From 53fa902132a05d37beb61caa8e666fe7f6d44e4b Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 22 Oct 2012 00:15:56 -0400
Subject: [PATCH] Add support for complex and float literals, real(), imag().

---
 loopy/codegen/expression.py | 15 ++++-----------
 loopy/kernel.py             |  3 +++
 loopy/symbolic.py           | 21 +++++++++++++++++++++
 test/test_loopy.py          | 26 ++++++++++++++++++++++++++
 4 files changed, 54 insertions(+), 11 deletions(-)

diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py
index 1406ad995..ae1802a34 100644
--- a/loopy/codegen/expression.py
+++ b/loopy/codegen/expression.py
@@ -52,6 +52,9 @@ class TypeInferenceMapper(CombineMapper):
 
         self.temporary_variables = temporary_variables
 
+    # /!\ Introduce caches with care--numpy.float32(x) and numpy.float64(x)
+    # are Python-equal.
+
     def combine(self, dtypes):
         dtypes = list(dtypes)
 
@@ -163,16 +166,6 @@ class TypeInferenceMapper(CombineMapper):
     def map_reduction(self, expr):
         return expr.operation.result_dtype(self.rec(expr.expr), expr.inames)
 
-    # {{{ use caching
-
-    @memoize_method
-    def __call__(self, expr):
-        return CombineMapper.__call__(self, expr)
-
-    rec = __call__
-
-    # }}}
-
 # }}}
 
 # {{{ C code mapper
@@ -261,7 +254,7 @@ class LoopyCCodeMapper(RecursiveMapper):
         else:
             return s
 
-    def rec(self, expr, prec, type_context, needed_dtype=None):
+    def rec(self, expr, prec, type_context=None, needed_dtype=None):
         if needed_dtype is None:
             return RecursiveMapper.rec(self, expr, prec, type_context)
 
diff --git a/loopy/kernel.py b/loopy/kernel.py
index f944c3e6c..d9fd7c274 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -535,6 +535,9 @@ def opencl_function_mangler(name, arg_dtypes):
                     "sinh", "cosh", "tanh"]:
                 return arg_dtype, "%s_%s" % (tpname, name)
 
+            if name in ["real", "imag"]:
+                return np.dtype(arg_dtype.type(0).real), "%s_%s" % (tpname, name)
+
     if name == "dot":
         scalar_dtype, offset, field_name = arg_dtypes[0].fields["s0"]
         return scalar_dtype, name
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index a7571512a..cbb664a27 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -55,6 +55,9 @@ from pymbolic.parser import Parser as ParserBase
 import islpy as isl
 from islpy import dim_type
 
+import re
+import numpy as np
+
 
 
 
@@ -345,12 +348,30 @@ class FunctionToPrimitiveMapper(IdentityMapper):
 _open_dbl_bracket = intern("open_dbl_bracket")
 _close_dbl_bracket = intern("close_dbl_bracket")
 
+TRAILING_FLOAT_TAG_RE = re.compile("^(.*?)([a-zA-Z]*)$")
+
 class LoopyParser(ParserBase):
     lex_table = [
             (_open_dbl_bracket, pytools.lex.RE(r"\[\[")),
             (_close_dbl_bracket, pytools.lex.RE(r"\]\]")),
             ] + ParserBase.lex_table
 
+    def parse_float(self, s):
+        match = TRAILING_FLOAT_TAG_RE.match(s)
+
+        val = match.group(1)
+        tag = frozenset(match.group(2))
+        if tag == frozenset("j"):
+            return np.float64(val)*np.complex128(1j)
+        elif tag == frozenset("jf"):
+            return np.float32(val)*np.complex64(1j)
+        elif tag == frozenset("f"):
+            return np.float32(val)
+        elif tag == frozenset("d"):
+            return np.float64(val)
+        else:
+            return float(val) # generic float
+
     def parse_postfix(self, pstate, min_precedence, left_exp):
         from pymbolic.parser import _PREC_CALL
         if pstate.next_tag() is _open_dbl_bracket and _PREC_CALL > min_precedence:
diff --git a/test/test_loopy.py b/test/test_loopy.py
index e3d2880be..84f8e0c21 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -63,6 +63,32 @@ def test_type_inference_no_artificial_doubles(ctx_factory):
 
 
 
+def test_sized_and_complex_literals(ctx_factory):
+    ctx = ctx_factory()
+
+    knl = lp.make_kernel(ctx.devices[0],
+            "{[i]: 0<=i<n}",
+            """
+                <> aa = 5jf
+                <> bb = 5j
+                a[i] = imag(aa)
+                b[i] = imag(bb)
+                c[i] = 5f
+                """,
+            [
+                lp.GlobalArg("a", np.float32, shape=("n",)),
+                lp.GlobalArg("b", np.float32, shape=("n",)),
+                lp.GlobalArg("c", np.float32, shape=("n",)),
+                lp.ValueArg("n", np.int32),
+                ],
+            assumptions="n>=1")
+
+    lp.auto_test_vs_ref(knl, ctx, lp.generate_loop_schedules(knl),
+            parameters=dict(n=5))
+
+
+
+
 def test_simple_side_effect(ctx_factory):
     ctx = ctx_factory()
 
-- 
GitLab