diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py index adbd3e55afcdeaca669130e4f51cfac95afc8658..9ba97a48880294e07e77d89eae8b740056a0c1f4 100644 --- a/loopy/codegen/expression.py +++ b/loopy/codegen/expression.py @@ -22,10 +22,10 @@ class TypeInferenceMapper(CombineMapper): if result.isbuiltin and other.isbuiltin: result = (np.empty(0, dtype=result) + np.empty(0, dtype=other)).dtype elif result.isbuiltin and not other.isbuiltin: - # assume the non-natiev type takes over + # assume the non-native type takes over result = other elif not result.isbuiltin and other.isbuiltin: - # assume the non-natiev type takes over + # assume the non-native type takes over pass else: if not result is other: @@ -40,6 +40,15 @@ class TypeInferenceMapper(CombineMapper): def map_subscript(self, expr): return self.rec(expr.aggregate) + def map_call(self, expr): + from pymbolic.primitives import Variable + if isinstance(expr.function, Variable): + name = expr.function.name + arg_dtypes = tuple(self.rec(par) for par in expr.parameters) + return self.kernel.get_function_result_dtype(name, arg_dtypes) + else: + return CombineMapper.map_call(self, expr) + def map_variable(self, expr): try: return self.kernel.arg_dict[expr.name].dtype diff --git a/loopy/kernel.py b/loopy/kernel.py index bef788c0ded57803e6db4bd25380f8534ed80269..2cb427bd3e794f0bcfb070985ac9e24444400f61 100644 --- a/loopy/kernel.py +++ b/loopy/kernel.py @@ -477,6 +477,14 @@ def parse_reduction_op(name): # {{{ loop kernel object +def _default_get_function_result_dtype(name, arg_dtypes): + if len(arg_dtypes) == 1: + dtype, = arg_dtypes + return dtype + else: + raise RuntimeError("no type inference information on " + "function '%s'" % name) + class LoopKernel(Record): """ :ivar device: :class:`pyopencl.Device` @@ -504,6 +512,7 @@ class LoopKernel(Record): :ivar applied_substitutions: A list of past substitution dictionaries that were applied to the kernel. These are stored so that they may be repeated on expressions the user specifies later. + :ivar get_function_result_dtype: :ivar cache_manager: @@ -521,7 +530,8 @@ class LoopKernel(Record): local_sizes={}, iname_to_tag={}, iname_to_tag_requests=None, substitutions={}, cache_manager=None, lowest_priority_inames=[], breakable_inames=set(), - applied_substitutions=[]): + applied_substitutions=[], + get_function_result_dtype=_default_get_function_result_dtype): """ :arg domain: a :class:`islpy.BasicSet`, or a string parseable to a basic set by the isl. Example: "{[i,j]: 0<=i < 10 and 0<= j < 9}" @@ -718,7 +728,8 @@ class LoopKernel(Record): cache_manager=cache_manager, lowest_priority_inames=lowest_priority_inames, breakable_inames=breakable_inames, - applied_substitutions=applied_substitutions) + applied_substitutions=applied_substitutions, + get_function_result_dtype=get_function_result_dtype) def make_unique_instruction_id(self, insns=None, based_on="insn", extra_used_ids=set()): if insns is None: