From 48bdc5ef3fd17272b5824221d17337ec30780b9d Mon Sep 17 00:00:00 2001 From: Isuru Fernando <idf2@illinois.edu> Date: Sun, 27 Mar 2022 22:28:35 -0700 Subject: [PATCH] Direct loopy kernel (#95) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * post process kernel * Fix postprocesslocal * use preprocess/postprocess everywhere * Fix event management for m2l * Use a loopy kernel * return a loopy kernel * fix formatting * reduce diff * Fix typo * Fix typo * use_preprocessing_for_m2l -> use_fft_for_m2l * supports_optimized_m2l -> supports_translation_classes * restore non fft optimized code path * events -> timing_events * More descriptive exception message * use lp.Assignment instead of strings * Fix typo * Add a comment about symbolic sum Co-authored-by: Andreas Klöckner <inform@tiker.net> * Fix error message * split too long line * Make use_preprocessing_for_m2l an option * Add an comment about the options * fix syntax errors * add missing defines * Fix hashing * Fix returning nexprs * fix use_preprocessing_for_m2l * Fix ncoeff_src and ncoeff_tgt * fix use_preprocessing_for_m2l for bessel based * Fix logic Co-authored-by: Andreas Klöckner <inform@tiker.net> --- sumpy/__init__.py | 5 +- sumpy/e2e.py | 330 +++++++++++++++++++++++---------------- sumpy/expansion/local.py | 237 ++++++++++++++++++++-------- sumpy/fmm.py | 112 +++++++++---- test/test_fmm.py | 2 +- 5 files changed, 458 insertions(+), 228 deletions(-) diff --git a/sumpy/__init__.py b/sumpy/__init__.py index 30b0cd66..b39ead2e 100644 --- a/sumpy/__init__.py +++ b/sumpy/__init__.py @@ -26,7 +26,8 @@ from sumpy.p2e import P2EFromSingleBox, P2EFromCSR from sumpy.e2p import E2PFromSingleBox, E2PFromCSR from sumpy.e2e import (E2EFromCSR, E2EFromChildren, E2EFromParent, M2LUsingTranslationClassesDependentData, - M2LGenerateTranslationClassesDependentData, M2LPreprocessMultipole) + M2LGenerateTranslationClassesDependentData, M2LPreprocessMultipole, + M2LPostprocessLocal) from sumpy.version import VERSION_TEXT from pytools.persistent_dict import WriteOncePersistentDict @@ -37,7 +38,7 @@ __all__ = [ "E2EFromCSR", "E2EFromChildren", "E2EFromParent", "M2LUsingTranslationClassesDependentData", "M2LGenerateTranslationClassesDependentData", - "M2LPreprocessMultipole"] + "M2LPreprocessMultipole", "M2LPostprocessLocal"] code_cache = WriteOncePersistentDict("sumpy-code-cache-v6-"+VERSION_TEXT) diff --git a/sumpy/e2e.py b/sumpy/e2e.py index 7774b62a..550cb173 100644 --- a/sumpy/e2e.py +++ b/sumpy/e2e.py @@ -23,6 +23,7 @@ THE SOFTWARE. import numpy as np import loopy as lp import sumpy.symbolic as sym +import pymbolic from loopy.version import MOST_RECENT_LANGUAGE_VERSION from sumpy.tools import KernelCacheWrapper, to_complex_dtype @@ -84,8 +85,6 @@ class E2EBase(KernelCacheWrapper): self.tgt_expansion = tgt_expansion self.name = name or self.default_name self.device = device - self.use_preprocessing_for_m2l = getattr(self.tgt_expansion, - "use_preprocessing_for_m2l", False) if src_expansion.dim != tgt_expansion.dim: raise ValueError("source and target expansions must have " @@ -272,13 +271,6 @@ class E2EFromCSR(E2EBase): return knl - def get_cache_key(self): - return ( - type(self).__name__, - self.src_expansion, - self.tgt_expansion, - ) - def __call__(self, queue, **kwargs): """ :arg src_expansions: @@ -320,22 +312,18 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR): self.tgt_expansion.m2l_translation_classes_dependent_ndata( self.src_expansion) m2l_translation_classes_dependent_data = \ - [sym.Symbol("m2l_translation_classes_dependent_expr%d" % i) + [sym.Symbol("data%d" % i) for i in range(m2l_translation_classes_dependent_ndata)] - if self.use_preprocessing_for_m2l: - ncoeff_src = self.tgt_expansion.m2l_preprocess_multipole_nexprs( - self.src_expansion) - else: - ncoeff_src = len(self.src_expansion) + ncoeff_src = len(self.src_expansion) - src_coeff_exprs = [sym.Symbol("src_coeff%d" % i) + src_coeff_exprs = [sym.Symbol("src_coeffs%d" % i) for i in range(ncoeff_src)] from sumpy.assignment_collection import SymbolicAssignmentCollection sac = SymbolicAssignmentCollection() tgt_coeff_names = [ - sac.assign_unique("coeff%d" % i, coeff_i) + sac.assign_unique("tgt_coeff%d" % i, coeff_i) for i, coeff_i in enumerate( self.tgt_expansion.translate_from( self.src_expansion, src_coeff_exprs, src_rscale, @@ -348,87 +336,65 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR): from sumpy.codegen import to_loopy_insns return to_loopy_insns( sac.assignments.items(), - vector_names=set(["d"]), + vector_names=set(["d", "src_coeffs", "data"]), pymbolic_expr_maps=[self.tgt_expansion.get_code_transformer()], retain_names=tgt_coeff_names, complex_dtype=to_complex_dtype(result_dtype), ) - def get_postprocess_loopy_insns(self, result_dtype): - """Loopy instructions that happen only once for each target box. - - :arg result_dtype: A numpy dtype for the result. This is important - because depending on the input to the M2L being real or not - the code needs to be different. If the input was real, the - result should be the real part of the complex values from the - inverse FFT. In the input was complex, the result matches the - values from the inverse FFT. - """ + def get_inner_loopy_kernel(self, result_dtype): + try: + return self.tgt_expansion.loopy_translate_from( + self.src_expansion) + except NotImplementedError: + pass - ncoeff_tgt = len(self.tgt_expansion) - m2l_translation_classes_dependent_ndata = \ - self.tgt_expansion.m2l_translation_classes_dependent_ndata( + ndata = self.tgt_expansion.m2l_translation_classes_dependent_ndata( self.src_expansion) - - if self.use_preprocessing_for_m2l: - ncoeff_tgt = m2l_translation_classes_dependent_ndata - - from sumpy.assignment_collection import SymbolicAssignmentCollection - sac = SymbolicAssignmentCollection() - - tgt_coeff_exprs = [ - sym.Symbol("coeff_sum%d" % i) for i in range(ncoeff_tgt) - ] - - src_rscale = sym.Symbol("src_rscale") - tgt_rscale = sym.Symbol("tgt_rscale") - - if self.use_preprocessing_for_m2l: - tgt_coeff_post_exprs = self.tgt_expansion.m2l_postprocess_local_exprs( - self.src_expansion, tgt_coeff_exprs, src_rscale, tgt_rscale, - sac=sac) - if result_dtype in (np.float32, np.float64): - real_func = sym.Function("real") - tgt_coeff_post_exprs = [real_func(expr) for expr in - tgt_coeff_post_exprs] + if self.tgt_expansion.use_preprocessing_for_m2l: + ncoeff_src = self.tgt_expansion.m2l_preprocess_multipole_nexprs( + self.src_expansion) + ncoeff_tgt = self.tgt_expansion.m2l_postprocess_local_nexprs( + self.src_expansion) else: - tgt_coeff_post_exprs = tgt_coeff_exprs - - tgt_coeff_post_names = [ - sac.assign_unique("coeff_post%d" % i, coeff) - for i, coeff in enumerate(tgt_coeff_post_exprs) - ] - - sac.run_global_cse() - - from sumpy.codegen import to_loopy_insns - insns = to_loopy_insns( - sac.assignments.items(), - vector_names=set(["d"]), - pymbolic_expr_maps=[self.tgt_expansion.get_code_transformer()], - retain_names=tgt_coeff_post_names, - complex_dtype=to_complex_dtype(result_dtype), - ) - return insns + ncoeff_src = len(self.src_expansion) + ncoeff_tgt = len(self.tgt_expansion) + + domains = [] + insns = self.get_translation_loopy_insns(result_dtype) + coeff = pymbolic.var("coeff") + for i in range(ncoeff_tgt): + expr = pymbolic.var(f"tgt_coeff{i}") + insn = lp.Assignment(assignee=coeff[i], + expression=coeff[i] + expr) + insns.append(insn) + + return lp.make_function(domains, insns, + kernel_data=[ + lp.GlobalArg("coeff", shape=(ncoeff_tgt,), + is_output=True, is_input=True), + lp.GlobalArg("src_coeffs", shape=(ncoeff_src,)), + lp.GlobalArg("data", shape=(ndata,)), + lp.ValueArg("src_rscale"), + lp.ValueArg("tgt_rscale"), + ...], + name="e2e", + lang_version=lp.MOST_RECENT_LANGUAGE_VERSION, + ) def get_kernel(self, result_dtype): m2l_translation_classes_dependent_ndata = \ self.tgt_expansion.m2l_translation_classes_dependent_ndata( self.src_expansion) - if self.use_preprocessing_for_m2l: - # number of expressions given as input to M2L after preprocessing + if self.tgt_expansion.use_preprocessing_for_m2l: ncoeff_src = self.tgt_expansion.m2l_preprocess_multipole_nexprs( self.src_expansion) - # number of expressions given as input to postprocessing - ncoeff_tgt_before_postprocess = \ - self.tgt_expansion.m2l_postprocess_local_nexprs( - self.src_expansion) + ncoeff_tgt = self.tgt_expansion.m2l_postprocess_local_nexprs( + self.src_expansion) else: ncoeff_src = len(self.src_expansion) - ncoeff_tgt_before_postprocess = len(self.tgt_expansion) - - ncoeff_tgt = len(self.tgt_expansion) + ncoeff_tgt = len(self.tgt_expansion) # To clarify terminology: # @@ -437,53 +403,44 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR): # # (same for itgt_box, tgt_ibox) + translation_knl = self.get_inner_loopy_kernel(result_dtype) + from sumpy.tools import gather_loopy_arguments loopy_knl = lp.make_kernel( [ "{[itgt_box]: 0<=itgt_box<ntgt_boxes}", "{[isrc_box]: isrc_start<=isrc_box<isrc_stop}", - "{[idim]: 0<=idim<dim}", - ], + "{[icoeff_tgt]: 0<=icoeff_tgt<ncoeff_tgt}", + "{[icoeff_src]: 0<=icoeff_src<ncoeff_src}", + "{[idep]: 0<=idep<m2l_translation_classes_dependent_ndata}", + ], [""" for itgt_box <> tgt_ibox = target_boxes[itgt_box] - <> tgt_center[idim] = centers[idim, tgt_ibox] <> isrc_start = src_box_starts[itgt_box] <> isrc_stop = src_box_starts[itgt_box+1] - + for icoeff_tgt + <> coeffs[icoeff_tgt] = 0 {id=init_coeffs, dup=icoeff_tgt} + end for isrc_box <> src_ibox = src_box_lists[isrc_box] \ {id=read_src_ibox} - <> src_center[idim] = centers[idim, src_ibox] {dup=idim} - <> d[idim] = tgt_center[idim] - src_center[idim] \ - {dup=idim} <> translation_class = \ m2l_translation_classes_lists[isrc_box] <> translation_class_rel = translation_class - \ translation_classes_level_start - """] + [""" - <> m2l_translation_classes_dependent_expr{idx} = \ - m2l_translation_classes_dependent_data[ \ - translation_class_rel, {idx}] - """.format(idx=idx) for idx in range( - m2l_translation_classes_dependent_ndata)] + [""" - <> src_coeff{coeffidx} = \ - src_expansions[src_ibox - src_base_ibox, {coeffidx}] \ - {{dep=read_src_ibox}} - """.format(coeffidx=i) for i in range(ncoeff_src)] + [ - - ] + self.get_translation_loopy_insns(result_dtype) + [""" + [icoeff_tgt]: coeffs[icoeff_tgt] = e2e( + [icoeff_tgt]: coeffs[icoeff_tgt], + [icoeff_src]: src_expansions[src_ibox - src_base_ibox, + icoeff_src], + [idep]: m2l_translation_classes_dependent_data[ + translation_class_rel, idep], + src_rscale, + tgt_rscale, + ) {dep=init_coeffs,id=update_coeffs} end - - """] + [""" - <> coeff_sum{coeffidx} = \ - simul_reduce(sum, isrc_box, coeff{coeffidx}) - """.format(coeffidx=i) for i in - range(ncoeff_tgt_before_postprocess)] + [ - ] + self.get_postprocess_loopy_insns(result_dtype) + [f""" - tgt_expansions[tgt_ibox - tgt_base_ibox, {coeffidx}] = \ - coeff_post{coeffidx} {{id_prefix=write_expn}} - """ for coeffidx in range(ncoeff_tgt)] + [""" + tgt_expansions[tgt_ibox - tgt_base_ibox, icoeff_tgt] = \ + coeffs[icoeff_tgt] {dep=update_coeffs, dup=icoeff_tgt} end """], [ @@ -498,7 +455,8 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR): lp.GlobalArg("src_expansions", None, shape=("nsrc_level_boxes", ncoeff_src), offset=lp.auto), lp.GlobalArg("tgt_expansions", None, - shape=("ntgt_level_boxes", ncoeff_tgt), offset=lp.auto), + shape=("ntgt_level_boxes", ncoeff_tgt), + offset=lp.auto), lp.ValueArg("translation_classes_level_start", np.int32), lp.GlobalArg("m2l_translation_classes_dependent_data", None, @@ -515,26 +473,26 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR): self.tgt_expansion]), name=self.name, assumptions="ntgt_boxes>=1", - silenced_warnings="write_race(write_expn*)", default_offset=lp.auto, fixed_parameters=dict(dim=self.dim, m2l_translation_classes_dependent_ndata=( - m2l_translation_classes_dependent_ndata)), + m2l_translation_classes_dependent_ndata), + ncoeff_tgt=ncoeff_tgt, + ncoeff_src=ncoeff_src), lang_version=MOST_RECENT_LANGUAGE_VERSION ) + loopy_knl = lp.merge([translation_knl, loopy_knl]) + loopy_knl = lp.inline_callable_kernel(loopy_knl, "e2e") + for knl in [self.src_expansion.kernel, self.tgt_expansion.kernel]: loopy_knl = knl.prepare_loopy_kernel(loopy_knl) - loopy_knl = lp.tag_inames(loopy_knl, "idim*:unr") - loopy_knl = lp.set_options(loopy_knl, - enforce_variable_access_ordered="no_check") - return loopy_knl def get_optimized_kernel(self, result_dtype): - # FIXME knl = self.get_kernel(result_dtype) + # FIXME knl = lp.split_iname(knl, "itgt_box", 16, outer_tag="g.0") return knl @@ -553,17 +511,12 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR): # meaningfully inferred. Make the type of rscale explicit. src_rscale = centers.dtype.type(kwargs.pop("src_rscale")) tgt_rscale = centers.dtype.type(kwargs.pop("tgt_rscale")) + src_expansions = kwargs.pop("src_expansions") - if "tgt_expansions" in kwargs: - tgt_expansions = kwargs["tgt_expansions"] - result_dtype = tgt_expansions.dtype - else: - src_expansions = kwargs["src_expansions"] - result_dtype = src_expansions.dtype - - knl = self.get_cached_optimized_kernel(result_dtype=result_dtype) + knl = self.get_cached_optimized_kernel(result_dtype=src_expansions.dtype) return knl(queue, + src_expansions=src_expansions, centers=centers, src_rscale=src_rscale, tgt_rscale=tgt_rscale, **kwargs) @@ -788,6 +741,118 @@ class M2LPreprocessMultipole(E2EBase): # }}} +# {{{ M2LPostprocessLocal + +class M2LPostprocessLocal(E2EBase): + """Postprocesses locals expansions for accelerated M2L""" + + default_name = "m2l_postprocess_local" + + def get_loopy_insns(self, result_dtype): + ncoeffs_before_postprocessing = \ + self.tgt_expansion.m2l_postprocess_local_nexprs(self.tgt_expansion) + + tgt_coeff_exprs_before_postprocessing = [ + sym.Symbol("tgt_coeff_before_postprocessing%d" % i) + for i in range(ncoeffs_before_postprocessing)] + + src_rscale = sym.Symbol("src_rscale") + tgt_rscale = sym.Symbol("tgt_rscale") + + from sumpy.assignment_collection import SymbolicAssignmentCollection + sac = SymbolicAssignmentCollection() + + tgt_coeff_exprs = self.tgt_expansion.m2l_postprocess_local_exprs( + self.tgt_expansion, tgt_coeff_exprs_before_postprocessing, + sac=sac, src_rscale=src_rscale, tgt_rscale=tgt_rscale) + + if result_dtype in (np.float32, np.float64): + real_func = sym.Function("real") + tgt_coeff_exprs = [real_func(expr) for expr in + tgt_coeff_exprs] + + tgt_coeff_names = [ + sac.assign_unique("tgt_coeff%d" % i, coeff_i) + for i, coeff_i in enumerate(tgt_coeff_exprs)] + + sac.run_global_cse() + + from sumpy.codegen import to_loopy_insns + return to_loopy_insns( + sac.assignments.items(), + vector_names=set(["d"]), + pymbolic_expr_maps=[self.tgt_expansion.get_code_transformer()], + retain_names=tgt_coeff_names, + complex_dtype=to_complex_dtype(result_dtype), + ) + + def get_kernel(self, result_dtype): + ntgt_coeffs = len(self.tgt_expansion) + ntgt_coeffs_before_postprocessing = \ + self.tgt_expansion.m2l_postprocess_local_nexprs(self.tgt_expansion) + from sumpy.tools import gather_loopy_arguments + loopy_knl = lp.make_kernel( + [ + "{[itgt_box]: 0<=itgt_box<ntgt_boxes}", + ], + [""" + for itgt_box + """] + [""" + <> tgt_coeff_before_postprocessing{idx} = \ + tgt_expansions_before_postprocessing[itgt_box, {idx}] + """.format(idx=i) for i in range( + ntgt_coeffs_before_postprocessing)] + + self.get_loopy_insns(result_dtype) + [""" + tgt_expansions[itgt_box, {idx}] = \ + tgt_coeff{idx} + """.format(idx=i) for i in range(ntgt_coeffs)] + [""" + end + """], + [ + lp.ValueArg("ntgt_boxes", np.int32), + lp.ValueArg("src_rscale", None), + lp.ValueArg("tgt_rscale", None), + lp.GlobalArg("tgt_expansions", result_dtype, + shape=("ntgt_boxes", ntgt_coeffs), offset=lp.auto), + lp.GlobalArg("tgt_expansions_before_postprocessing", None, + shape=("ntgt_boxes", ntgt_coeffs_before_postprocessing), + offset=lp.auto), + "..." + ] + gather_loopy_arguments([self.src_expansion, self.tgt_expansion]), + name=self.name, + assumptions="ntgt_boxes>=1", + default_offset=lp.auto, + fixed_parameters=dict(dim=self.dim), + lang_version=MOST_RECENT_LANGUAGE_VERSION + ) + + for expn in [self.src_expansion.kernel, self.tgt_expansion.kernel]: + loopy_knl = expn.prepare_loopy_kernel(loopy_knl) + + loopy_knl = lp.set_options(loopy_knl, + enforce_variable_access_ordered="no_check") + return loopy_knl + + def get_optimized_kernel(self, result_dtype): + # FIXME + knl = self.get_kernel(result_dtype) + knl = lp.split_iname(knl, "itgt_box", 16, outer_tag="g.0") + return knl + + def __call__(self, queue, **kwargs): + """ + :arg tgt_expansions + :arg tgt_expansions_before_postprocessing + """ + tgt_expansions = kwargs.pop("tgt_expansions") + result_dtype = tgt_expansions.dtype + knl = self.get_cached_optimized_kernel(result_dtype=result_dtype) + return knl(queue, + tgt_expansions=tgt_expansions, **kwargs) + +# }}} + + # {{{ translation from a box's children class E2EFromChildren(E2EBase): @@ -796,10 +861,11 @@ class E2EFromChildren(E2EBase): def get_kernel(self): if self.src_expansion is not self.tgt_expansion: raise RuntimeError("%s requires that the source " - "and target expansion are the same object" - % type(self).__name__) + "and target expansion are the same object" + % type(self).__name__) - ncoeffs = len(self.src_expansion) + ncoeffs_src = len(self.src_expansion) + ncoeffs_tgt = len(self.tgt_expansion) # To clarify terminology: # @@ -841,14 +907,14 @@ class E2EFromChildren(E2EBase): <> src_coeff{i} = \ src_expansions[src_ibox - src_base_ibox, {i}] \ {{id_prefix=read_coeff,dep=read_src_ibox}} - """.format(i=i) for i in range(ncoeffs)] + [ + """.format(i=i) for i in range(ncoeffs_src)] + [ ] + loopy_insns + [""" tgt_expansions[tgt_ibox - tgt_base_ibox, {i}] = \ tgt_expansions[tgt_ibox - tgt_base_ibox, {i}] \ + coeff{i} \ {{id_prefix=write_expn,dep=compute_coeff*, nosync=read_coeff*}} - """.format(i=i) for i in range(ncoeffs)] + [""" + """.format(i=i) for i in range(ncoeffs_tgt)] + [""" end end end @@ -861,9 +927,9 @@ class E2EFromChildren(E2EBase): lp.GlobalArg("box_child_ids", None, shape="nchildren, aligned_nboxes"), lp.GlobalArg("tgt_expansions", None, - shape=("ntgt_level_boxes", ncoeffs), offset=lp.auto), + shape=("ntgt_level_boxes", ncoeffs_tgt), offset=lp.auto), lp.GlobalArg("src_expansions", None, - shape=("nsrc_level_boxes", ncoeffs), offset=lp.auto), + shape=("nsrc_level_boxes", ncoeffs_src), offset=lp.auto), lp.ValueArg("src_base_ibox,tgt_base_ibox", np.int32), lp.ValueArg("ntgt_level_boxes,nsrc_level_boxes", np.int32), lp.ValueArg("aligned_nboxes", np.int32), diff --git a/sumpy/expansion/local.py b/sumpy/expansion/local.py index abdd0410..493e7045 100644 --- a/sumpy/expansion/local.py +++ b/sumpy/expansion/local.py @@ -29,6 +29,8 @@ from sumpy.expansion import ( from sumpy.tools import mi_increment_axis, matvec_toeplitz_upper_triangular from pytools import single_valued from typing import Tuple, Any +import pymbolic +import loopy as lp import logging logger = logging.getLogger(__name__) @@ -50,6 +52,7 @@ class LocalExpansionBase(ExpansionBase): .. attribute:: kernel .. attribute:: order .. attribute:: use_rscale + .. attribute:: use_fft_for_m2l .. attribute:: use_preprocessing_for_m2l .. automethod:: m2l_translation_classes_dependent_data @@ -60,19 +63,26 @@ class LocalExpansionBase(ExpansionBase): .. automethod:: m2l_postprocess_local_nexprs .. automethod:: translate_from """ - init_arg_names = ("kernel", "order", "use_rscale", "use_preprocessing_for_m2l") + init_arg_names = ("kernel", "order", "use_rscale", "use_fft_for_m2l", + "use_preprocessing_for_m2l") def __init__(self, kernel, order, use_rscale=None, - use_preprocessing_for_m2l=False): + use_fft_for_m2l=False, use_preprocessing_for_m2l=None): super().__init__(kernel, order, use_rscale) - self.use_preprocessing_for_m2l = use_preprocessing_for_m2l + self.use_fft_for_m2l = use_fft_for_m2l + if use_preprocessing_for_m2l is None: + self.use_preprocessing_for_m2l = use_fft_for_m2l + else: + self.use_preprocessing_for_m2l = use_preprocessing_for_m2l def with_kernel(self, kernel): return type(self)(kernel, self.order, self.use_rscale, + use_fft_for_m2l=self.use_fft_for_m2l, use_preprocessing_for_m2l=self.use_preprocessing_for_m2l) def update_persistent_hash(self, key_hash, key_builder): super().update_persistent_hash(key_hash, key_builder) + key_builder.rec(key_hash, self.use_fft_for_m2l) key_builder.rec(key_hash, self.use_preprocessing_for_m2l) def __eq__(self, other): @@ -81,6 +91,7 @@ class LocalExpansionBase(ExpansionBase): and self.kernel == other.kernel and self.order == other.order and self.use_rscale == other.use_rscale + and self.use_fft_for_m2l == other.use_fft_for_m2l and self.use_preprocessing_for_m2l == other.use_preprocessing_for_m2l ) @@ -308,13 +319,10 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase): def m2l_translation_classes_dependent_ndata(self, src_expansion): """Returns number of expressions in M2L global precomputation step. """ - mis_with_dummy_rows, mis_without_dummy_rows, _ = \ + mis_with_dummy_rows, _, _ = \ self._m2l_translation_classes_dependent_data_mis(src_expansion) - if self.use_preprocessing_for_m2l: - return len(mis_with_dummy_rows) - else: - return len(mis_without_dummy_rows) + return len(mis_with_dummy_rows) def _m2l_translation_classes_dependent_data_mis(self, src_expansion): """We would like to compute the M2L by way of a circulant matrix below. @@ -431,12 +439,12 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase): vector[i] = add_to_sac(sac, vector_full[srcplusderiv_ident_to_index[term]]) - if self.use_preprocessing_for_m2l: - # Add zero values needed to make the translation matrix circulant - derivatives_full = [0]*len(circulant_matrix_mis) - for expr, mi in zip(vector, needed_vector_terms): - derivatives_full[circulant_matrix_ident_to_index[mi]] = expr + # Add zero values needed to make the translation matrix circulant + derivatives_full = [0]*len(circulant_matrix_mis) + for expr, mi in zip(vector, needed_vector_terms): + derivatives_full[circulant_matrix_ident_to_index[mi]] = expr + if self.use_fft_for_m2l: # Note that the matrix we have now is a mirror image of a # circulant matrix. We reverse the first column to get the # first column for the circulant matrix and then finally @@ -444,7 +452,7 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase): # matrix. return fft(list(reversed(derivatives_full)), sac=sac) - return vector + return derivatives_full def m2l_preprocess_multipole_exprs(self, src_expansion, src_coeff_exprs, sac, src_rscale): @@ -461,15 +469,16 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase): input_vector[circulant_matrix_ident_to_index[term]] = \ add_to_sac(sac, coeff) - if self.use_preprocessing_for_m2l: + if self.use_fft_for_m2l: return fft(input_vector, sac=sac) else: - # When FFT is turned off, there is no preprocessing needed - # Therefore no copying is done and the multipole expansion is sent to - # the main M2L routine as it is. This method is used internally in the - # the main M2l routine to avoid code duplication. return input_vector + def m2l_preprocess_multipole_nexprs(self, src_expansion): + circulant_matrix_mis, _, _ = \ + self._m2l_translation_classes_dependent_data_mis(src_expansion) + return len(circulant_matrix_mis) + def m2l_postprocess_local_exprs(self, src_expansion, m2l_result, src_rscale, tgt_rscale, sac): circulant_matrix_mis, needed_vector_terms, max_mi = \ @@ -477,7 +486,7 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase): circulant_matrix_ident_to_index = dict((ident, i) for i, ident in enumerate(circulant_matrix_mis)) - if self.use_preprocessing_for_m2l: + if self.use_fft_for_m2l: n = len(circulant_matrix_mis) m2l_result = fft(m2l_result, inverse=True, sac=sac) # since we reversed the M2L matrix, we reverse the result @@ -493,6 +502,9 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase): return result + def m2l_postprocess_local_nexprs(self, src_expansion): + return self.m2l_translation_classes_dependent_ndata(src_expansion) + def translate_from(self, src_expansion, src_coeff_exprs, src_rscale, dvec, tgt_rscale, sac=None, _fast_version=True, m2l_translation_classes_dependent_data=None): @@ -514,8 +526,6 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase): if isinstance(src_expansion, VolumeTaylorMultipoleExpansionBase): circulant_matrix_mis, needed_vector_terms, max_mi = \ self._m2l_translation_classes_dependent_data_mis(src_expansion) - circulant_matrix_ident_to_index = {ident: i for i, ident in - enumerate(circulant_matrix_mis)} if not m2l_translation_classes_dependent_data: derivatives = self.m2l_translation_classes_dependent_data( @@ -523,26 +533,23 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase): else: derivatives = m2l_translation_classes_dependent_data - if self.use_preprocessing_for_m2l: - assert m2l_translation_classes_dependent_data is not None - assert len(src_coeff_exprs) == len( - m2l_translation_classes_dependent_data) - result = [a*b for a, b in zip(m2l_translation_classes_dependent_data, - src_coeff_exprs)] + if self.use_fft_for_m2l: + assert len(src_coeff_exprs) == len(derivatives) + result = [a*b for a, b in zip(derivatives, src_coeff_exprs)] else: - derivatives_full = [0]*len(circulant_matrix_mis) - for expr, mi in zip(derivatives, needed_vector_terms): - derivatives_full[circulant_matrix_ident_to_index[mi]] = expr - - input_vector = self.m2l_preprocess_multipole_exprs(src_expansion, - src_coeff_exprs, sac, src_rscale) + if not self.use_preprocessing_for_m2l: + src_coeff_exprs = self.m2l_preprocess_multipole_exprs( + src_expansion, src_coeff_exprs, sac, src_rscale) - # Do the matvec - output = matvec_toeplitz_upper_triangular(input_vector, - derivatives_full) + # Returns a big symbolic sum of matrix entries + # (FIXME? Though this is just the correctness-checking + # fallback for the FFT anyhow) + result = matvec_toeplitz_upper_triangular(src_coeff_exprs, + derivatives) - result = self.m2l_postprocess_local_exprs(src_expansion, output, - src_rscale, tgt_rscale, sac) + if not self.use_preprocessing_for_m2l: + result = self.m2l_postprocess_local_exprs(src_expansion, + result, src_rscale, tgt_rscale, sac) logger.info("building translation operator: done") return result @@ -682,15 +689,61 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase): logger.info("building translation operator: done") return result + def loopy_translate_from(self, src_expansion): + from sumpy.expansion.multipole import VolumeTaylorMultipoleExpansionBase + + if isinstance(src_expansion, VolumeTaylorMultipoleExpansionBase): + if self.use_preprocessing_for_m2l: + ncoeff_src = self.m2l_preprocess_multipole_nexprs(src_expansion) + ncoeff_tgt = self.m2l_postprocess_local_nexprs(src_expansion) + icoeff_src = pymbolic.var("icoeff_src") + icoeff_tgt = pymbolic.var("icoeff_tgt") + domains = [f"{{[icoeff_tgt]: 0<=icoeff_tgt<{ncoeff_tgt} }}"] + + coeff = pymbolic.var("coeff") + src_coeffs = pymbolic.var("src_coeffs") + m2l_translation_classes_dependent_data = pymbolic.var("data") + + if self.use_fft_for_m2l: + expr = src_coeffs[icoeff_tgt] \ + * m2l_translation_classes_dependent_data[icoeff_tgt] + else: + toeplitz_first_row = src_coeffs[icoeff_src-icoeff_tgt] + vector = m2l_translation_classes_dependent_data[icoeff_src] + expr = toeplitz_first_row * vector + domains.append( + f"{{[icoeff_src]: icoeff_tgt<=icoeff_src<{ncoeff_src} }}") + + expr = src_coeffs[icoeff_tgt] \ + * m2l_translation_classes_dependent_data[icoeff_tgt] + + insns = [ + lp.Assignment( + assignee=coeff[icoeff_tgt], + expression=coeff[icoeff_tgt] + expr), + ] + return lp.make_function(domains, insns, + kernel_data=[ + lp.GlobalArg("coeff, src_coeffs, data", + shape=lp.auto), + lp.ValueArg("src_rscale, tgt_rscale"), + ...], + name="e2e", + lang_version=lp.MOST_RECENT_LANGUAGE_VERSION, + ) + raise NotImplementedError( + f"A direct loopy kernel for translation from " + f"{src_expansion} to {self} is not implemented.") + class VolumeTaylorLocalExpansion( VolumeTaylorExpansion, VolumeTaylorLocalExpansionBase): def __init__(self, kernel, order, use_rscale=None, - use_preprocessing_for_m2l=False): + use_fft_for_m2l=False, use_preprocessing_for_m2l=None): VolumeTaylorLocalExpansionBase.__init__(self, kernel, order, use_rscale, - use_preprocessing_for_m2l) + use_fft_for_m2l, use_preprocessing_for_m2l=None) VolumeTaylorExpansion.__init__(self, kernel, order, use_rscale) @@ -699,9 +752,9 @@ class LinearPDEConformingVolumeTaylorLocalExpansion( VolumeTaylorLocalExpansionBase): def __init__(self, kernel, order, use_rscale=None, - use_preprocessing_for_m2l=False): + use_fft_for_m2l=False, use_preprocessing_for_m2l=None): VolumeTaylorLocalExpansionBase.__init__(self, kernel, order, use_rscale, - use_preprocessing_for_m2l) + use_fft_for_m2l, use_preprocessing_for_m2l) LinearPDEConformingVolumeTaylorExpansion.__init__( self, kernel, order, use_rscale) @@ -745,9 +798,9 @@ class BiharmonicConformingVolumeTaylorLocalExpansion( class _FourierBesselLocalExpansion(LocalExpansionBase): def __init__(self, kernel, order, use_rscale=None, - use_preprocessing_for_m2l=False): + use_fft_for_m2l=False, use_preprocessing_for_m2l=None): - if use_preprocessing_for_m2l: + if use_fft_for_m2l: # FIXME: expansion with FFT is correct symbolically and can be verified # with sympy. However there are numerical issues that we have to deal # with. Greengard and Rokhlin 1988 attributes this to numerical @@ -758,6 +811,7 @@ class _FourierBesselLocalExpansion(LocalExpansionBase): "supported yet.") super().__init__(kernel, order, use_rscale, + use_fft_for_m2l=use_fft_for_m2l, use_preprocessing_for_m2l=use_preprocessing_for_m2l) def get_storage_index(self, k): @@ -829,7 +883,7 @@ class _FourierBesselLocalExpansion(LocalExpansionBase): Hankel1(m + j, arg_scale * dvec_len, 0) * sym.exp(sym.I * (m + j) * new_center_angle_rel_old_center)) - if self.use_preprocessing_for_m2l: + if self.use_fft_for_m2l: order = src_expansion.order # For this expansion, we have a mirror image of a Toeplitz matrix. # First, we have to take the mirror image of the M2L matrix. @@ -860,18 +914,20 @@ class _FourierBesselLocalExpansion(LocalExpansionBase): for m in src_expansion.get_coefficient_identifiers(): src_coeff_exprs[src_expansion.get_storage_index(m)] *= src_rscale**abs(m) - if self.use_preprocessing_for_m2l: + if self.use_fft_for_m2l: src_coeff_exprs = list(reversed(src_coeff_exprs)) src_coeff_exprs += [0] * (len(src_coeff_exprs) - 1) - res = fft(src_coeff_exprs, sac=sac) - return res + return fft(src_coeff_exprs, sac=sac) else: return src_coeff_exprs + def m2l_preprocess_multipole_nexprs(self, src_expansion): + return 2*src_expansion.order + 1 + def m2l_postprocess_local_exprs(self, src_expansion, m2l_result, src_rscale, tgt_rscale, sac): - if self.use_preprocessing_for_m2l: + if self.use_fft_for_m2l: m2l_result = fft(m2l_result, inverse=True, sac=sac) m2l_result = m2l_result[:2*self.order+1] @@ -883,6 +939,9 @@ class _FourierBesselLocalExpansion(LocalExpansionBase): return result + def m2l_postprocess_local_nexprs(self, src_expansion): + return 2*self.order + 1 + def translate_from(self, src_expansion, src_coeff_exprs, src_rscale, dvec, tgt_rscale, sac=None, m2l_translation_classes_dependent_data=None): from sumpy.symbolic import sym_real_norm_2, BesselJ @@ -915,40 +974,85 @@ class _FourierBesselLocalExpansion(LocalExpansionBase): else: derivatives = m2l_translation_classes_dependent_data - translated_coeffs = [] - if self.use_preprocessing_for_m2l: + if self.use_fft_for_m2l: assert m2l_translation_classes_dependent_data is not None assert len(derivatives) == len(src_coeff_exprs) - for a, b in zip(derivatives, src_coeff_exprs): - translated_coeffs.append(a * b) - return translated_coeffs - - src_coeff_exprs = self.m2l_preprocess_multipole_exprs(src_expansion, - src_coeff_exprs, sac, src_rscale) + translated_coeffs = [a * b for a, b in zip(derivatives, + src_coeff_exprs)] + else: + if not self.use_preprocessing_for_m2l: + src_coeff_exprs = self.m2l_preprocess_multipole_exprs( + src_expansion, src_coeff_exprs, sac, src_rscale) - for j in self.get_coefficient_identifiers(): - translated_coeffs.append( + translated_coeffs = [ sum(derivatives[m + j + self.order + src_expansion.order] - * src_coeff_exprs[src_expansion.get_storage_index(m)] - for m in src_expansion.get_coefficient_identifiers())) + * src_coeff_exprs[src_expansion.get_storage_index(m)] + for m in src_expansion.get_coefficient_identifiers()) + for j in self.get_coefficient_identifiers()] + + if not self.use_preprocessing_for_m2l: + translated_coeffs = self.m2l_postprocess_local_exprs( + src_expansion, translated_coeffs, src_rscale, tgt_rscale, + sac) - translated_coeffs = self.m2l_postprocess_local_exprs(src_expansion, - translated_coeffs, src_rscale, tgt_rscale, sac) return translated_coeffs raise RuntimeError("do not know how to translate %s to %s" % (type(src_expansion).__name__, type(self).__name__)) + def loopy_translate_from(self, src_expansion): + if isinstance(src_expansion, self.mpole_expn_class): + if self.use_preprocessing_for_m2l: + ncoeff_src = self.m2l_preprocess_multipole_nexprs(src_expansion) + ncoeff_tgt = self.m2l_postprocess_local_nexprs(src_expansion) + + icoeff_src = pymbolic.var("icoeff_src") + icoeff_tgt = pymbolic.var("icoeff_tgt") + domains = [f"{{[icoeff_tgt]: 0<=icoeff_tgt<{ncoeff_tgt} }}"] + + coeff = pymbolic.var("coeff") + src_coeffs = pymbolic.var("src_coeffs") + m2l_translation_classes_dependent_data = pymbolic.var("data") + + if self.use_fft_for_m2l: + expr = src_coeffs[icoeff_tgt] \ + * m2l_translation_classes_dependent_data[icoeff_tgt] + else: + expr = src_coeffs[icoeff_src] \ + * m2l_translation_classes_dependent_data[ + icoeff_tgt + icoeff_src] + domains.append( + f"{{[icoeff_src]: 0<=icoeff_src<{ncoeff_src} }}") + + insns = [ + lp.Assignment( + assignee=coeff[icoeff_tgt], + expression=coeff[icoeff_tgt] + expr), + ] + return lp.make_function(domains, insns, + kernel_data=[ + lp.GlobalArg("coeff, src_coeffs, data", + shape=lp.auto), + lp.ValueArg("src_rscale, tgt_rscale"), + ...], + name="e2e", + lang_version=lp.MOST_RECENT_LANGUAGE_VERSION, + ) + raise NotImplementedError( + f"A direct loopy kernel for translation from " + f"{src_expansion} to {self} is not implemented.") + class H2DLocalExpansion(_FourierBesselLocalExpansion): def __init__(self, kernel, order, use_rscale=None, - use_preprocessing_for_m2l=False): + use_fft_for_m2l=False, use_preprocessing_for_m2l=None): from sumpy.kernel import HelmholtzKernel assert (isinstance(kernel.get_base_kernel(), HelmholtzKernel) and kernel.dim == 2) super().__init__(kernel, order, use_rscale, + use_fft_for_m2l=use_fft_for_m2l, use_preprocessing_for_m2l=use_preprocessing_for_m2l) from sumpy.expansion.multipole import H2DMultipoleExpansion @@ -960,12 +1064,13 @@ class H2DLocalExpansion(_FourierBesselLocalExpansion): class Y2DLocalExpansion(_FourierBesselLocalExpansion): def __init__(self, kernel, order, use_rscale=None, - use_preprocessing_for_m2l=False): + use_fft_for_m2l=False, use_preprocessing_for_m2l=None): from sumpy.kernel import YukawaKernel assert (isinstance(kernel.get_base_kernel(), YukawaKernel) and kernel.dim == 2) super().__init__(kernel, order, use_rscale, + use_fft_for_m2l=use_fft_for_m2l, use_preprocessing_for_m2l=use_preprocessing_for_m2l) from sumpy.expansion.multipole import Y2DMultipoleExpansion diff --git a/sumpy/fmm.py b/sumpy/fmm.py index bfeadbe5..7403a1db 100644 --- a/sumpy/fmm.py +++ b/sumpy/fmm.py @@ -40,7 +40,7 @@ from sumpy import ( E2EFromCSR, M2LUsingTranslationClassesDependentData, E2EFromChildren, E2EFromParent, M2LGenerateTranslationClassesDependentData, - M2LPreprocessMultipole) + M2LPreprocessMultipole, M2LPostprocessLocal) from sumpy.tools import to_complex_dtype @@ -67,7 +67,7 @@ class SumpyTreeIndependentDataForWrangler(TreeIndependentDataForWrangler): local_expansion_factory, target_kernels, exclude_self=False, use_rscale=None, strength_usage=None, source_kernels=None, - use_preprocessing_for_m2l=False): + use_fft_for_m2l=False, use_preprocessing_for_m2l=None): """ :arg multipole_expansion_factory: a callable of a single argument (order) that returns a multipole expansion. @@ -77,6 +77,10 @@ class SumpyTreeIndependentDataForWrangler(TreeIndependentDataForWrangler): :arg exclude_self: whether the self contribution should be excluded :arg strength_usage: passed unchanged to p2l, p2m and p2p. :arg source_kernels: passed unchanged to p2l, p2m and p2p. + :arg use_fft_for_m2l: Use an FFT based multipole-to-local expansion. + :arg use_preprocessing_for_m2l: do preprocessing of the source multipole + expansion and postprocessing of the target local expansion for + multipole-to-local expansion. """ self.multipole_expansion_factory = multipole_expansion_factory self.local_expansion_factory = local_expansion_factory @@ -85,7 +89,11 @@ class SumpyTreeIndependentDataForWrangler(TreeIndependentDataForWrangler): self.exclude_self = exclude_self self.use_rscale = use_rscale self.strength_usage = strength_usage - self.use_preprocessing_for_m2l = use_preprocessing_for_m2l + self.use_fft_for_m2l = use_fft_for_m2l + if use_preprocessing_for_m2l is None: + self.use_preprocessing_for_m2l = use_fft_for_m2l + else: + self.use_preprocessing_for_m2l = use_preprocessing_for_m2l super().__init__() @@ -103,6 +111,7 @@ class SumpyTreeIndependentDataForWrangler(TreeIndependentDataForWrangler): @memoize_method def local_expansion(self, order): return self.local_expansion_factory(order, self.use_rscale, + use_fft_for_m2l=self.use_fft_for_m2l, use_preprocessing_for_m2l=self.use_preprocessing_for_m2l) @memoize_method @@ -148,6 +157,12 @@ class SumpyTreeIndependentDataForWrangler(TreeIndependentDataForWrangler): self.multipole_expansion(src_order), self.local_expansion(tgt_order)) + @memoize_method + def m2l_postprocess_local_kernel(self, src_order, tgt_order): + return M2LPostprocessLocal(self.cl_context, + self.multipole_expansion(src_order), + self.local_expansion(tgt_order)) + @memoize_method def l2l(self, src_order, tgt_order): return E2EFromParent(self.cl_context, @@ -310,7 +325,7 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface): self.dtype = dtype - if not self.tree_indep.use_preprocessing_for_m2l: + if not self.tree_indep.use_fft_for_m2l: # If not FFT, we don't need complex dtypes self.preprocessed_mpole_dtype = dtype elif preprocessed_mpole_dtype is not None: @@ -345,7 +360,7 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface): if base_kernel.is_translation_invariant: if translation_classes_data is None: from warnings import warn - if self.tree_indep.use_preprocessing_for_m2l: + if self.tree_indep.use_fft_for_m2l: raise NotImplementedError( "FFT based List 2 (multipole-to-local) translations " "without translation_classes_data argument is not " @@ -358,14 +373,14 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface): "to the wrangler for optimized List 2.", SumpyTranslationClassesDataNotSuppliedWarning, stacklevel=2) - self.supports_optimized_m2l = False + self.supports_translation_classes = False else: - self.supports_optimized_m2l = True + self.supports_translation_classes = True else: - self.supports_optimized_m2l = False + self.supports_translation_classes = False self.translation_classes_data = translation_classes_data - self.use_preprocessing_for_m2l = self.tree_indep.use_preprocessing_for_m2l + self.use_fft_for_m2l = self.tree_indep.use_fft_for_m2l # {{{ data vector utilities @@ -465,7 +480,7 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface): def order_to_size(order): mpole_expn = self.tree_indep.multipole_expansion(order) local_expn = self.tree_indep.local_expansion(order) - res = local_expn.m2l_translation_classes_dependent_ndata(mpole_expn) + res = local_expn.m2l_preprocess_multipole_nexprs(mpole_expn) return res return build_csr_level_starts(self.level_orders, order_to_size, @@ -485,6 +500,11 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface): return (box_start, mpole_exps[expn_start:expn_stop].reshape(box_stop-box_start, -1)) + m2l_work_array_view = m2l_preproc_mpole_expansions_view + m2l_work_array_zeros = m2l_preproc_mpole_expansion_zeros + m2l_work_array_level_starts = \ + m2l_preproc_mpole_expansions_level_starts + def output_zeros(self, template_ary): """Return a potentials array (which must support addition) capable of holding a potential value for each target in the tree. Note that @@ -707,16 +727,15 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface): m2l_translation_classes_dependent_data = \ m2l_translation_classes_dependent_data.with_queue(None) - return m2l_translation_classes_dependent_data def _add_m2l_precompute_kwargs(self, kwargs_for_m2l, lev): - """This method is used for addin the information needed for a + """This method is used for adding the information needed for a multipole-to-local translation with precomputation to the keywords passed to multipole-to-local translation. """ - if not self.supports_optimized_m2l: + if not self.supports_translation_classes: return m2l_translation_classes_dependent_data = \ self.multipole_to_local_precompute() @@ -736,18 +755,19 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface): target_boxes, src_box_starts, src_box_lists, mpole_exps): - precompute_evts = [] + preprocess_evts = [] queue = mpole_exps.queue + local_exps = self.local_expansion_zeros(mpole_exps) - if self.use_preprocessing_for_m2l: + if self.tree_indep.use_preprocessing_for_m2l: preprocessed_mpole_exps = \ - self.m2l_preproc_mpole_expansion_zeros(mpole_exps) + self.m2l_preproc_mpole_expansion_zeros(mpole_exps) for lev in range(self.tree.nlevels): order = self.level_orders[lev] preprocess_mpole_kernel = \ self.tree_indep.m2l_preprocess_mpole_kernel(order, order) - source_level_start_ibox, source_mpoles_view = \ + _, source_mpoles_view = \ self.multipole_expansions_view(mpole_exps, lev) _, preprocessed_source_mpoles_view = \ @@ -766,14 +786,17 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface): src_rscale=level_to_rscale(self.tree, lev), **self.kernel_extra_kwargs ) - precompute_evts.append(evt) + preprocess_evts.append(evt) mpole_exps = preprocessed_mpole_exps + m2l_work_array = self.m2l_work_array_zeros(local_exps) mpole_exps_view_func = self.m2l_preproc_mpole_expansions_view + local_exps_view_func = self.m2l_work_array_view else: + m2l_work_array = local_exps mpole_exps_view_func = self.multipole_expansions_view + local_exps_view_func = self.local_expansions_view - events = [] - local_exps = self.local_expansion_zeros(mpole_exps) + translate_evts = [] for lev in range(self.tree.nlevels): start, stop = level_start_target_box_nrs[lev:lev+2] @@ -781,17 +804,18 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface): continue order = self.level_orders[lev] - m2l = self.tree_indep.m2l(order, order, self.supports_optimized_m2l) + m2l = self.tree_indep.m2l(order, order, + self.supports_translation_classes) source_level_start_ibox, source_mpoles_view = \ mpole_exps_view_func(mpole_exps, lev) - target_level_start_ibox, target_local_exps_view = \ - self.local_expansions_view(local_exps, lev) + target_level_start_ibox, target_locals_view = \ + local_exps_view_func(m2l_work_array, lev) kwargs = dict( src_expansions=source_mpoles_view, src_base_ibox=source_level_start_ibox, - tgt_expansions=target_local_exps_view, + tgt_expansions=target_locals_view, tgt_base_ibox=target_level_start_ibox, target_boxes=target_boxes[start:stop], @@ -809,11 +833,45 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface): kwargs["m2l_translation_classes_dependent_data"].size == 0: # There is nothing to do for this level continue - evt, _ = m2l(queue, **kwargs, wait_for=precompute_evts) + evt, _ = m2l(queue, **kwargs, wait_for=preprocess_evts) - events.append(evt) + translate_evts.append(evt) - return (local_exps, SumpyTimingFuture(queue, events)) + postprocess_evts = [] + + if self.tree_indep.use_preprocessing_for_m2l: + for lev in range(self.tree.nlevels): + order = self.level_orders[lev] + postprocess_local_kernel = \ + self.tree_indep.m2l_postprocess_local_kernel(order, order) + + _, target_locals_view = \ + self.local_expansions_view(local_exps, lev) + + _, target_locals_before_postprocessing_view = \ + self.m2l_work_array_view( + m2l_work_array, lev) + + tr_classes = self.m2l_translation_class_level_start_box_nrs() + if tr_classes[lev] == tr_classes[lev + 1]: + # There is no M2L happening in this level + continue + + evt, _ = postprocess_local_kernel( + queue, + tgt_expansions=target_locals_view, + tgt_expansions_before_postprocessing=( + target_locals_before_postprocessing_view), + src_rscale=level_to_rscale(self.tree, lev), + tgt_rscale=level_to_rscale(self.tree, lev), + wait_for=translate_evts, + **self.kernel_extra_kwargs, + ) + postprocess_evts.append(evt) + + timing_events = preprocess_evts + translate_evts + postprocess_evts + + return (local_exps, SumpyTimingFuture(queue, timing_events)) def eval_multipoles(self, target_boxes_by_source_level, source_boxes_by_level, mpole_exps): diff --git a/test/test_fmm.py b/test/test_fmm.py index 3e2f84fd..d1fd9690 100644 --- a/test/test_fmm.py +++ b/test/test_fmm.py @@ -190,7 +190,7 @@ def test_sumpy_fmm(ctx_factory, knl, local_expn_class, mpole_expn_class, ctx, partial(mpole_expn_class, knl), partial(local_expn_class, knl), - target_kernels, use_preprocessing_for_m2l=use_fft) + target_kernels, use_fft_for_m2l=use_fft) with warnings.catch_warnings(): if not optimized_m2l: -- GitLab