From 0eb507fe119e5458658798143a5c6dd437ee9bbc Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 21 May 2013 16:16:37 -0400
Subject: [PATCH] Generalize TypeInferenceMapper.

---
 loopy/codegen/expression.py | 72 +++++++++++++++++++++++--------------
 1 file changed, 45 insertions(+), 27 deletions(-)

diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py
index 4700e1df7..c87723659 100644
--- a/loopy/codegen/expression.py
+++ b/loopy/codegen/expression.py
@@ -34,7 +34,6 @@ from pymbolic.mapper import CombineMapper
 import islpy as isl
 import pyopencl as cl
 import pyopencl.array
-from pytools import memoize_method
 
 # {{{ type inference
 
@@ -45,12 +44,18 @@ class DependencyTypeInferenceFailure(TypeInferenceFailure):
     pass
 
 class TypeInferenceMapper(CombineMapper):
-    def __init__(self, kernel, temporary_variables=None):
+    def __init__(self, kernel, new_assignments=None):
+        """
+        :arg new_assignments: mapping from names to either
+            :class:`loopy.kernel.data.TemporaryVariable`
+            or
+            :class:`loopy.kernel.data.KernelArgument`
+            instances
+        """
         self.kernel = kernel
-        if temporary_variables is None:
-            temporary_variables = kernel.temporary_variables
-
-        self.temporary_variables = temporary_variables
+        if new_assignments is None:
+            new_assignments = {}
+        self.new_assignments = new_assignments
 
     # /!\ Introduce caches with care--numpy.float32(x) and numpy.float64(x)
     # are Python-equal.
@@ -156,24 +161,6 @@ class TypeInferenceMapper(CombineMapper):
                 "function '%s'" % identifier)
 
     def map_variable(self, expr):
-        try:
-            return self.kernel.arg_dict[expr.name].dtype
-        except KeyError:
-            pass
-
-        try:
-            tv = self.temporary_variables[expr.name]
-        except KeyError:
-            # name is not a temporary variable, ok
-            pass
-        else:
-            import loopy as lp
-            if tv.dtype is lp.auto:
-                raise DependencyTypeInferenceFailure("attempted type inference on "
-                        "variable requiring type inference")
-
-            return tv.dtype
-
         if expr.name in self.kernel.all_inames():
             return self.kernel.index_dtype
 
@@ -183,7 +170,39 @@ class TypeInferenceMapper(CombineMapper):
                 result_dtype, _ = result
                 return result_dtype
 
-        raise TypeInferenceFailure("nothing known about '%s'" % expr.name)
+        obj = self.new_assignments.get(expr.name)
+
+        if obj is None:
+            obj = self.kernel.arg_dict.get(expr.name)
+
+        if obj is None:
+            obj = self.kernel.temporary_variables.get(expr.name)
+
+        if obj is None:
+            raise TypeInferenceFailure("name not known in type inference: %s"
+                    % expr.name)
+
+        from loopy.kernel.data import TemporaryVariable, KernelArgument
+        import loopy as lp
+        if isinstance(obj, TemporaryVariable):
+            result = obj.dtype
+            if result is lp.auto:
+                raise DependencyTypeInferenceFailure(
+                        "temporary variable '%s'" % expr.name)
+            else:
+                return result
+
+        elif isinstance(obj, KernelArgument):
+            result = obj.dtype
+            if result is None:
+                raise DependencyTypeInferenceFailure(
+                        "argument '%s'" % expr.name)
+            else:
+                return result
+
+        else:
+            raise RuntimeError("unexpected type inference "
+                    "object type for '%s'" % expr.name)
 
     map_tagged_variable = map_variable
 
@@ -214,8 +233,7 @@ def dtype_to_type_context(dtype):
         return 'd'
     if dtype in [np.float32, np.complex64]:
         return 'f'
-    from pyopencl.array import vec
-    if dtype in vec.types.values():
+    if dtype in cl.array.vec.types.values():
         return dtype_to_type_context(dtype.fields["x"][0])
 
     return None
-- 
GitLab