From 791820ca8dff17d4c1878fee9c8ae1c083782ed6 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 18 Apr 2012 15:46:22 -0400
Subject: [PATCH] Fix type inference for temporaries.

---
 loopy/codegen/expression.py |  5 ++++-
 loopy/preprocess.py         | 28 +++++++++++++++++++---------
 2 files changed, 23 insertions(+), 10 deletions(-)

diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py
index 3ccb5d1ed..d77cdbd2a 100644
--- a/loopy/codegen/expression.py
+++ b/loopy/codegen/expression.py
@@ -11,6 +11,9 @@ from pymbolic.mapper import CombineMapper
 class TypeInferenceFailure(RuntimeError):
     pass
 
+class DependencyTypeInferenceFailure(TypeInferenceFailure):
+    pass
+
 class TypeInferenceMapper(CombineMapper):
     def __init__(self, kernel, temporary_variables=None):
         self.kernel = kernel
@@ -77,7 +80,7 @@ class TypeInferenceMapper(CombineMapper):
         else:
             from loopy import infer_type
             if tv.dtype is infer_type:
-                raise TypeInferenceFailure("attempted type inference on "
+                raise DependencyTypeInferenceFailure("attempted type inference on "
                         "variable requiring type inference")
 
             return tv.dtype
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 3f00f85d6..e010b2e78 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -17,26 +17,35 @@ def infer_types_of_temporaries(kernel):
     for tv in kernel.temporary_variables.itervalues():
         if tv.dtype is infer_type:
             queue.append(tv)
+            new_temp_vars[tv.name] = tv
         else:
             new_temp_vars[tv.name] = tv
 
     from loopy.codegen.expression import (
-            TypeInferenceMapper, TypeInferenceFailure)
-    tim = TypeInferenceMapper(kernel)
+            TypeInferenceMapper, DependencyTypeInferenceFailure)
+    tim = TypeInferenceMapper(kernel, new_temp_vars)
 
     first_failure = None
     while queue:
         tv = queue.pop(0)
 
-        try:
-            dtypes = [
-                    tim(kernel.id_to_insn[writer].expression)
-                    for writer in kernel.writer_map()[tv.name]]
-        except TypeInferenceFailure:
+        dtypes = []
+
+        writers = kernel.writer_map()[tv.name]
+        exprs = [kernel.id_to_insn[w].expression for w in writers]
+
+        for expr in exprs:
+            try:
+                dtypes.append(tim(expr))
+            except DependencyTypeInferenceFailure:
+                pass
+
+        if not dtypes:
             if tv is first_failure:
                 # this has failed before, give up.
-                raise RuntimeError("could not determine type of '%s'"
-                        % tv.name)
+                raise RuntimeError("could not determine type of '%s' from expression(s) '%s'"
+                        % (tv.name,
+                            ", ".join(str(e) for e in exprs)))
 
             if first_failure is None:
                 # remember the first failure for this round through the queue
@@ -44,6 +53,7 @@ def infer_types_of_temporaries(kernel):
 
             # can't infer type yet, put back into queue
             queue.append(tv)
+            continue
 
         from pytools import is_single_valued
         if not is_single_valued(dtypes):
-- 
GitLab