diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py index 4700e1df716e82ac392790b3cf9e670d1cd33598..c877236595d55f02c2357aeae25a62f6c1f957a3 100644 --- a/loopy/codegen/expression.py +++ b/loopy/codegen/expression.py @@ -34,7 +34,6 @@ from pymbolic.mapper import CombineMapper import islpy as isl import pyopencl as cl import pyopencl.array -from pytools import memoize_method # {{{ type inference @@ -45,12 +44,18 @@ class DependencyTypeInferenceFailure(TypeInferenceFailure): pass class TypeInferenceMapper(CombineMapper): - def __init__(self, kernel, temporary_variables=None): + def __init__(self, kernel, new_assignments=None): + """ + :arg new_assignments: mapping from names to either + :class:`loopy.kernel.data.TemporaryVariable` + or + :class:`loopy.kernel.data.KernelArgument` + instances + """ self.kernel = kernel - if temporary_variables is None: - temporary_variables = kernel.temporary_variables - - self.temporary_variables = temporary_variables + if new_assignments is None: + new_assignments = {} + self.new_assignments = new_assignments # /!\ Introduce caches with care--numpy.float32(x) and numpy.float64(x) # are Python-equal. @@ -156,24 +161,6 @@ class TypeInferenceMapper(CombineMapper): "function '%s'" % identifier) def map_variable(self, expr): - try: - return self.kernel.arg_dict[expr.name].dtype - except KeyError: - pass - - try: - tv = self.temporary_variables[expr.name] - except KeyError: - # name is not a temporary variable, ok - pass - else: - import loopy as lp - if tv.dtype is lp.auto: - raise DependencyTypeInferenceFailure("attempted type inference on " - "variable requiring type inference") - - return tv.dtype - if expr.name in self.kernel.all_inames(): return self.kernel.index_dtype @@ -183,7 +170,39 @@ class TypeInferenceMapper(CombineMapper): result_dtype, _ = result return result_dtype - raise TypeInferenceFailure("nothing known about '%s'" % expr.name) + obj = self.new_assignments.get(expr.name) + + if obj is None: + obj = self.kernel.arg_dict.get(expr.name) + + if obj is None: + obj = self.kernel.temporary_variables.get(expr.name) + + if obj is None: + raise TypeInferenceFailure("name not known in type inference: %s" + % expr.name) + + from loopy.kernel.data import TemporaryVariable, KernelArgument + import loopy as lp + if isinstance(obj, TemporaryVariable): + result = obj.dtype + if result is lp.auto: + raise DependencyTypeInferenceFailure( + "temporary variable '%s'" % expr.name) + else: + return result + + elif isinstance(obj, KernelArgument): + result = obj.dtype + if result is None: + raise DependencyTypeInferenceFailure( + "argument '%s'" % expr.name) + else: + return result + + else: + raise RuntimeError("unexpected type inference " + "object type for '%s'" % expr.name) map_tagged_variable = map_variable @@ -214,8 +233,7 @@ def dtype_to_type_context(dtype): return 'd' if dtype in [np.float32, np.complex64]: return 'f' - from pyopencl.array import vec - if dtype in vec.types.values(): + if dtype in cl.array.vec.types.values(): return dtype_to_type_context(dtype.fields["x"][0]) return None