From 29ce9e11626bf4050924fa910b34f04c0a22dfec Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 11 May 2016 23:20:31 -0500
Subject: [PATCH] Type inference fixes for multiple assignment

---
 loopy/expression.py | 19 +++++++++++--------
 1 file changed, 11 insertions(+), 8 deletions(-)

diff --git a/loopy/expression.py b/loopy/expression.py
index 42c54af71..6fd49661d 100644
--- a/loopy/expression.py
+++ b/loopy/expression.py
@@ -213,17 +213,19 @@ class TypeInferenceMapper(CombineMapper):
 
         mangle_result = self.kernel.mangle_function(identifier, arg_dtypes)
         if multiple_types_ok:
-            return mangle_result.result_dtypes
+            if mangle_result is not None:
+                return mangle_result.result_dtypes
         else:
-            if len(mangle_result.result_dtypes) != 1 and not multiple_types_ok:
-                raise LoopyError("functions with more or fewer than one "
-                        "return value may only be used in direct assignments")
-
             if mangle_result is not None:
+                if len(mangle_result.result_dtypes) != 1 and not multiple_types_ok:
+                    raise LoopyError("functions with more or fewer than one "
+                            "return value may only be used in direct assignments")
+
                 return mangle_result.result_dtypes[0]
 
-        raise RuntimeError("no type inference information on "
-                "function '%s'" % identifier)
+        raise RuntimeError("unable to resolve "
+                "function '%s' with %d given arguments"
+                % (identifier, len(arg_dtypes)))
 
     def map_variable(self, expr):
         if expr.name in self.kernel.all_inames():
@@ -274,7 +276,8 @@ class TypeInferenceMapper(CombineMapper):
 
     def map_lookup(self, expr):
         agg_result = self.rec(expr.aggregate)
-        dtype, offset = agg_result.numpy_dtype.fields[expr.name]
+        field = agg_result.numpy_dtype.fields[expr.name]
+        dtype = field[0]
         return NumpyType(dtype)
 
     def map_comparison(self, expr):
-- 
GitLab