diff --git a/sumpy/codegen.py b/sumpy/codegen.py index 61ed73014968d58a5455f006f4405f36d0d41572..bef7fc6bf93a171a96626f9db88c312c23167bf2 100644 --- a/sumpy/codegen.py +++ b/sumpy/codegen.py @@ -581,6 +581,28 @@ class BigIntegerKiller(CSECachingMapperMixin, IdentityMapper): # }}} +# {{{ convert 123000000j to 123000000 * 1j + +class ComplexRewriter(CSECachingMapperMixin, IdentityMapper): + + def __init__(self, float_type=np.float32): + IdentityMapper.__init__(self) + self.float_type = float_type + + def map_constant(self, expr): + """Convert complex values not within complex64 to a product for loopy + """ + if not isinstance(expr, complex): + return IdentityMapper.map_constant(self, expr) + + if complex(self.float_type(expr.imag)) == expr.imag: + return IdentityMapper.map_constant(self, expr) + + return expr.real + prim.Product((expr.imag, 1j)) + + map_common_subexpression_uncached = IdentityMapper.map_common_subexpression + + # {{{ vector component rewriter INDEXED_VAR_RE = re.compile("^([a-zA-Z_]+)([0-9]+)$") @@ -684,6 +706,7 @@ def to_loopy_insns(assignments, vector_names=set(), pymbolic_expr_maps=[], ssg = SumSignGrouper() fck = FractionKiller() bik = BigIntegerKiller() + cmr = ComplexRewriter() def convert_expr(name, expr): logger.debug("generate expression for: %s" % name) @@ -694,6 +717,7 @@ def to_loopy_insns(assignments, vector_names=set(), pymbolic_expr_maps=[], expr = fck(expr) expr = ssg(expr) expr = bik(expr) + expr = cmr(expr) #expr = cse_tag(expr) for m in pymbolic_expr_maps: expr = m(expr)