diff --git a/sumpy/codegen.py b/sumpy/codegen.py index b5d1c4ef83d3a4a156334b6007245efa1feb96a0..9c8a4958713800d37c18babd57f2426340a36222 100644 --- a/sumpy/codegen.py +++ b/sumpy/codegen.py @@ -80,6 +80,7 @@ class SympyToPymbolicMapper(SympyToPymbolicMapperBase): # }}} + # {{{ bessel handling BESSEL_PREAMBLE = """//CL// @@ -215,6 +216,7 @@ def bessel_mangler(kernel, identifier, arg_dtypes): class BesselGetter(object): def __init__(self, bessel_j_arg_to_top_order): self.bessel_j_arg_to_top_order = bessel_j_arg_to_top_order + self.cse_cache = {} @memoize_method def hank1_01(self, arg): @@ -224,26 +226,30 @@ class BesselGetter(object): def bessel_jv_two(self, order, arg): return prim.Variable("bessel_jv_two")(order, arg) + def wrap_in_cse(self, expr, prefix): + cse = prim.wrap_in_cse(expr, prefix) + return self.cse_cache.setdefault(expr, cse) + @memoize_method def hankel_1(self, order, arg): if order == 0: - return prim.Lookup( - prim.CommonSubexpression(self.hank1_01(arg), "hank1_01_result"), - "order0") + return self.wrap_in_cse( + prim.Lookup(self.hank1_01(arg), "order0"), + "hank1_01_result") elif order == 1: - return prim.Lookup( - prim.CommonSubexpression(self.hank1_01(arg), "hank1_01_result"), - "order1") + return self.wrap_in_cse( + prim.Lookup(self.hank1_01(arg), "order1"), + "hank1_01_result") elif order < 0: # AS (9.1.6) nu = -order - return prim.wrap_in_cse( - (-1)**nu * self.hankel_1(nu, arg), + return self.wrap_in_cse( + (-1) ** nu * self.hankel_1(nu, arg), "hank1_neg%d" % nu) elif order > 1: # AS (9.1.27) nu = order-1 - return prim.CommonSubexpression( + return self.wrap_in_cse( 2*nu/arg*self.hankel_1(nu, arg) - self.hankel_1(nu-1, arg), "hank1_%d" % order) @@ -254,22 +260,24 @@ class BesselGetter(object): def bessel_j(self, order, arg): top_order = self.bessel_j_arg_to_top_order[arg] - bessel_two = ( - prim.CommonSubexpression(self.bessel_jv_two(top_order-1, arg), - "bessel_jv_two_result")) - if order == top_order: - return prim.Lookup(bessel_two, "jvp1") + return self.wrap_in_cse( + prim.Lookup(self.bessel_jv_two(top_order-1, arg), "jvp1"), + "bessel_jv_two_result") elif order == top_order-1: - return prim.Lookup(bessel_two, "jv") + return self.wrap_in_cse( + prim.Lookup(self.bessel_jv_two(top_order-1, arg), "jv"), + "bessel_jv_two_result") elif order < 0: - return (-1)**order*self.bessel_j(-order, arg) + return self.wrap_in_cse( + (-1)**order*self.bessel_j(-order, arg), + "bessel_j_neg%d" % -order) else: assert abs(order) < top_order # AS (9.1.27) nu = order+1 - return prim.CommonSubexpression( + return self.wrap_in_cse( 2*nu/arg*self.bessel_j(nu, arg) - self.bessel_j(nu+1, arg), "bessel_j_%d" % order)