From 29a2fbd5d83fed8598fbaede1063c9145906f322 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Andreas=20Kl=C3=B6ckner?= <inform@tiker.net>
Date: Mon, 7 Jun 2021 14:07:10 -0500
Subject: [PATCH] Adapt codegen to loopy kernel callables (#62)

* Adapt codegen to loopy kernel callables

* Factor Bessel callable registration into a separate function, use in Yukawa

* Add transfer_requirements_git_urls from sumpy to downstream projects

* Point req.txt for loopy back to main
---
 .github/workflows/ci.yml |   9 +-
 sumpy/codegen.py         | 423 ++++++++++++++++++++-------------------
 sumpy/kernel.py          |  18 +-
 3 files changed, 231 insertions(+), 219 deletions(-)

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index e99a76b0..e0f76370 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -88,19 +88,20 @@ jobs:
             env:
                 DOWNSTREAM_PROJECT: ${{ matrix.downstream_project }}
             run: |
-                if [[ "$DOWNSTREAM_PROJECT" = "pytential" ]] && [[ "$GITHUB_HEAD_REF" = "multiplier" ]]; then
-                  git clone "https://github.com/isuruf/$DOWNSTREAM_PROJECT.git" -b "remover"
+                curl -L -O https://tiker.net/ci-support-v0
+                . ./ci-support-v0
+                if [[ "$DOWNSTREAM_PROJECT" = "pytential" ]] && [[ "$GITHUB_HEAD_REF" = "derivtaker" ]]; then
+                  git clone "https://github.com/isuruf/$DOWNSTREAM_PROJECT.git" -b "$GITHUB_HEAD_REF"
                 else
                   git clone "https://github.com/inducer/$DOWNSTREAM_PROJECT.git"
                 fi
                 cd "$DOWNSTREAM_PROJECT"
                 echo "*** $DOWNSTREAM_PROJECT version: $(git rev-parse --short HEAD)"
+                transfer_requirements_git_urls ../requirements.txt ./requirements.txt
                 sed -i "/egg=sumpy/ c git+file://$(readlink -f ..)#egg=sumpy" requirements.txt
                 export CONDA_ENVIRONMENT=.test-conda-env-py3.yml
                 # Avoid slow or complicated tests in downstream projects
                 export PYTEST_ADDOPTS="-k 'not (slowtest or octave or mpi)'"
-                curl -L -O -k https://gitlab.tiker.net/inducer/ci-support/raw/main/ci-support.sh
-                . ./ci-support.sh
                 build_py_project_in_conda_env
                 test_py_project
 
diff --git a/sumpy/codegen.py b/sumpy/codegen.py
index 8a49c0c7..4da7b637 100644
--- a/sumpy/codegen.py
+++ b/sumpy/codegen.py
@@ -21,18 +21,15 @@ THE SOFTWARE.
 """
 
 
+import re
+
 import numpy as np
-import pyopencl as cl
-import pyopencl.tools  # noqa
 import loopy as lp
-
-import re
+from loopy.kernel.instruction import make_assignment
 
 from pymbolic.mapper import IdentityMapper, CSECachingMapperMixin
 import pymbolic.primitives as prim
 
-from loopy.types import NumpyType
-
 from pytools import memoize_method
 
 from sumpy.symbolic import (SympyToPymbolicMapper as SympyToPymbolicMapperBase)
@@ -73,136 +70,149 @@ class SympyToPymbolicMapper(SympyToPymbolicMapperBase):
 # }}}
 
 
-# {{{ bessel handling
+# {{{ bessel -> loopy codegen
 
 BESSEL_PREAMBLE = """//CL//
 #include <pyopencl-bessel-j.cl>
 #include <pyopencl-bessel-y.cl>
 #include <pyopencl-bessel-j-complex.cl>
 
-typedef struct bessel_j_two_result_str
-{
-    cdouble_t jv, jvp1;
-} bessel_j_two_result;
-
-bessel_j_two_result bessel_jv_two(int v, double z)
+double bessel_jv_two(int v, double z, double *jvp1)
 {
-    bessel_j_two_result result;
-    result.jv = cdouble_fromreal(bessel_jv(v, z));
-    result.jvp1 = cdouble_fromreal(bessel_jv(v+1, z));
-    return result;
+    *jvp1 = bessel_jv(v+1, z);
+    return bessel_jv(v, z);
 }
 
-bessel_j_two_result bessel_jv_two_complex(int v, cdouble_t z)
+cdouble_t bessel_jv_two_complex(int v, cdouble_t z, cdouble_t *jvp1)
 {
-    bessel_j_two_result result;
-    bessel_j_complex(v, z, &result.jv, &result.jvp1);
-    return result;
+    cdouble_t jv;
+    bessel_j_complex(v, z, &jv, jvp1);
+    return jv;
 }
 """
 
 HANKEL_PREAMBLE = """//CL//
+#include <pyopencl-bessel-j.cl>
+#include <pyopencl-bessel-y.cl>
 #include <pyopencl-hankel-complex.cl>
 
-typedef struct hank1_01_result_str
-{
-    cdouble_t order0, order1;
-} hank1_01_result;
-
-hank1_01_result hank1_01(double z)
+cdouble_t hank1_01(double z, cdouble_t *order1)
 {
-    hank1_01_result result;
-    result.order0 = cdouble_new(bessel_j0(z), bessel_y0(z));
-    result.order1 = cdouble_new(bessel_j1(z), bessel_y1(z));
-    return result;
+    *order1 = cdouble_new(bessel_j1(z), bessel_y1(z));
+    return cdouble_new(bessel_j0(z), bessel_y0(z));
 }
 
-hank1_01_result hank1_01_complex(cdouble_t z)
+cdouble_t hank1_01_complex(cdouble_t z, cdouble_t *order1)
 {
-    hank1_01_result result;
-    hankel_01_complex(z, &result.order0, &result.order1, 1);
-    return result;
+    cdouble_t order0;
+    hankel_01_complex(z, &order0, order1, 1);
+    return order0;
 }
 """
 
 
-def bessel_preamble_generator(preamble_info):
-    from loopy.target.pyopencl import PyOpenCLTarget
-    if not isinstance(preamble_info.kernel.target, PyOpenCLTarget):
-        raise NotImplementedError("Only the PyOpenCLTarget is supported as of now")
+class BesselJvvp1(lp.ScalarCallable):
+    def with_types(self, arg_id_to_dtype, clbl_inf_ctx):
+        from loopy.types import NumpyType
 
-    require_bessel = False
-    if any(func.name == "hank1_01" for func in preamble_info.seen_functions):
-        yield ("50-sumpy-hankel", HANKEL_PREAMBLE)
-        require_bessel = True
-    if (require_bessel
-            or any(func.name == "bessel_jv_two"
-                for func in preamble_info.seen_functions)):
-        yield ("40-sumpy-bessel", BESSEL_PREAMBLE)
+        for i in arg_id_to_dtype:
+            if not (-2 <= i <= 1):
+                raise TypeError(f"{self.name} can only take 2 arguments.")
 
+        if (arg_id_to_dtype.get(0) is None) or (arg_id_to_dtype.get(1) is None):
+            # not enough info about input types
+            return self, clbl_inf_ctx
 
-hank1_01_result_dtype = cl.tools.get_or_register_dtype("hank1_01_result",
-        NumpyType(np.dtype([
-            ("order0", np.complex128),
-            ("order1", np.complex128),
-            ])),
-        )
+        n_dtype = arg_id_to_dtype[0]
+        z_dtype = arg_id_to_dtype[1]
 
-bessel_j_two_result_dtype = cl.tools.get_or_register_dtype("bessel_j_two_result",
-        NumpyType(np.dtype([
-            ("jv", np.complex128),
-            ("jvp1", np.complex128),
-            ])),
-        )
+        # *technically* takes a float, but let's not worry about that.
+        if n_dtype.numpy_dtype.kind != "i":
+            raise TypeError(f"{self.name} expects an integer first argument")
 
+        if z_dtype.numpy_dtype.kind == "c":
+            return (self.copy(name_in_target="bessel_jv_two_complex",
+                              arg_id_to_dtype={
+                                  -2: NumpyType(np.complex128),
+                                  -1: NumpyType(np.complex128),
+                                  0: NumpyType(np.int32),
+                                  1: NumpyType(np.complex128),
+                                  }),
+                    clbl_inf_ctx)
+        else:
+            return (self.copy(name_in_target="bessel_jv_two",
+                              arg_id_to_dtype={
+                                  -2: NumpyType(np.float64),
+                                  -1: NumpyType(np.float64),
+                                  0: NumpyType(np.int32),
+                                  1: NumpyType(np.float64),
+                                  }),
+                    clbl_inf_ctx)
+
+    def generate_preambles(self, target):
+        from loopy import PyOpenCLTarget
+        if not isinstance(target, PyOpenCLTarget):
+            raise NotImplementedError("Only the PyOpenCLTarget is supported as"
+                                      "of now.")
 
-def bessel_mangler(kernel, identifier, arg_dtypes):
-    """A function "mangler" to make Bessel functions
-    digestible for :mod:`loopy`.
+        yield ("40-sumpy-bessel", BESSEL_PREAMBLE)
 
-    See argument *function_manglers* to :func:`loopy.make_kernel`.
-    """
 
-    from loopy.target.pyopencl import PyOpenCLTarget
-    if not isinstance(kernel.target, PyOpenCLTarget):
-        raise NotImplementedError("Only the PyOpenCLTarget is supported as of now")
-
-    if identifier == "hank1_01":
-        if arg_dtypes[0].is_complex():
-            identifier = "hank1_01_complex"
-            return lp.CallMangleInfo(
-                    target_name=identifier,
-                    result_dtypes=(NumpyType(np.dtype(hank1_01_result_dtype)),),
-                    arg_dtypes=(
-                        NumpyType(np.dtype(np.complex128)),
-                        ))
-        else:
-            return lp.CallMangleInfo(
-                    target_name=identifier,
-                    result_dtypes=(NumpyType(np.dtype(hank1_01_result_dtype)),),
-                    arg_dtypes=(
-                        NumpyType(np.dtype(np.float64)),
-                        ))
-
-    elif identifier == "bessel_jv_two":
-        if arg_dtypes[1].is_complex():
-            identifier = "bessel_jv_two_complex"
-            return lp.CallMangleInfo(
-                    target_name=identifier,
-                    result_dtypes=(NumpyType(np.dtype(bessel_j_two_result_dtype)),),
-                    arg_dtypes=(
-                        NumpyType(np.dtype(np.int32)),
-                        NumpyType(np.dtype(np.complex128)),))
+class Hankel1_01(lp.ScalarCallable):  # noqa: N801
+    def with_types(self, arg_id_to_dtype, clbl_inf_ctx):
+        from loopy.types import NumpyType
+
+        for i in arg_id_to_dtype:
+            if not (-2 <= i <= 0):
+                raise TypeError(f"{self.name} can only take one argument.")
+
+        if arg_id_to_dtype.get(0) is None:
+            # not enough info about input types
+            return self, clbl_inf_ctx
+
+        z_dtype = arg_id_to_dtype[0]
+
+        if z_dtype.numpy_dtype.kind == "c":
+            return (self.copy(name_in_target="hank1_01_complex",
+                              arg_id_to_dtype={
+                                  -2: NumpyType(np.complex128),
+                                  -1: NumpyType(np.complex128),
+                                  0: NumpyType(np.complex128),
+                                  }),
+                    clbl_inf_ctx)
         else:
-            return lp.CallMangleInfo(
-                    target_name=identifier,
-                    result_dtypes=(NumpyType(np.dtype(bessel_j_two_result_dtype)),),
-                    arg_dtypes=(
-                        NumpyType(np.dtype(np.int32)),
-                        NumpyType(np.dtype(np.float64)),))
+            return (self.copy(name_in_target="hank1_01",
+                              arg_id_to_dtype={
+                                  -2: NumpyType(np.complex128),
+                                  -1: NumpyType(np.complex128),
+                                  0: NumpyType(np.float64),
+                                  }),
+                    clbl_inf_ctx)
+
+    def generate_preambles(self, target):
+        from loopy import PyOpenCLTarget
+        if not isinstance(target, PyOpenCLTarget):
+            raise NotImplementedError("Only the PyOpenCLTarget is supported as"
+                                      "of now.")
+
+        yield ("50-sumpy-hankel", HANKEL_PREAMBLE)
+
+
+def register_bessel_callables(loopy_knl):
+    from sumpy.codegen import BesselJvvp1, Hankel1_01
+    loopy_knl = lp.register_callable(loopy_knl, "bessel_jvvp1",
+            BesselJvvp1("bessel_jvvp1"))
+    loopy_knl = lp.register_callable(loopy_knl, "hank1_01",
+            Hankel1_01("hank1_01"))
+    return loopy_knl
+
+# }}}
 
-    else:
-        return None
+
+# {{{ custom mapper base classes
+
+class CSECachingIdentityMapper(IdentityMapper, CSECachingMapperMixin):
+    pass
 
 
 class CallExternalRecMapper(IdentityMapper):
@@ -212,8 +222,12 @@ class CallExternalRecMapper(IdentityMapper):
         else:
             return super().rec(expr, *args, **kwargs)
 
+# }}}
+
 
-class BesselTopOrderGatherer(CSECachingMapperMixin, CallExternalRecMapper):
+# {{{ bessel handling
+
+class BesselTopOrderGatherer(CSECachingIdentityMapper, CallExternalRecMapper):
     """This mapper walks the expression tree to find the highest-order
     Bessel J being used, so that all other Js can be computed by the
     (stable) downward recurrence.
@@ -230,13 +244,13 @@ class BesselTopOrderGatherer(CSECachingMapperMixin, CallExternalRecMapper):
             self.bessel_j_arg_to_top_order[arg] = max(
                     self.bessel_j_arg_to_top_order.get(arg, 0),
                     abs(order))
-        return IdentityMapper.map_call(rec_self if rec_self else self,
+        return CSECachingIdentityMapper.map_call(rec_self or self,
                 expr, rec_self, *args)
 
     map_common_subexpression_uncached = IdentityMapper.map_common_subexpression
 
 
-class BesselDerivativeReplacer(CSECachingMapperMixin, CallExternalRecMapper):
+class BesselDerivativeReplacer(CSECachingIdentityMapper, CallExternalRecMapper):
     def map_call(self, expr, rec_self=None, *args):
         call = expr
 
@@ -262,98 +276,56 @@ class BesselDerivativeReplacer(CSECachingMapperMixin, CallExternalRecMapper):
                         for idx, i in enumerate(range(order-k, order+k+1, 2))),
                     "d%d_%s_%s" % (n_derivs, function.name, order_str))
         else:
