diff --git a/sumpy/fmm.py b/sumpy/fmm.py index 39ed7f2093d44a32c24a5265cfb0ba844de2e98d..157bd5f53179cc5dc5451a8c5005548f5537297a 100644 --- a/sumpy/fmm.py +++ b/sumpy/fmm.py @@ -372,7 +372,7 @@ class SumpyExpansionWrangler: self.queue, source_boxes=source_boxes[start:stop], centers=self.tree.box_centers, - strengths=src_weights, + strengths=(src_weights,), tgt_expansions=mpoles_view, tgt_base_ibox=level_start_ibox, @@ -588,7 +588,7 @@ class SumpyExpansionWrangler: source_box_starts=starts[start:stop+1], source_box_lists=lists, centers=self.tree.box_centers, - strengths=src_weights, + strengths=(src_weights,), tgt_expansions=target_local_exps_view, tgt_base_ibox=target_level_start_ibox, diff --git a/sumpy/p2e.py b/sumpy/p2e.py index ff365e014feec9694b829360039d21ede6519c7b..f2782d6348f04d6f6bb02cc21a764ce0e76ac339 100644 --- a/sumpy/p2e.py +++ b/sumpy/p2e.py @@ -24,9 +24,8 @@ import numpy as np import loopy as lp from loopy.version import MOST_RECENT_LANGUAGE_VERSION import pymbolic -from pymbolic.mapper import WalkMapper -from sumpy.tools import KernelCacheWrapper +from sumpy.tools import KernelCacheWrapper, KernelComputation import logging logger = logging.getLogger(__name__) @@ -46,9 +45,11 @@ Particle-to-expansion # {{{ P2E base class -class P2EBase(KernelCacheWrapper): +class P2EBase(KernelComputation, KernelCacheWrapper): + """Common input processing for kernel computations.""" + def __init__(self, ctx, expansion, kernels=None, - options=[], name=None, device=None, strength_expr=None, nstrengths=1): + options=[], name=None, device=None, strength_usage=None): """ :arg expansion: a subclass of :class:`sympy.expansion.ExpansionBase` :arg strength_usage: A list of integers indicating which expression @@ -56,40 +57,20 @@ class P2EBase(KernelCacheWrapper): number of strength arrays that need to be passed. Default: all kernels use the same strength. """ - - if device is None: - device = ctx.devices[0] - if kernels is None: kernels = [expansion.kernel] - if strength_expr is None: - import pymbolic - strength_expr = pymbolic.parse("strength0 * kernel0") + KernelComputation.__init__(self, ctx=ctx, kernels=kernels, + strength_usage=strength_usage, value_dtypes=None, + name=name, options=options, device=device) from sumpy.kernel import TargetDerivativeRemover expansion = expansion.with_kernel( TargetDerivativeRemover()(expansion.kernel)) - self.ctx = ctx self.expansion = expansion - self.options = options - self.name = name or self.default_name - self.device = device - self.kernels = kernels - self.strength_expr, self.extra_source_variables = \ - process_strength_expr(strength_expr, nstrengths, len(kernels)) - self.nstrengths = nstrengths - self.dim = expansion.dim - def get_result_expr(self, icoeff): - subst_dict = dict(( - pymbolic.var(f"kernel{iknl}"), pymbolic.var(f"coeff{icoeff}_{iknl}") - ) for iknl in range(len(self.kernels))) - res = pymbolic.substitute(self.strength_expr, subst_dict) - return res - def get_loopy_instructions(self): from sumpy.symbolic import make_sym_vector avec = make_sym_vector("a", self.dim) @@ -123,6 +104,15 @@ class P2EBase(KernelCacheWrapper): def get_cache_key(self): return (type(self).__name__, self.name, self.expansion) + def get_result_expr(self, icoeff): + expr = 0 + isrc = pymbolic.var("isrc") + for i in range(len(self.kernels)): + strength_num = self.strength_usage[i] + expr += pymbolic.var(f"coeff{icoeff}_{i}") * \ + pymbolic.var("strengths")[strength_num, isrc] + return expr + # }}} @@ -162,7 +152,7 @@ class P2EFromSingleBox(P2EBase): [ lp.GlobalArg("sources", None, shape=(self.dim, "nsources"), dim_tags="sep,c"), - lp.GlobalArg("strengths", None, shape="nstrengths, nsources", + lp.GlobalArg("strengths", None, shape="strength_count, nsources", dim_tags="sep,C"), lp.GlobalArg("box_source_starts,box_source_counts_nonchild", None, shape=None), @@ -178,7 +168,8 @@ class P2EFromSingleBox(P2EBase): assumptions="nsrc_boxes>=1", silenced_warnings="write_race(write_expn*)", default_offset=lp.auto, - fixed_parameters=dict(dim=self.dim, nstrengths=self.nstrengths), + fixed_parameters=dict(dim=self.dim, + strength_count=self.strength_count), lang_version=MOST_RECENT_LANGUAGE_VERSION) loopy_knl = self.expansion.prepare_loopy_kernel(loopy_knl) @@ -229,7 +220,7 @@ class P2EFromCSR(P2EBase): [ lp.GlobalArg("sources", None, shape=(self.dim, "nsources"), dim_tags="sep,c"), - lp.GlobalArg("strengths", None, shape="nstrengths, nsources", + lp.GlobalArg("strengths", None, shape="strength_count, nsources", dim_tags="sep,C"), lp.GlobalArg("source_box_starts,source_box_lists", None, shape=None, offset=lp.auto), @@ -284,7 +275,8 @@ class P2EFromCSR(P2EBase): assumptions="ntgt_boxes>=1", silenced_warnings="write_race(write_expn*)", default_offset=lp.auto, - fixed_parameters=dict(dim=self.dim, nstrengths=self.nstrengths), + fixed_parameters=dict(dim=self.dim, + strength_count=self.strength_count), lang_version=MOST_RECENT_LANGUAGE_VERSION) loopy_knl = self.expansion.prepare_loopy_kernel(loopy_knl) @@ -296,6 +288,7 @@ class P2EFromCSR(P2EBase): # FIXME knl = self.get_kernel() knl = lp.split_iname(knl, "itgt_box", 16, outer_tag="g.0") + return knl def __call__(self, queue, **kwargs): @@ -320,39 +313,4 @@ class P2EFromCSR(P2EBase): # }}} - -# {{{ helper functions - -class CollectVariableMapper(WalkMapper): - def __init__(self): - self.variables = set() - - def post_visit(self, expr, *args, **kwargs): - if isinstance(expr, pymbolic.primitives.Variable): - self.variables.add(expr) - - -def process_strength_expr(expr, nstrengths, nkernels): - collect_variable_mapper = CollectVariableMapper() - collect_variable_mapper(expr) - variables = collect_variable_mapper.variables - - # Get variables that are not strengths nor kernels - source_variables = [var for var in variables if - not (var.name.startswith("strength") or var.name.startswith("kernel"))] - - # Use strengths array and index with isrc - subst = {} - for i in range(nstrengths): - old = pymbolic.var(f"strength{i}") - new = pymbolic.var("strengths")[0, pymbolic.var("isrc")] - subst[old] = new - print(subst) - expr = pymbolic.substitute(expr, subst) - - return expr, source_variables - - -#}}} - # vim: foldmethod=marker diff --git a/test/test_kernels.py b/test/test_kernels.py index 4db163f85cd79c0e01dea3fbf4c044595d11bce1..d6631f92840f3027a82612ba754d7439843b70c3 100644 --- a/test/test_kernels.py +++ b/test/test_kernels.py @@ -173,7 +173,7 @@ def test_p2e2p(ctx_factory, base_knl, expn_class, order, with_source_derivative) expn = expn_class(knl, order=order) from sumpy import P2EFromSingleBox, E2PFromSingleBox, P2P - p2e = P2EFromSingleBox(ctx, expn, out_kernels) + p2e = P2EFromSingleBox(ctx, expn, [knl]) e2p = E2PFromSingleBox(ctx, expn, out_kernels) p2p = P2P(ctx, out_kernels, exclude_self=False)