diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py index a270944016c9c69163be5a94df7e9fc7cfbd4feb..3ccb5d1edc2874953ee4808d649b9649d9317472 100644 --- a/loopy/codegen/expression.py +++ b/loopy/codegen/expression.py @@ -52,7 +52,14 @@ class TypeInferenceMapper(CombineMapper): if isinstance(expr.function, Variable): name = expr.function.name arg_dtypes = tuple(self.rec(par) for par in expr.parameters) - return self.kernel.get_function_result_dtype(name, arg_dtypes) + for rdg in self.kernel.function_result_dtype_getters: + result = rdg(name, arg_dtypes) + if result is not None: + return result + + raise RuntimeError("no type inference information on " + "function '%s'" % name) + else: return CombineMapper.map_call(self, expr) diff --git a/loopy/kernel.py b/loopy/kernel.py index d74a9d7ba8f0713b00940a9e91b76fd331752dc2..a1cfe80596513e9666f32d5f38246711fccd3c92 100644 --- a/loopy/kernel.py +++ b/loopy/kernel.py @@ -495,9 +495,8 @@ def _default_get_function_result_dtype(name, arg_dtypes): if len(arg_dtypes) == 1: dtype, = arg_dtypes return dtype - else: - raise RuntimeError("no type inference information on " - "function '%s'" % name) + + return None class LoopKernel(Record): """ @@ -526,7 +525,8 @@ class LoopKernel(Record): :ivar applied_substitutions: A list of past substitution dictionaries that were applied to the kernel. These are stored so that they may be repeated on expressions the user specifies later. - :ivar get_function_result_dtype: + :ivar function_result_dtype_getters: list of functions of signature (name, arg_dtypes) + returning the result dtype of that function. :ivar cache_manager: @@ -545,7 +545,7 @@ class LoopKernel(Record): iname_to_tag={}, iname_to_tag_requests=None, substitutions={}, cache_manager=None, lowest_priority_inames=[], breakable_inames=set(), applied_substitutions=[], - get_function_result_dtype=_default_get_function_result_dtype): + function_result_dtype_getters=[_default_get_function_result_dtype]): """ :arg domain: a :class:`islpy.BasicSet`, or a string parseable to a basic set by the isl. Example: "{[i,j]: 0<=i < 10 and 0<= j < 9}" @@ -750,7 +750,7 @@ class LoopKernel(Record): lowest_priority_inames=lowest_priority_inames, breakable_inames=breakable_inames, applied_substitutions=applied_substitutions, - get_function_result_dtype=get_function_result_dtype) + function_result_dtype_getters=function_result_dtype_getters) def make_unique_instruction_id(self, insns=None, based_on="insn", extra_used_ids=set()): if insns is None: @@ -973,7 +973,7 @@ class LoopKernel(Record): if not all_inames_by_insns <= self.all_inames(): raise RuntimeError("inames collected from instructions (%s) " "that are not present in domain (%s)" - % (", ".join(sorted(all_inames_by_insns)), + % (", ".join(sorted(all_inames_by_insns)), ", ".join(sorted(self.all_inames())))) global_sizes = {}