-            return IdentityMapper.map_call(
-                    rec_self if rec_self else self, expr, rec_self, *args)
-
-    map_common_subexpression_uncached = IdentityMapper.map_common_subexpression
-
-
-class HankelSubstitutor(CSECachingMapperMixin, CallExternalRecMapper):
-    def map_call(self, expr, rec_self=None, *args):
-        if isinstance(expr.function, prim.Variable):
-            name = expr.function.name
-            if name == "hankel_1":
-                order, arg = expr.parameters
-                return self.hankel_1(order, self.rec(arg,
-                    rec_self, *args))
-
-        return IdentityMapper.map_call(rec_self if rec_self else self, expr)
-
-    def hank1_01(self, arg):
-        return prim.Variable("hank1_01")(arg)
-
-    def wrap_in_cse(self, expr, prefix):
-        return prim.wrap_in_cse(expr, prefix)
-
-    @memoize_method
-    def hankel_1(self, order, arg):
-        if order == 0:
-            return self.wrap_in_cse(
-                    prim.Lookup(self.hank1_01(arg), "order0"),
-                    "hank1_01_result")
-        elif order == 1:
-            return self.wrap_in_cse(
-                    prim.Lookup(self.hank1_01(arg), "order1"),
-                    "hank1_01_result")
-        elif order < 0:
-            # AS (9.1.6)
-            nu = -order
-            return self.wrap_in_cse(
-                    (-1) ** nu * self.hankel_1(nu, arg),
-                    "hank1_neg%d" % nu)
-        elif order > 1:
-            # AS (9.1.27)
-            nu = order-1
-            return self.wrap_in_cse(
-                    2*nu/arg*self.hankel_1(nu, arg)
-                    - self.hankel_1(nu-1, arg),
-                    "hank1_%d" % order)
-        else:
-            assert False
+            return CSECachingIdentityMapper.map_call(
+                    rec_self or self, expr, rec_self, *args)
 
     map_common_subexpression_uncached = IdentityMapper.map_common_subexpression
 
 
