From 009b410f0c453cb29eca8701c39edd36d536baae Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Tue, 17 Apr 2012 20:44:16 -0400 Subject: [PATCH] Allow ()-shaped global args that work properly. --- loopy/codegen/expression.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py index 5435a3b4f..54bb64f48 100644 --- a/loopy/codegen/expression.py +++ b/loopy/codegen/expression.py @@ -63,15 +63,17 @@ class TypeInferenceMapper(CombineMapper): pass try: - result = self.temporary_variables[expr.name].dtype + tv = self.temporary_variables[expr.name] except KeyError: + # name is not a temporary variable, ok pass else: from loopy import infer_type - if result is infer_type: + if tv.dtype is infer_type: raise TypeInferenceFailure("attempted type inference on " "variable requiring type inference") - return result + + return tv.dtype if expr.name in self.kernel.all_inames(): return np.dtype(np.int16) # don't force single-precision upcast @@ -135,6 +137,12 @@ class LoopyCCodeMapper(CCodeMapper): else: return str(self.rec(self.var_subst_map[expr.name], prec)) else: + if expr.name in self.kernel.arg_dict: + arg = self.kernel.arg_dict[expr.name] + from loopy.kernel import _ShapedArg + if isinstance(arg, _ShapedArg) and arg.shape == (): + return "*"+expr.name + return CCodeMapper.map_variable(self, expr, prec) def map_tagged_variable(self, expr, enclosing_prec): @@ -183,6 +191,9 @@ class LoopyCCodeMapper(CCodeMapper): expr.aggregate.name, expr, len(index_expr), len(ary_strides))) + if len(index_expr) == 0: + return "*" + expr.aggregate.name + from pymbolic.primitives import Subscript return CCodeMapper.map_subscript(self, Subscript(expr.aggregate, arg.offset+sum( -- GitLab