From 29ce9e11626bf4050924fa910b34f04c0a22dfec Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 11 May 2016 23:20:31 -0500 Subject: [PATCH] Type inference fixes for multiple assignment --- loopy/expression.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/loopy/expression.py b/loopy/expression.py index 42c54af71..6fd49661d 100644 --- a/loopy/expression.py +++ b/loopy/expression.py @@ -213,17 +213,19 @@ class TypeInferenceMapper(CombineMapper): mangle_result = self.kernel.mangle_function(identifier, arg_dtypes) if multiple_types_ok: - return mangle_result.result_dtypes + if mangle_result is not None: + return mangle_result.result_dtypes else: - if len(mangle_result.result_dtypes) != 1 and not multiple_types_ok: - raise LoopyError("functions with more or fewer than one " - "return value may only be used in direct assignments") - if mangle_result is not None: + if len(mangle_result.result_dtypes) != 1 and not multiple_types_ok: + raise LoopyError("functions with more or fewer than one " + "return value may only be used in direct assignments") + return mangle_result.result_dtypes[0] - raise RuntimeError("no type inference information on " - "function '%s'" % identifier) + raise RuntimeError("unable to resolve " + "function '%s' with %d given arguments" + % (identifier, len(arg_dtypes))) def map_variable(self, expr): if expr.name in self.kernel.all_inames(): @@ -274,7 +276,8 @@ class TypeInferenceMapper(CombineMapper): def map_lookup(self, expr): agg_result = self.rec(expr.aggregate) - dtype, offset = agg_result.numpy_dtype.fields[expr.name] + field = agg_result.numpy_dtype.fields[expr.name] + dtype = field[0] return NumpyType(dtype) def map_comparison(self, expr): -- GitLab