-class BesselSubstitutor(CSECachingMapperMixin, IdentityMapper):
-    def __init__(self, bessel_j_arg_to_top_order):
+class BesselSubstitutor(CSECachingIdentityMapper):
+    def __init__(self, name_gen, bessel_j_arg_to_top_order, assignments):
+        self.name_gen = name_gen
         self.bessel_j_arg_to_top_order = bessel_j_arg_to_top_order
         self.cse_cache = {}
+        self.assignments = assignments
 
-    def __call__(self, expr, *args, **kwargs):
-        if not self.bessel_j_arg_to_top_order:
-            return expr
-        return super().__call__(expr, *args, **kwargs)
-
-    def map_call(self, expr, rec_self=None, *args):
+    def map_call(self, expr, *args):
         if isinstance(expr.function, prim.Variable):
             name = expr.function.name
             if name == "bessel_j":
                 order, arg = expr.parameters
-                return self.bessel_j(order, self.rec(arg,
-                    rec_self, *args))
-
-        return IdentityMapper.map_call(rec_self if rec_self else self, expr)
+                return self.bessel_j(order, self.rec(arg, *args))
+            elif name == "hankel_1":
+                order, arg = expr.parameters
+                return self.hankel_1(order, self.rec(arg, *args))
 
-    @memoize_method
-    def bessel_jv_two(self, order, arg):
-        return prim.Variable("bessel_jv_two")(order, arg)
+        return super().map_call(expr)
 
     def wrap_in_cse(self, expr, prefix):
         cse = prim.wrap_in_cse(expr, prefix)
         return self.cse_cache.setdefault(expr, cse)
 
