From 7529d1a45b2fff9a17464bb70044c3b3f2642e38 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 29 Mar 2012 01:34:55 -0400
Subject: [PATCH] Let TypeInferenceMapper take its own temporary_variables
 array. (Also muck with codegen for constants.)

---
 loopy/codegen/expression.py | 27 +++++++++++++++++----------
 1 file changed, 17 insertions(+), 10 deletions(-)

diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py
index 9ba97a488..ea0cd4958 100644
--- a/loopy/codegen/expression.py
+++ b/loopy/codegen/expression.py
@@ -9,8 +9,12 @@ from pymbolic.mapper import CombineMapper
 # {{{ type inference
 
 class TypeInferenceMapper(CombineMapper):
-    def __init__(self, kernel):
+    def __init__(self, kernel, temporary_variables=None):
         self.kernel = kernel
+        if temporary_variables is None:
+            temporary_variables = kernel.temporary_variables
+
+        self.temporary_variables = temporary_variables
 
     def combine(self, dtypes):
         dtypes = list(dtypes)
@@ -56,7 +60,7 @@ class TypeInferenceMapper(CombineMapper):
             pass
 
         try:
-            return self.kernel.temporary_variables[expr.name].dtype
+            return self.temporary_variables[expr.name].dtype
         except KeyError:
             pass
 
@@ -72,15 +76,8 @@ class TypeInferenceMapper(CombineMapper):
 class LoopyCCodeMapper(CCodeMapper):
     def __init__(self, kernel, cse_name_list=[], var_subst_map={},
             with_annotation=False, allow_complex=False):
-        def constant_mapper(c):
-            if isinstance(c, float):
-                # FIXME: type-variable
-                return "%sf" % repr(c)
-            else:
-                return repr(c)
 
-        CCodeMapper.__init__(self, constant_mapper=constant_mapper,
-                cse_name_list=cse_name_list)
+        CCodeMapper.__init__(self, cse_name_list=cse_name_list)
         self.kernel = kernel
         self.infer_type = TypeInferenceMapper(kernel)
         self.allow_complex = allow_complex
@@ -217,6 +214,16 @@ class LoopyCCodeMapper(CCodeMapper):
 
     map_max = map_min
 
+    def map_constant(self, expr, enclosing_prec):
+        if isinstance(expr, complex):
+            # FIXME: type-variable
+            return "(cdouble_t) (%s, %s)" % (repr(expr.real), repr(expr.imag))
+        elif isinstance(expr, float):
+            # FIXME: type-variable
+            return "%s" % repr(expr)
+        else:
+            return CCodeMapper.map_constant(self, expr, enclosing_prec)
+
     # {{{ deal with complex-valued variables
 
     def complex_type_name(self, dtype):
-- 
GitLab