From 009b410f0c453cb29eca8701c39edd36d536baae Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 17 Apr 2012 20:44:16 -0400
Subject: [PATCH] Allow ()-shaped global args that work properly.

---
 loopy/codegen/expression.py | 17 ++++++++++++++---
 1 file changed, 14 insertions(+), 3 deletions(-)

diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py
index 5435a3b4f..54bb64f48 100644
--- a/loopy/codegen/expression.py
+++ b/loopy/codegen/expression.py
@@ -63,15 +63,17 @@ class TypeInferenceMapper(CombineMapper):
             pass
 
         try:
-            result = self.temporary_variables[expr.name].dtype
+            tv = self.temporary_variables[expr.name]
         except KeyError:
+            # name is not a temporary variable, ok
             pass
         else:
             from loopy import infer_type
-            if result is infer_type:
+            if tv.dtype is infer_type:
                 raise TypeInferenceFailure("attempted type inference on "
                         "variable requiring type inference")
-            return result
+
+            return tv.dtype
 
         if expr.name in self.kernel.all_inames():
             return np.dtype(np.int16) # don't force single-precision upcast
@@ -135,6 +137,12 @@ class LoopyCCodeMapper(CCodeMapper):
             else:
                 return str(self.rec(self.var_subst_map[expr.name], prec))
         else:
+            if expr.name in self.kernel.arg_dict:
+                arg = self.kernel.arg_dict[expr.name]
+                from loopy.kernel import _ShapedArg
+                if isinstance(arg, _ShapedArg) and arg.shape == ():
+                    return "*"+expr.name
+
             return CCodeMapper.map_variable(self, expr, prec)
 
     def map_tagged_variable(self, expr, enclosing_prec):
@@ -183,6 +191,9 @@ class LoopyCCodeMapper(CCodeMapper):
                                 expr.aggregate.name, expr,
                                 len(index_expr), len(ary_strides)))
 
+                if len(index_expr) == 0:
+                    return "*" + expr.aggregate.name
+
                 from pymbolic.primitives import Subscript
                 return CCodeMapper.map_subscript(self,
                         Subscript(expr.aggregate, arg.offset+sum(
-- 
GitLab