+    # {{{ bessel implementation
+
+    @memoize_method
+    def bessel_jv_two(self, order, arg):
+        name_om1 = self.name_gen("bessel_%d" % (order-1))
+        name_o = self.name_gen("bessel_%d" % order)
+        self.assignments.append(
+                make_assignment(
+                    (prim.Variable(name_om1), prim.Variable(name_o),),
+                    prim.Variable("bessel_jvvp1")(order, arg),
+                    temp_var_types=(lp.Optional(None),)*2))
+
+        return prim.Variable(name_om1), prim.Variable(name_o)
+
     @memoize_method
     def bessel_j(self, order, arg):
         top_order = self.bessel_j_arg_to_top_order[arg]
-
         if order == top_order:
-            return self.wrap_in_cse(
-                    prim.Lookup(self.bessel_jv_two(top_order-1, arg), "jvp1"),
-                    "bessel_jv_two_result")
+            return self.bessel_jv_two(top_order-1, arg)[1]
         elif order == top_order-1:
-            return self.wrap_in_cse(
-                    prim.Lookup(self.bessel_jv_two(top_order-1, arg), "jv"),
-                    "bessel_jv_two_result")
+            return self.bessel_jv_two(top_order-1, arg)[0]
         elif order < 0:
             return self.wrap_in_cse(
                     (-1)**order*self.bessel_j(-order, arg),
@@ -368,6 +340,45 @@ class BesselSubstitutor(CSECachingMapperMixin, IdentityMapper):
                     - self.bessel_j(nu+1, arg),
                     "bessel_j_%d" % order)
 
