Skip to content
Snippets Groups Projects
Commit 009b410f authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Allow ()-shaped global args that work properly.

parent 4342e827
No related branches found
No related tags found
No related merge requests found
......@@ -63,15 +63,17 @@ class TypeInferenceMapper(CombineMapper):
pass
try:
result = self.temporary_variables[expr.name].dtype
tv = self.temporary_variables[expr.name]
except KeyError:
# name is not a temporary variable, ok
pass
else:
from loopy import infer_type
if result is infer_type:
if tv.dtype is infer_type:
raise TypeInferenceFailure("attempted type inference on "
"variable requiring type inference")
return result
return tv.dtype
if expr.name in self.kernel.all_inames():
return np.dtype(np.int16) # don't force single-precision upcast
......@@ -135,6 +137,12 @@ class LoopyCCodeMapper(CCodeMapper):
else:
return str(self.rec(self.var_subst_map[expr.name], prec))
else:
if expr.name in self.kernel.arg_dict:
arg = self.kernel.arg_dict[expr.name]
from loopy.kernel import _ShapedArg
if isinstance(arg, _ShapedArg) and arg.shape == ():
return "*"+expr.name
return CCodeMapper.map_variable(self, expr, prec)
def map_tagged_variable(self, expr, enclosing_prec):
......@@ -183,6 +191,9 @@ class LoopyCCodeMapper(CCodeMapper):
expr.aggregate.name, expr,
len(index_expr), len(ary_strides)))
if len(index_expr) == 0:
return "*" + expr.aggregate.name
from pymbolic.primitives import Subscript
return CCodeMapper.map_subscript(self,
Subscript(expr.aggregate, arg.offset+sum(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment