From 29a2fbd5d83fed8598fbaede1063c9145906f322 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20Kl=C3=B6ckner?= <inform@tiker.net> Date: Mon, 7 Jun 2021 14:07:10 -0500 Subject: [PATCH] Adapt codegen to loopy kernel callables (#62) * Adapt codegen to loopy kernel callables * Factor Bessel callable registration into a separate function, use in Yukawa * Add transfer_requirements_git_urls from sumpy to downstream projects * Point req.txt for loopy back to main --- .github/workflows/ci.yml | 9 +- sumpy/codegen.py | 423 ++++++++++++++++++++------------------- sumpy/kernel.py | 18 +- 3 files changed, 231 insertions(+), 219 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e99a76b0..e0f76370 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -88,19 +88,20 @@ jobs: env: DOWNSTREAM_PROJECT: ${{ matrix.downstream_project }} run: | - if [[ "$DOWNSTREAM_PROJECT" = "pytential" ]] && [[ "$GITHUB_HEAD_REF" = "multiplier" ]]; then - git clone "https://github.com/isuruf/$DOWNSTREAM_PROJECT.git" -b "remover" + curl -L -O https://tiker.net/ci-support-v0 + . ./ci-support-v0 + if [[ "$DOWNSTREAM_PROJECT" = "pytential" ]] && [[ "$GITHUB_HEAD_REF" = "derivtaker" ]]; then + git clone "https://github.com/isuruf/$DOWNSTREAM_PROJECT.git" -b "$GITHUB_HEAD_REF" else git clone "https://github.com/inducer/$DOWNSTREAM_PROJECT.git" fi cd "$DOWNSTREAM_PROJECT" echo "*** $DOWNSTREAM_PROJECT version: $(git rev-parse --short HEAD)" + transfer_requirements_git_urls ../requirements.txt ./requirements.txt sed -i "/egg=sumpy/ c git+file://$(readlink -f ..)#egg=sumpy" requirements.txt export CONDA_ENVIRONMENT=.test-conda-env-py3.yml # Avoid slow or complicated tests in downstream projects export PYTEST_ADDOPTS="-k 'not (slowtest or octave or mpi)'" - curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/main/ci-support.sh - . ./ci-support.sh build_py_project_in_conda_env test_py_project diff --git a/sumpy/codegen.py b/sumpy/codegen.py index 8a49c0c7..4da7b637 100644 --- a/sumpy/codegen.py +++ b/sumpy/codegen.py @@ -21,18 +21,15 @@ THE SOFTWARE. """ +import re + import numpy as np -import pyopencl as cl -import pyopencl.tools # noqa import loopy as lp - -import re +from loopy.kernel.instruction import make_assignment from pymbolic.mapper import IdentityMapper, CSECachingMapperMixin import pymbolic.primitives as prim -from loopy.types import NumpyType - from pytools import memoize_method from sumpy.symbolic import (SympyToPymbolicMapper as SympyToPymbolicMapperBase) @@ -73,136 +70,149 @@ class SympyToPymbolicMapper(SympyToPymbolicMapperBase): # }}} -# {{{ bessel handling +# {{{ bessel -> loopy codegen BESSEL_PREAMBLE = """//CL// #include <pyopencl-bessel-j.cl> #include <pyopencl-bessel-y.cl> #include <pyopencl-bessel-j-complex.cl> -typedef struct bessel_j_two_result_str -{ - cdouble_t jv, jvp1; -} bessel_j_two_result; - -bessel_j_two_result bessel_jv_two(int v, double z) +double bessel_jv_two(int v, double z, double *jvp1) { - bessel_j_two_result result; - result.jv = cdouble_fromreal(bessel_jv(v, z)); - result.jvp1 = cdouble_fromreal(bessel_jv(v+1, z)); - return result; + *jvp1 = bessel_jv(v+1, z); + return bessel_jv(v, z); } -bessel_j_two_result bessel_jv_two_complex(int v, cdouble_t z) +cdouble_t bessel_jv_two_complex(int v, cdouble_t z, cdouble_t *jvp1) { - bessel_j_two_result result; - bessel_j_complex(v, z, &result.jv, &result.jvp1); - return result; + cdouble_t jv; + bessel_j_complex(v, z, &jv, jvp1); + return jv; } """ HANKEL_PREAMBLE = """//CL// +#include <pyopencl-bessel-j.cl> +#include <pyopencl-bessel-y.cl> #include <pyopencl-hankel-complex.cl> -typedef struct hank1_01_result_str -{ - cdouble_t order0, order1; -} hank1_01_result; - -hank1_01_result hank1_01(double z) +cdouble_t hank1_01(double z, cdouble_t *order1) { - hank1_01_result result; - result.order0 = cdouble_new(bessel_j0(z), bessel_y0(z)); - result.order1 = cdouble_new(bessel_j1(z), bessel_y1(z)); - return result; + *order1 = cdouble_new(bessel_j1(z), bessel_y1(z)); + return cdouble_new(bessel_j0(z), bessel_y0(z)); } -hank1_01_result hank1_01_complex(cdouble_t z) +cdouble_t hank1_01_complex(cdouble_t z, cdouble_t *order1) { - hank1_01_result result; - hankel_01_complex(z, &result.order0, &result.order1, 1); - return result; + cdouble_t order0; + hankel_01_complex(z, &order0, order1, 1); + return order0; } """ -def bessel_preamble_generator(preamble_info): - from loopy.target.pyopencl import PyOpenCLTarget - if not isinstance(preamble_info.kernel.target, PyOpenCLTarget): - raise NotImplementedError("Only the PyOpenCLTarget is supported as of now") +class BesselJvvp1(lp.ScalarCallable): + def with_types(self, arg_id_to_dtype, clbl_inf_ctx): + from loopy.types import NumpyType - require_bessel = False - if any(func.name == "hank1_01" for func in preamble_info.seen_functions): - yield ("50-sumpy-hankel", HANKEL_PREAMBLE) - require_bessel = True - if (require_bessel - or any(func.name == "bessel_jv_two" - for func in preamble_info.seen_functions)): - yield ("40-sumpy-bessel", BESSEL_PREAMBLE) + for i in arg_id_to_dtype: + if not (-2 <= i <= 1): + raise TypeError(f"{self.name} can only take 2 arguments.") + if (arg_id_to_dtype.get(0) is None) or (arg_id_to_dtype.get(1) is None): + # not enough info about input types + return self, clbl_inf_ctx -hank1_01_result_dtype = cl.tools.get_or_register_dtype("hank1_01_result", - NumpyType(np.dtype([ - ("order0", np.complex128), - ("order1", np.complex128), - ])), - ) + n_dtype = arg_id_to_dtype[0] + z_dtype = arg_id_to_dtype[1] -bessel_j_two_result_dtype = cl.tools.get_or_register_dtype("bessel_j_two_result", - NumpyType(np.dtype([ - ("jv", np.complex128), - ("jvp1", np.complex128), - ])), - ) + # *technically* takes a float, but let's not worry about that. + if n_dtype.numpy_dtype.kind != "i": + raise TypeError(f"{self.name} expects an integer first argument") + if z_dtype.numpy_dtype.kind == "c": + return (self.copy(name_in_target="bessel_jv_two_complex", + arg_id_to_dtype={ + -2: NumpyType(np.complex128), + -1: NumpyType(np.complex128), + 0: NumpyType(np.int32), + 1: NumpyType(np.complex128), + }), + clbl_inf_ctx) + else: + return (self.copy(name_in_target="bessel_jv_two", + arg_id_to_dtype={ + -2: NumpyType(np.float64), + -1: NumpyType(np.float64), + 0: NumpyType(np.int32), + 1: NumpyType(np.float64), + }), + clbl_inf_ctx) + + def generate_preambles(self, target): + from loopy import PyOpenCLTarget + if not isinstance(target, PyOpenCLTarget): + raise NotImplementedError("Only the PyOpenCLTarget is supported as" + "of now.") -def bessel_mangler(kernel, identifier, arg_dtypes): - """A function "mangler" to make Bessel functions - digestible for :mod:`loopy`. + yield ("40-sumpy-bessel", BESSEL_PREAMBLE) - See argument *function_manglers* to :func:`loopy.make_kernel`. - """ - from loopy.target.pyopencl import PyOpenCLTarget - if not isinstance(kernel.target, PyOpenCLTarget): - raise NotImplementedError("Only the PyOpenCLTarget is supported as of now") - - if identifier == "hank1_01": - if arg_dtypes[0].is_complex(): - identifier = "hank1_01_complex" - return lp.CallMangleInfo( - target_name=identifier, - result_dtypes=(NumpyType(np.dtype(hank1_01_result_dtype)),), - arg_dtypes=( - NumpyType(np.dtype(np.complex128)), - )) - else: - return lp.CallMangleInfo( - target_name=identifier, - result_dtypes=(NumpyType(np.dtype(hank1_01_result_dtype)),), - arg_dtypes=( - NumpyType(np.dtype(np.float64)), - )) - - elif identifier == "bessel_jv_two": - if arg_dtypes[1].is_complex(): - identifier = "bessel_jv_two_complex" - return lp.CallMangleInfo( - target_name=identifier, - result_dtypes=(NumpyType(np.dtype(bessel_j_two_result_dtype)),), - arg_dtypes=( - NumpyType(np.dtype(np.int32)), - NumpyType(np.dtype(np.complex128)),)) +class Hankel1_01(lp.ScalarCallable): # noqa: N801 + def with_types(self, arg_id_to_dtype, clbl_inf_ctx): + from loopy.types import NumpyType + + for i in arg_id_to_dtype: + if not (-2 <= i <= 0): + raise TypeError(f"{self.name} can only take one argument.") + + if arg_id_to_dtype.get(0) is None: + # not enough info about input types + return self, clbl_inf_ctx + + z_dtype = arg_id_to_dtype[0] + + if z_dtype.numpy_dtype.kind == "c": + return (self.copy(name_in_target="hank1_01_complex", + arg_id_to_dtype={ + -2: NumpyType(np.complex128), + -1: NumpyType(np.complex128), + 0: NumpyType(np.complex128), + }), + clbl_inf_ctx) else: - return lp.CallMangleInfo( - target_name=identifier, - result_dtypes=(NumpyType(np.dtype(bessel_j_two_result_dtype)),), - arg_dtypes=( - NumpyType(np.dtype(np.int32)), - NumpyType(np.dtype(np.float64)),)) + return (self.copy(name_in_target="hank1_01", + arg_id_to_dtype={ + -2: NumpyType(np.complex128), + -1: NumpyType(np.complex128), + 0: NumpyType(np.float64), + }), + clbl_inf_ctx) + + def generate_preambles(self, target): + from loopy import PyOpenCLTarget + if not isinstance(target, PyOpenCLTarget): + raise NotImplementedError("Only the PyOpenCLTarget is supported as" + "of now.") + + yield ("50-sumpy-hankel", HANKEL_PREAMBLE) + + +def register_bessel_callables(loopy_knl): + from sumpy.codegen import BesselJvvp1, Hankel1_01 + loopy_knl = lp.register_callable(loopy_knl, "bessel_jvvp1", + BesselJvvp1("bessel_jvvp1")) + loopy_knl = lp.register_callable(loopy_knl, "hank1_01", + Hankel1_01("hank1_01")) + return loopy_knl + +# }}} - else: - return None + +# {{{ custom mapper base classes + +class CSECachingIdentityMapper(IdentityMapper, CSECachingMapperMixin): + pass class CallExternalRecMapper(IdentityMapper): @@ -212,8 +222,12 @@ class CallExternalRecMapper(IdentityMapper): else: return super().rec(expr, *args, **kwargs) +# }}} + -class BesselTopOrderGatherer(CSECachingMapperMixin, CallExternalRecMapper): +# {{{ bessel handling + +class BesselTopOrderGatherer(CSECachingIdentityMapper, CallExternalRecMapper): """This mapper walks the expression tree to find the highest-order Bessel J being used, so that all other Js can be computed by the (stable) downward recurrence. @@ -230,13 +244,13 @@ class BesselTopOrderGatherer(CSECachingMapperMixin, CallExternalRecMapper): self.bessel_j_arg_to_top_order[arg] = max( self.bessel_j_arg_to_top_order.get(arg, 0), abs(order)) - return IdentityMapper.map_call(rec_self if rec_self else self, + return CSECachingIdentityMapper.map_call(rec_self or self, expr, rec_self, *args) map_common_subexpression_uncached = IdentityMapper.map_common_subexpression -class BesselDerivativeReplacer(CSECachingMapperMixin, CallExternalRecMapper): +class BesselDerivativeReplacer(CSECachingIdentityMapper, CallExternalRecMapper): def map_call(self, expr, rec_self=None, *args): call = expr @@ -262,98 +276,56 @@ class BesselDerivativeReplacer(CSECachingMapperMixin, CallExternalRecMapper): for idx, i in enumerate(range(order-k, order+k+1, 2))), "d%d_%s_%s" % (n_derivs, function.name, order_str)) else: - return IdentityMapper.map_call( - rec_self if rec_self else self, expr, rec_self, *args) - - map_common_subexpression_uncached = IdentityMapper.map_common_subexpression - - -class HankelSubstitutor(CSECachingMapperMixin, CallExternalRecMapper): - def map_call(self, expr, rec_self=None, *args): - if isinstance(expr.function, prim.Variable): - name = expr.function.name - if name == "hankel_1": - order, arg = expr.parameters - return self.hankel_1(order, self.rec(arg, - rec_self, *args)) - - return IdentityMapper.map_call(rec_self if rec_self else self, expr) - - def hank1_01(self, arg): - return prim.Variable("hank1_01")(arg) - - def wrap_in_cse(self, expr, prefix): - return prim.wrap_in_cse(expr, prefix) - - @memoize_method - def hankel_1(self, order, arg): - if order == 0: - return self.wrap_in_cse( - prim.Lookup(self.hank1_01(arg), "order0"), - "hank1_01_result") - elif order == 1: - return self.wrap_in_cse( - prim.Lookup(self.hank1_01(arg), "order1"), - "hank1_01_result") - elif order < 0: - # AS (9.1.6) - nu = -order - return self.wrap_in_cse( - (-1) ** nu * self.hankel_1(nu, arg), - "hank1_neg%d" % nu) - elif order > 1: - # AS (9.1.27) - nu = order-1 - return self.wrap_in_cse( - 2*nu/arg*self.hankel_1(nu, arg) - - self.hankel_1(nu-1, arg), - "hank1_%d" % order) - else: - assert False + return CSECachingIdentityMapper.map_call( + rec_self or self, expr, rec_self, *args) map_common_subexpression_uncached = IdentityMapper.map_common_subexpression -class BesselSubstitutor(CSECachingMapperMixin, IdentityMapper): - def __init__(self, bessel_j_arg_to_top_order): +class BesselSubstitutor(CSECachingIdentityMapper): + def __init__(self, name_gen, bessel_j_arg_to_top_order, assignments): + self.name_gen = name_gen self.bessel_j_arg_to_top_order = bessel_j_arg_to_top_order self.cse_cache = {} + self.assignments = assignments - def __call__(self, expr, *args, **kwargs): - if not self.bessel_j_arg_to_top_order: - return expr - return super().__call__(expr, *args, **kwargs) - - def map_call(self, expr, rec_self=None, *args): + def map_call(self, expr, *args): if isinstance(expr.function, prim.Variable): name = expr.function.name if name == "bessel_j": order, arg = expr.parameters - return self.bessel_j(order, self.rec(arg, - rec_self, *args)) - - return IdentityMapper.map_call(rec_self if rec_self else self, expr) + return self.bessel_j(order, self.rec(arg, *args)) + elif name == "hankel_1": + order, arg = expr.parameters + return self.hankel_1(order, self.rec(arg, *args)) - @memoize_method - def bessel_jv_two(self, order, arg): - return prim.Variable("bessel_jv_two")(order, arg) + return super().map_call(expr) def wrap_in_cse(self, expr, prefix): cse = prim.wrap_in_cse(expr, prefix) return self.cse_cache.setdefault(expr, cse) + # {{{ bessel implementation + + @memoize_method + def bessel_jv_two(self, order, arg): + name_om1 = self.name_gen("bessel_%d" % (order-1)) + name_o = self.name_gen("bessel_%d" % order) + self.assignments.append( + make_assignment( + (prim.Variable(name_om1), prim.Variable(name_o),), + prim.Variable("bessel_jvvp1")(order, arg), + temp_var_types=(lp.Optional(None),)*2)) + + return prim.Variable(name_om1), prim.Variable(name_o) + @memoize_method def bessel_j(self, order, arg): top_order = self.bessel_j_arg_to_top_order[arg] - if order == top_order: - return self.wrap_in_cse( - prim.Lookup(self.bessel_jv_two(top_order-1, arg), "jvp1"), - "bessel_jv_two_result") + return self.bessel_jv_two(top_order-1, arg)[1] elif order == top_order-1: - return self.wrap_in_cse( - prim.Lookup(self.bessel_jv_two(top_order-1, arg), "jv"), - "bessel_jv_two_result") + return self.bessel_jv_two(top_order-1, arg)[0] elif order < 0: return self.wrap_in_cse( (-1)**order*self.bessel_j(-order, arg), @@ -368,6 +340,45 @@ class BesselSubstitutor(CSECachingMapperMixin, IdentityMapper): - self.bessel_j(nu+1, arg), "bessel_j_%d" % order) + # }}} + + # {{{ hankel implementation + + @memoize_method + def hank1_01(self, arg): + name_0 = self.name_gen("hank1_0") + name_1 = self.name_gen("hank1_1") + self.assignments.append( + make_assignment( + (prim.Variable(name_0), prim.Variable(name_1),), + prim.Variable("hank1_01")(arg), + temp_var_types=(lp.Optional(None),)*2)) + return prim.Variable(name_0), prim.Variable(name_1) + + @memoize_method + def hankel_1(self, order, arg): + if order == 0: + return self.hank1_01(arg)[0] + elif order == 1: + return self.hank1_01(arg)[1] + elif order < 0: + # AS (9.1.6) + nu = -order + return self.wrap_in_cse( + (-1) ** nu * self.hankel_1(nu, arg), + "hank1_neg%d" % nu) + elif order > 1: + # AS (9.1.27) + nu = order-1 + return self.wrap_in_cse( + 2*nu/arg*self.hankel_1(nu, arg) + - self.hankel_1(nu-1, arg), + "hank1_%d" % order) + else: + raise AssertionError() + + # }}} + map_common_subexpression_uncached = IdentityMapper.map_common_subexpression # }}} @@ -375,7 +386,7 @@ class BesselSubstitutor(CSECachingMapperMixin, IdentityMapper): # {{{ power rewriter -class PowerRewriter(CSECachingMapperMixin, CallExternalRecMapper): +class PowerRewriter(CSECachingIdentityMapper, CallExternalRecMapper): def map_power(self, expr, rec_self=None, *args): exp = expr.exponent if isinstance(exp, int): @@ -418,7 +429,7 @@ class PowerRewriter(CSECachingMapperMixin, CallExternalRecMapper): return self.rec(new_base**p, rec_self, *args) - return IdentityMapper.map_power(rec_self if rec_self else self, expr) + return CSECachingIdentityMapper.map_power(rec_self or self, expr) map_common_subexpression_uncached = IdentityMapper.map_common_subexpression @@ -430,11 +441,11 @@ class PowerRewriter(CSECachingMapperMixin, CallExternalRecMapper): from loopy.tools import is_integer -class BigIntegerKiller(CSECachingMapperMixin, CallExternalRecMapper): +class BigIntegerKiller(CSECachingIdentityMapper, CallExternalRecMapper): def __init__(self, warn_on_digit_loss=True, int_type=np.int64, float_type=np.float64): - IdentityMapper.__init__(self) + super().__init__() self.warn = warn_on_digit_loss self.float_type = float_type self.iinfo = np.iinfo(int_type) @@ -468,22 +479,22 @@ class BigIntegerKiller(CSECachingMapperMixin, CallExternalRecMapper): # {{{ convert 123000000j to 123000000 * 1j -class ComplexRewriter(CSECachingMapperMixin, CallExternalRecMapper): +class ComplexRewriter(CSECachingIdentityMapper, CallExternalRecMapper): def __init__(self, float_type=np.float32): - IdentityMapper.__init__(self) + super().__init__() self.float_type = float_type def map_constant(self, expr, rec_self=None, *args, **kwargs): """Convert complex values not within complex64 to a product for loopy """ if not isinstance(expr, complex): - return IdentityMapper.map_constant( - rec_self if rec_self else self, expr) + return CSECachingIdentityMapper.map_constant( + rec_self or self, expr) if complex(self.float_type(expr.imag)) == expr.imag: - return IdentityMapper.map_constant( - rec_self if rec_self else self, expr) + return CSECachingIdentityMapper.map_constant( + rec_self or self, expr) # avoid cycles if expr == 1j: @@ -501,10 +512,10 @@ class ComplexRewriter(CSECachingMapperMixin, CallExternalRecMapper): INDEXED_VAR_RE = re.compile("^([a-zA-Z_]+)([0-9]+)$") -class VectorComponentRewriter(CSECachingMapperMixin, CallExternalRecMapper): +class VectorComponentRewriter(CSECachingIdentityMapper, CallExternalRecMapper): """For names in name_whitelist, turn ``a3`` into ``a[3]``.""" - def __init__(self, name_whitelist=set()): + def __init__(self, name_whitelist=frozenset()): self.name_whitelist = name_whitelist def map_variable(self, expr, *args): @@ -526,7 +537,7 @@ class VectorComponentRewriter(CSECachingMapperMixin, CallExternalRecMapper): # {{{ sum sign grouper -class SumSignGrouper(CSECachingMapperMixin, CallExternalRecMapper): +class SumSignGrouper(CSECachingIdentityMapper, CallExternalRecMapper): """Anti-cancellation cargo-cultism.""" def map_sum(self, expr, *args): @@ -564,7 +575,7 @@ class SumSignGrouper(CSECachingMapperMixin, CallExternalRecMapper): # }}} -class MathConstantRewriter(CSECachingMapperMixin, CallExternalRecMapper): +class MathConstantRewriter(CSECachingIdentityMapper, CallExternalRecMapper): def map_variable(self, expr, *args): if expr.name == "pi": return prim.Variable("M_PI") @@ -609,10 +620,11 @@ def combine_mappers(*mappers): continue all_methods[method_name].append((mapper, method)) - class CombinedMapper(CSECachingMapperMixin, IdentityMapper): + class CombinedMapper(CSECachingIdentityMapper): def __init__(self, all_methods): self.all_methods = all_methods - map_common_subexpression_uncached = IdentityMapper.map_common_subexpression + map_common_subexpression_uncached = \ + IdentityMapper.map_common_subexpression def _map(method_name, self, expr, rec_self=None, *args): if method_name not in self.all_methods: @@ -632,9 +644,13 @@ def combine_mappers(*mappers): types.MethodType(partial(_map, method_name), combine_mapper)) return combine_mapper +# }}} + -def to_loopy_insns(assignments, vector_names=set(), pymbolic_expr_maps=[], - complex_dtype=None, retain_names=set()): +# {{{ to-loopy conversion + +def to_loopy_insns(assignments, vector_names=frozenset(), pymbolic_expr_maps=(), + complex_dtype=None, retain_names=frozenset()): logger.info("loopy instruction generation: start") assignments = list(assignments) @@ -649,11 +665,12 @@ def to_loopy_insns(assignments, vector_names=set(), pymbolic_expr_maps=[], ssg = SumSignGrouper() bik = BigIntegerKiller() cmr = ComplexRewriter() - hks = HankelSubstitutor() + + cmb_mapper = combine_mappers(bdr, btog, vcr, pwr, ssg, bik, cmr) if 0: # https://github.com/inducer/sumpy/pull/40#issuecomment-852635444 - cmb_mapper = combine_mappers(bdr, btog, vcr, pwr, ssg, bik, cmr, hks) + cmb_mapper = combine_mappers(bdr, btog, vcr, pwr, ssg, bik, cmr) else: def cmb_mapper(expr): expr = bdr(expr) @@ -662,7 +679,6 @@ def to_loopy_insns(assignments, vector_names=set(), pymbolic_expr_maps=[], expr = ssg(expr) expr = bik(expr) expr = cmr(expr) - expr = hks(expr) expr = btog(expr) return expr @@ -674,16 +690,21 @@ def to_loopy_insns(assignments, vector_names=set(), pymbolic_expr_maps=[], return expr assignments = [(name, convert_expr(name, expr)) for name, expr in assignments] - bessel_sub = BesselSubstitutor(btog.bessel_j_arg_to_top_order) + from pytools import UniqueNameGenerator + name_gen = UniqueNameGenerator(set([name for name, expr in assignments])) + + result = [] + bessel_sub = BesselSubstitutor( + name_gen, btog.bessel_j_arg_to_top_order, + result) import loopy as lp from pytools import MinRecursionLimit with MinRecursionLimit(3000): - result = [ - lp.Assignment(id=None, + for name, expr in assignments: + result.append(lp.Assignment(id=None, assignee=name, expression=bessel_sub(expr), - temp_var_type=lp.Optional(None)) - for name, expr in assignments] + temp_var_type=lp.Optional(None))) logger.info("loopy instruction generation: done") return result diff --git a/sumpy/kernel.py b/sumpy/kernel.py index af771ceb..635f7699 100644 --- a/sumpy/kernel.py +++ b/sumpy/kernel.py @@ -542,13 +542,8 @@ class HelmholtzKernel(ExpressionKernel): self.dim, self.helmholtz_k_name) def prepare_loopy_kernel(self, loopy_knl): - from sumpy.codegen import (bessel_preamble_generator, bessel_mangler) - loopy_knl = lp.register_function_manglers(loopy_knl, - [bessel_mangler]) - loopy_knl = lp.register_preamble_generators(loopy_knl, - [bessel_preamble_generator]) - - return loopy_knl + from sumpy.codegen import register_bessel_callables + return register_bessel_callables(loopy_knl) def get_args(self): if self.allow_evanescent: @@ -635,13 +630,8 @@ class YukawaKernel(ExpressionKernel): self.dim, self.yukawa_lambda_name) def prepare_loopy_kernel(self, loopy_knl): - from sumpy.codegen import (bessel_preamble_generator, bessel_mangler) - loopy_knl = lp.register_function_manglers(loopy_knl, - [bessel_mangler]) - loopy_knl = lp.register_preamble_generators(loopy_knl, - [bessel_preamble_generator]) - - return loopy_knl + from sumpy.codegen import register_bessel_callables + return register_bessel_callables(loopy_knl) def get_args(self): return [ -- GitLab