+    # }}}
+
+    # {{{ hankel implementation
+
+    @memoize_method
+    def hank1_01(self, arg):
+        name_0 = self.name_gen("hank1_0")
+        name_1 = self.name_gen("hank1_1")
+        self.assignments.append(
+                make_assignment(
+                    (prim.Variable(name_0), prim.Variable(name_1),),
+                    prim.Variable("hank1_01")(arg),
+                    temp_var_types=(lp.Optional(None),)*2))
+        return prim.Variable(name_0), prim.Variable(name_1)
+
+    @memoize_method
+    def hankel_1(self, order, arg):
+        if order == 0:
+            return self.hank1_01(arg)[0]
+        elif order == 1:
+            return self.hank1_01(arg)[1]
+        elif order < 0:
+            # AS (9.1.6)
+            nu = -order
+            return self.wrap_in_cse(
+                    (-1) ** nu * self.hankel_1(nu, arg),
+                    "hank1_neg%d" % nu)
+        elif order > 1:
+            # AS (9.1.27)
+            nu = order-1
+            return self.wrap_in_cse(
+                    2*nu/arg*self.hankel_1(nu, arg)
+                    - self.hankel_1(nu-1, arg),
+                    "hank1_%d" % order)
+        else:
+            raise AssertionError()
+
+    # }}}
+
     map_common_subexpression_uncached = IdentityMapper.map_common_subexpression
 
 # }}}
