From 2fec1d09e1c04df1c733dd14f76cc20e0bbd5ba0 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Fri, 13 Aug 2021 16:09:20 +0530 Subject: [PATCH] separate E2EFromCSR and M2LUsingTranslationClassesDependentData --- sumpy/e2e.py | 202 ++++++++++++++++++++++++++++++++++++++++++--------- sumpy/fmm.py | 13 ++-- 2 files changed, 175 insertions(+), 40 deletions(-) diff --git a/sumpy/e2e.py b/sumpy/e2e.py index 8d6ecf95..5f4897e5 100644 --- a/sumpy/e2e.py +++ b/sumpy/e2e.py @@ -150,12 +150,9 @@ class E2EFromCSR(E2EBase): default_name = "e2e_from_csr" def __init__(self, ctx, src_expansion, tgt_expansion, - name=None, device=None, - m2l_use_translation_classes_dependent_data=False): + name=None, device=None): super().__init__(ctx, src_expansion, tgt_expansion, name=name, device=device) - self.m2l_use_translation_classes_dependent_data = \ - m2l_use_translation_classes_dependent_data def get_translation_loopy_insns(self, result_dtype): from sumpy.symbolic import make_sym_vector @@ -165,18 +162,167 @@ class E2EFromCSR(E2EBase): tgt_rscale = sym.Symbol("tgt_rscale") - extra_kwargs = {} - if self.m2l_use_translation_classes_dependent_data: - m2l_translation_classes_dependent_ndata = \ + ncoeff_src = len(self.src_expansion) + + src_coeff_exprs = [sym.Symbol("src_coeff%d" % i) + for i in range(ncoeff_src)] + + from sumpy.assignment_collection import SymbolicAssignmentCollection + sac = SymbolicAssignmentCollection() + tgt_coeff_names = [ + sac.assign_unique("coeff%d" % i, coeff_i) + for i, coeff_i in enumerate( + self.tgt_expansion.translate_from( + self.src_expansion, src_coeff_exprs, src_rscale, + dvec, tgt_rscale, sac))] + + sac.run_global_cse() + + from sumpy.codegen import to_loopy_insns + return to_loopy_insns( + sac.assignments.items(), + vector_names=set(["d"]), + pymbolic_expr_maps=[self.tgt_expansion.get_code_transformer()], + retain_names=tgt_coeff_names, + complex_dtype=to_complex_dtype(result_dtype), + ) + + def get_kernel(self, result_dtype): + ncoeff_src = len(self.src_expansion) + ncoeff_tgt = len(self.tgt_expansion) + + # To clarify terminology: + # + # isrc_box -> The index in a list of (in this case, source) boxes + # src_ibox -> The (global) box number for the (in this case, source) box + # + # (same for itgt_box, tgt_ibox) + + from sumpy.tools import gather_loopy_arguments + loopy_knl = lp.make_kernel( + [ + "{[itgt_box]: 0<=itgt_box tgt_ibox = target_boxes[itgt_box] + <> tgt_center[idim] = centers[idim, tgt_ibox] + <> isrc_start = src_box_starts[itgt_box] + <> isrc_stop = src_box_starts[itgt_box+1] + + for isrc_box + <> src_ibox = src_box_lists[isrc_box] \ + {id=read_src_ibox} + <> src_center[idim] = centers[idim, src_ibox] {dup=idim} + <> d[idim] = tgt_center[idim] - src_center[idim] \ + {dup=idim} + """] + [""" + <> src_coeff{coeffidx} = \ + src_expansions[src_ibox - src_base_ibox, {coeffidx}] \ + {{dep=read_src_ibox}} + """.format(coeffidx=i) for i in range(ncoeff_src)] + [ + ] + self.get_translation_loopy_insns() + [""" + end + + """] + [f""" + tgt_expansions[tgt_ibox - tgt_base_ibox, {coeffidx}] = \ + simul_reduce(sum, isrc_box, coeff{coeffidx}) \ + {{id_prefix=write_expn}} + """ for coeffidx in range(ncoeff_tgt)] + [""" + end + """], + [ + lp.GlobalArg("centers", None, shape="dim, aligned_nboxes"), + lp.ValueArg("src_rscale,tgt_rscale", None), + lp.GlobalArg("src_box_starts, src_box_lists", + None, shape=None, strides=(1,), offset=lp.auto), + lp.ValueArg("aligned_nboxes,tgt_base_ibox,src_base_ibox", + np.int32), + lp.ValueArg("nsrc_level_boxes,ntgt_level_boxes", + np.int32), + lp.GlobalArg("src_expansions", None, + shape=("nsrc_level_boxes", ncoeff_src), offset=lp.auto), + lp.GlobalArg("tgt_expansions", None, + shape=("ntgt_level_boxes", ncoeff_tgt), offset=lp.auto), + "..." + ] + gather_loopy_arguments([self.src_expansion, + self.tgt_expansion]), + name=self.name, + assumptions="ntgt_boxes>=1", + silenced_warnings="write_race(write_expn*)", + default_offset=lp.auto, + fixed_parameters=dict(dim=self.dim), + lang_version=MOST_RECENT_LANGUAGE_VERSION + ) + + for knl in [self.src_expansion.kernel, self.tgt_expansion.kernel]: + loopy_knl = knl.prepare_loopy_kernel(loopy_knl) + + loopy_knl = lp.tag_inames(loopy_knl, "idim*:unr") + loopy_knl = lp.set_options(loopy_knl, + enforce_variable_access_ordered="no_check") + + return loopy_knl + + def get_optimized_kernel(self): + # FIXME + knl = self.get_kernel() + knl = lp.split_iname(knl, "itgt_box", 16, outer_tag="g.0") + + return knl + + def get_cache_key(self): + return ( + type(self).__name__, + self.src_expansion, + self.tgt_expansion, + ) + + def __call__(self, queue, **kwargs): + """ + :arg src_expansions: + :arg src_box_starts: + :arg src_box_lists: + :arg src_rscale: + :arg tgt_rscale: + :arg centers: + """ + centers = kwargs.pop("centers") + # "1" may be passed for rscale, which won't have its type + # meaningfully inferred. Make the type of rscale explicit. + src_rscale = centers.dtype.type(kwargs.pop("src_rscale")) + tgt_rscale = centers.dtype.type(kwargs.pop("tgt_rscale")) + + knl = self.get_cached_optimized_kernel() + + return knl(queue, + centers=centers, + src_rscale=src_rscale, tgt_rscale=tgt_rscale, + **kwargs) + + +class M2LUsingTranslationClassesDependentData(E2EFromCSR): + """Implements translation from a "compressed sparse row"-like source box + list using M2L translation classes dependent data + """ + + default_name = "m2l_using_translation_classes_dependent_data" + + def get_translation_loopy_insns(self, result_dtype): + from sumpy.symbolic import make_sym_vector + dvec = make_sym_vector("d", self.dim) + + src_rscale = sym.Symbol("src_rscale") + tgt_rscale = sym.Symbol("tgt_rscale") + + m2l_translation_classes_dependent_ndata = \ self.tgt_expansion.m2l_translation_classes_dependent_ndata( self.src_expansion) - m2l_translation_classes_dependent_data = \ + m2l_translation_classes_dependent_data = \ [sym.Symbol("m2l_translation_classes_dependent_expr%d" % i) for i in range(m2l_translation_classes_dependent_ndata)] - extra_kwargs["m2l_translation_classes_dependent_data"] = \ - m2l_translation_classes_dependent_data - else: - m2l_translation_classes_dependent_ndata = 0 if self.use_preprocessing_for_m2l: ncoeff_src = self.tgt_expansion.m2l_preprocess_multipole_nexprs( @@ -194,7 +340,9 @@ class E2EFromCSR(E2EBase): for i, coeff_i in enumerate( self.tgt_expansion.translate_from( self.src_expansion, src_coeff_exprs, src_rscale, - dvec, tgt_rscale, sac, **extra_kwargs))] + dvec, tgt_rscale, sac, + m2l_translation_classes_dependent_data=( + m2l_translation_classes_dependent_data)))] sac.run_global_cse() @@ -219,12 +367,9 @@ class E2EFromCSR(E2EBase): """ ncoeff_tgt = len(self.tgt_expansion) - if self.m2l_use_translation_classes_dependent_data: - m2l_translation_classes_dependent_ndata = \ + m2l_translation_classes_dependent_ndata = \ self.tgt_expansion.m2l_translation_classes_dependent_ndata( self.src_expansion) - else: - m2l_translation_classes_dependent_ndata = 0 if self.use_preprocessing_for_m2l: ncoeff_tgt = m2l_translation_classes_dependent_ndata @@ -268,12 +413,9 @@ class E2EFromCSR(E2EBase): return insns def get_kernel(self, result_dtype): - if self.m2l_use_translation_classes_dependent_data: - m2l_translation_classes_dependent_ndata = \ + m2l_translation_classes_dependent_ndata = \ self.tgt_expansion.m2l_translation_classes_dependent_ndata( self.src_expansion) - else: - m2l_translation_classes_dependent_ndata = 0 if self.use_preprocessing_for_m2l: # number of expressions given as input to M2L after preprocessing @@ -316,12 +458,11 @@ class E2EFromCSR(E2EBase): <> src_center[idim] = centers[idim, src_ibox] {dup=idim} <> d[idim] = tgt_center[idim] - src_center[idim] \ {dup=idim} - """] + ([""" <> translation_class = \ m2l_translation_classes_lists[isrc_box] <> translation_class_rel = translation_class - \ translation_classes_level_start - """] if m2l_translation_classes_dependent_ndata != 0 else []) + [""" + """] + [""" <> m2l_translation_classes_dependent_expr{idx} = \ m2l_translation_classes_dependent_data[ \ translation_class_rel, {idx}] @@ -359,8 +500,6 @@ class E2EFromCSR(E2EBase): shape=("nsrc_level_boxes", ncoeff_src), offset=lp.auto), lp.GlobalArg("tgt_expansions", None, shape=("ntgt_level_boxes", ncoeff_tgt), offset=lp.auto), - "..." - ] + ([ lp.ValueArg("translation_classes_level_start", np.int32), lp.GlobalArg("m2l_translation_classes_dependent_data", None, @@ -372,8 +511,9 @@ class E2EFromCSR(E2EBase): offset=lp.auto), lp.ValueArg("ntranslation_classes, ntranslation_classes_lists", np.int32), - ] if m2l_translation_classes_dependent_ndata != 0 else []) - + gather_loopy_arguments([self.src_expansion, self.tgt_expansion]), + "..." + ] + gather_loopy_arguments([self.src_expansion, + self.tgt_expansion]), name=self.name, assumptions="ntgt_boxes>=1", silenced_warnings="write_race(write_expn*)", @@ -400,14 +540,6 @@ class E2EFromCSR(E2EBase): return knl - def get_cache_key(self): - return ( - type(self).__name__, - self.src_expansion, - self.tgt_expansion, - self.m2l_use_translation_classes_dependent_data, - ) - def __call__(self, queue, **kwargs): """ :arg src_expansions: diff --git a/sumpy/fmm.py b/sumpy/fmm.py index 6948cc9d..26acf713 100644 --- a/sumpy/fmm.py +++ b/sumpy/fmm.py @@ -36,7 +36,8 @@ from sumpy import ( P2EFromSingleBox, P2EFromCSR, E2PFromSingleBox, E2PFromCSR, P2PFromCSR, - E2EFromCSR, E2EFromChildren, E2EFromParent, + E2EFromCSR, M2LUsingTranslationClassesDependentData, + E2EFromChildren, E2EFromParent, M2LGenerateTranslationClassesDependentData, M2LPreprocessMultipole) from sumpy.tools import to_complex_dtype @@ -124,11 +125,13 @@ class SumpyExpansionWranglerCodeContainer: @memoize_method def m2l(self, src_order, tgt_order, m2l_use_translation_classes_dependent_data=False): - return E2EFromCSR(self.cl_context, + if m2l_use_translation_classes_dependent_data: + m2l_class = M2LUsingTranslationClassesDependentData + else: + m2l_class = E2EFromCSR + return m2l_class(self.cl_context, self.multipole_expansion(src_order), - self.local_expansion(tgt_order), - m2l_use_translation_classes_dependent_data=( - m2l_use_translation_classes_dependent_data)) + self.local_expansion(tgt_order)) @memoize_method def m2l_translation_class_dependent_data_kernel(self, src_order, tgt_order): -- GitLab