diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py index 1406ad995d6a4a49d5c0b877996743c8874650a8..ae1802a34280bbc50f8f4457b43fdc374cbb2744 100644 --- a/loopy/codegen/expression.py +++ b/loopy/codegen/expression.py @@ -52,6 +52,9 @@ class TypeInferenceMapper(CombineMapper): self.temporary_variables = temporary_variables + # /!\ Introduce caches with care--numpy.float32(x) and numpy.float64(x) + # are Python-equal. + def combine(self, dtypes): dtypes = list(dtypes) @@ -163,16 +166,6 @@ class TypeInferenceMapper(CombineMapper): def map_reduction(self, expr): return expr.operation.result_dtype(self.rec(expr.expr), expr.inames) - # {{{ use caching - - @memoize_method - def __call__(self, expr): - return CombineMapper.__call__(self, expr) - - rec = __call__ - - # }}} - # }}} # {{{ C code mapper @@ -261,7 +254,7 @@ class LoopyCCodeMapper(RecursiveMapper): else: return s - def rec(self, expr, prec, type_context, needed_dtype=None): + def rec(self, expr, prec, type_context=None, needed_dtype=None): if needed_dtype is None: return RecursiveMapper.rec(self, expr, prec, type_context) diff --git a/loopy/kernel.py b/loopy/kernel.py index f944c3e6c88e32954ac853a26adb2320b911cdfc..d9fd7c274a5c4a0c715b2cdb435705d241573ad6 100644 --- a/loopy/kernel.py +++ b/loopy/kernel.py @@ -535,6 +535,9 @@ def opencl_function_mangler(name, arg_dtypes): "sinh", "cosh", "tanh"]: return arg_dtype, "%s_%s" % (tpname, name) + if name in ["real", "imag"]: + return np.dtype(arg_dtype.type(0).real), "%s_%s" % (tpname, name) + if name == "dot": scalar_dtype, offset, field_name = arg_dtypes[0].fields["s0"] return scalar_dtype, name diff --git a/loopy/symbolic.py b/loopy/symbolic.py index a7571512a90de985e5a04086de0f5827e3809c68..cbb664a272f81bfa4e95fd77b5ebd493b8fb46a6 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -55,6 +55,9 @@ from pymbolic.parser import Parser as ParserBase import islpy as isl from islpy import dim_type +import re +import numpy as np + @@ -345,12 +348,30 @@ class FunctionToPrimitiveMapper(IdentityMapper): _open_dbl_bracket = intern("open_dbl_bracket") _close_dbl_bracket = intern("close_dbl_bracket") +TRAILING_FLOAT_TAG_RE = re.compile("^(.*?)([a-zA-Z]*)$") + class LoopyParser(ParserBase): lex_table = [ (_open_dbl_bracket, pytools.lex.RE(r"\[\[")), (_close_dbl_bracket, pytools.lex.RE(r"\]\]")), ] + ParserBase.lex_table + def parse_float(self, s): + match = TRAILING_FLOAT_TAG_RE.match(s) + + val = match.group(1) + tag = frozenset(match.group(2)) + if tag == frozenset("j"): + return np.float64(val)*np.complex128(1j) + elif tag == frozenset("jf"): + return np.float32(val)*np.complex64(1j) + elif tag == frozenset("f"): + return np.float32(val) + elif tag == frozenset("d"): + return np.float64(val) + else: + return float(val) # generic float + def parse_postfix(self, pstate, min_precedence, left_exp): from pymbolic.parser import _PREC_CALL if pstate.next_tag() is _open_dbl_bracket and _PREC_CALL > min_precedence: diff --git a/test/test_loopy.py b/test/test_loopy.py index e3d2880befba9382288105923271267982344b48..84f8e0c21dd8c53d3308ed091f3d1b80c4bc6c7c 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -63,6 +63,32 @@ def test_type_inference_no_artificial_doubles(ctx_factory): +def test_sized_and_complex_literals(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel(ctx.devices[0], + "{[i]: 0<=i<n}", + """ + <> aa = 5jf + <> bb = 5j + a[i] = imag(aa) + b[i] = imag(bb) + c[i] = 5f + """, + [ + lp.GlobalArg("a", np.float32, shape=("n",)), + lp.GlobalArg("b", np.float32, shape=("n",)), + lp.GlobalArg("c", np.float32, shape=("n",)), + lp.ValueArg("n", np.int32), + ], + assumptions="n>=1") + + lp.auto_test_vs_ref(knl, ctx, lp.generate_loop_schedules(knl), + parameters=dict(n=5)) + + + + def test_simple_side_effect(ctx_factory): ctx = ctx_factory()