@@ -375,7 +386,7 @@ class BesselSubstitutor(CSECachingMapperMixin, IdentityMapper):
 
 # {{{ power rewriter
 
-class PowerRewriter(CSECachingMapperMixin, CallExternalRecMapper):
+class PowerRewriter(CSECachingIdentityMapper, CallExternalRecMapper):
     def map_power(self, expr, rec_self=None, *args):
         exp = expr.exponent
         if isinstance(exp, int):
@@ -418,7 +429,7 @@ class PowerRewriter(CSECachingMapperMixin, CallExternalRecMapper):
 
                 return self.rec(new_base**p, rec_self, *args)
 
-        return IdentityMapper.map_power(rec_self if rec_self else self, expr)
+        return CSECachingIdentityMapper.map_power(rec_self or self, expr)
 
     map_common_subexpression_uncached = IdentityMapper.map_common_subexpression
 
@@ -430,11 +441,11 @@ class PowerRewriter(CSECachingMapperMixin, CallExternalRecMapper):
 from loopy.tools import is_integer
 
 
-class BigIntegerKiller(CSECachingMapperMixin, CallExternalRecMapper):
+class BigIntegerKiller(CSECachingIdentityMapper, CallExternalRecMapper):
 
     def __init__(self, warn_on_digit_loss=True, int_type=np.int64,
             float_type=np.float64):
-        IdentityMapper.__init__(self)
+        super().__init__()
         self.warn = warn_on_digit_loss
         self.float_type = float_type
         self.iinfo = np.iinfo(int_type)
@@ -468,22 +479,22 @@ class BigIntegerKiller(CSECachingMapperMixin, CallExternalRecMapper):
 
 # {{{ convert 123000000j to 123000000 * 1j
 
-class ComplexRewriter(CSECachingMapperMixin, CallExternalRecMapper):
+class ComplexRewriter(CSECachingIdentityMapper, CallExternalRecMapper):
 
     def __init__(self, float_type=np.float32):
-        IdentityMapper.__init__(self)
+        super().__init__()
         self.float_type = float_type
 
     def map_constant(self, expr, rec_self=None, *args, **kwargs):
         """Convert complex values not within complex64 to a product for loopy
         """
         if not isinstance(expr, complex):
-            return IdentityMapper.map_constant(
-                    rec_self if rec_self else self, expr)
+            return CSECachingIdentityMapper.map_constant(
+                    rec_self or self, expr)
 
         if complex(self.float_type(expr.imag)) == expr.imag:
-            return IdentityMapper.map_constant(
-                    rec_self if rec_self else self, expr)
+            return CSECachingIdentityMapper.map_constant(
+                    rec_self or self, expr)
 
         # avoid cycles
         if expr == 1j:
@@ -501,10 +512,10 @@ class ComplexRewriter(CSECachingMapperMixin, CallExternalRecMapper):
 INDEXED_VAR_RE = re.compile("^([a-zA-Z_]+)([0-9]+)$")
 
 
-class VectorComponentRewriter(CSECachingMapperMixin, CallExternalRecMapper):
+class VectorComponentRewriter(CSECachingIdentityMapper, CallExternalRecMapper):
     """For names in name_whitelist, turn ``a3`` into ``a[3]``."""
 
-    def __init__(self, name_whitelist=set()):
+    def __init__(self, name_whitelist=frozenset()):
         self.name_whitelist = name_whitelist
 
     def map_variable(self, expr, *args):
@@ -526,7 +537,7 @@ class VectorComponentRewriter(CSECachingMapperMixin, CallExternalRecMapper):
 
 # {{{ sum sign grouper
 
-class SumSignGrouper(CSECachingMapperMixin, CallExternalRecMapper):
+class SumSignGrouper(CSECachingIdentityMapper, CallExternalRecMapper):
     """Anti-cancellation cargo-cultism."""
 
     def map_sum(self, expr, *args):
@@ -564,7 +575,7 @@ class SumSignGrouper(CSECachingMapperMixin, CallExternalRecMapper):
 # }}}
 
 
-class MathConstantRewriter(CSECachingMapperMixin, CallExternalRecMapper):
+class MathConstantRewriter(CSECachingIdentityMapper, CallExternalRecMapper):
     def map_variable(self, expr, *args):
         if expr.name == "pi":
             return prim.Variable("M_PI")
@@ -609,10 +620,11 @@ def combine_mappers(*mappers):
                 continue
             all_methods[method_name].append((mapper, method))
 
-    class CombinedMapper(CSECachingMapperMixin, IdentityMapper):
+    class CombinedMapper(CSECachingIdentityMapper):
         def __init__(self, all_methods):
             self.all_methods = all_methods
-        map_common_subexpression_uncached = IdentityMapper.map_common_subexpression
+        map_common_subexpression_uncached = \
+                IdentityMapper.map_common_subexpression
 
     def _map(method_name, self, expr, rec_self=None, *args):
         if method_name not in self.all_methods:
@@ -632,9 +644,13 @@ def combine_mappers(*mappers):
                 types.MethodType(partial(_map, method_name), combine_mapper))
     return combine_mapper
 
+# }}}
+
 
-def to_loopy_insns(assignments, vector_names=set(), pymbolic_expr_maps=[],
-                   complex_dtype=None, retain_names=set()):
+# {{{ to-loopy conversion
+
+def to_loopy_insns(assignments, vector_names=frozenset(), pymbolic_expr_maps=(),
+                   complex_dtype=None, retain_names=frozenset()):
     logger.info("loopy instruction generation: start")
     assignments = list(assignments)
 
@@ -649,11 +665,12 @@ def to_loopy_insns(assignments, vector_names=set(), pymbolic_expr_maps=[],
     ssg = SumSignGrouper()
     bik = BigIntegerKiller()
     cmr = ComplexRewriter()
-    hks = HankelSubstitutor()
+
+    cmb_mapper = combine_mappers(bdr, btog, vcr, pwr, ssg, bik, cmr)
 
     if 0:
         # https://github.com/inducer/sumpy/pull/40#issuecomment-852635444
-        cmb_mapper = combine_mappers(bdr, btog, vcr, pwr, ssg, bik, cmr, hks)
+        cmb_mapper = combine_mappers(bdr, btog, vcr, pwr, ssg, bik, cmr)
     else:
         def cmb_mapper(expr):
             expr = bdr(expr)
@@ -662,7 +679,6 @@ def to_loopy_insns(assignments, vector_names=set(), pymbolic_expr_maps=[],
             expr = ssg(expr)
             expr = bik(expr)
             expr = cmr(expr)
-            expr = hks(expr)
             expr = btog(expr)
             return expr
 
@@ -674,16 +690,21 @@ def to_loopy_insns(assignments, vector_names=set(), pymbolic_expr_maps=[],
         return expr
 
     assignments = [(name, convert_expr(name, expr)) for name, expr in assignments]
-    bessel_sub = BesselSubstitutor(btog.bessel_j_arg_to_top_order)
+    from pytools import UniqueNameGenerator
+    name_gen = UniqueNameGenerator(set([name for name, expr in assignments]))
+
+    result = []
+    bessel_sub = BesselSubstitutor(
+            name_gen, btog.bessel_j_arg_to_top_order,
+            result)
 
     import loopy as lp
     from pytools import MinRecursionLimit
     with MinRecursionLimit(3000):
-        result = [
-                lp.Assignment(id=None,
+        for name, expr in assignments:
+            result.append(lp.Assignment(id=None,
                     assignee=name, expression=bessel_sub(expr),
-                    temp_var_type=lp.Optional(None))
-                for name, expr in assignments]
+                    temp_var_type=lp.Optional(None)))
 
     logger.info("loopy instruction generation: done")
     return result
diff --git a/sumpy/kernel.py b/sumpy/kernel.py
index af771ceb..635f7699 100644
--- a/sumpy/kernel.py
+++ b/sumpy/kernel.py
@@ -542,13 +542,8 @@ class HelmholtzKernel(ExpressionKernel):
                 self.dim, self.helmholtz_k_name)
 
     def prepare_loopy_kernel(self, loopy_knl):
-        from sumpy.codegen import (bessel_preamble_generator, bessel_mangler)
-        loopy_knl = lp.register_function_manglers(loopy_knl,
-                [bessel_mangler])
-        loopy_knl = lp.register_preamble_generators(loopy_knl,
-                [bessel_preamble_generator])
-
-        return loopy_knl
+        from sumpy.codegen import register_bessel_callables
+        return register_bessel_callables(loopy_knl)
 
     def get_args(self):
         if self.allow_evanescent:
@@ -635,13 +630,8 @@ class YukawaKernel(ExpressionKernel):
                 self.dim, self.yukawa_lambda_name)
 
     def prepare_loopy_kernel(self, loopy_knl):
-        from sumpy.codegen import (bessel_preamble_generator, bessel_mangler)
-        loopy_knl = lp.register_function_manglers(loopy_knl,
-                [bessel_mangler])
-        loopy_knl = lp.register_preamble_generators(loopy_knl,
-                [bessel_preamble_generator])
-
-        return loopy_knl
+        from sumpy.codegen import register_bessel_callables
+        return register_bessel_callables(loopy_knl)
 
     def get_args(self):
         return [
-- 
GitLab