diff --git a/loopy/kernel.py b/loopy/kernel.py index 43d73114ec524f96b7e1a0f0f54204eac7bfcc6a..4bd9750b33e0448ee80f420d9d550ae90144153d 100644 --- a/loopy/kernel.py +++ b/loopy/kernel.py @@ -396,7 +396,13 @@ class ReductionOperation(object): def dtype(self, inames): raise NotImplementedError - def get_preambles(self, inames, c_code_mapper): + def get_function_result_dtype_getter(self): + """If the reduction declares any functions, return a getter that + makes their return types known to type inference. + """ + return None + + def get_preambles(self, inames): return [] def neutral_element(self, inames): @@ -473,8 +479,16 @@ class _ArgExtremumReductionOperation(ReductionOperation): def dtype(self, inames): return self.struct_dtype - # No need to make type inference aware of our functions: Their results - # always get assigned directly to typed temporaries without any arithmetic. + def get_function_result_dtype_getter(self): + names = [self.prefix+"_init", self.prefix+"_update"] + + def getter(name, arg_dtypes): + if name in names: + return self.struct_dtype + + return None + + return getter def get_preambles(self, inames): """Returns a tuple (preamble_key, preamble), where *preamble* is a string diff --git a/loopy/preprocess.py b/loopy/preprocess.py index f49832ed80008533b5bcbcc9b41234e42d5180ee..82a67450efe2ca558d5e7c103dd4fe3139553e9a 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -6,6 +6,27 @@ import pyopencl.characterize as cl_char +# {{{ gather dtype getters from reduction operations + +def gather_dtype_getters(kernel): + dtype_getters = kernel.function_result_dtype_getters + + def gather_from_reduction(expr, rec): + red_getter = expr.operation.get_function_result_dtype_getter() + if red_getter is not None: + dtype_getters.append(red_getter) + + rec(expr.expr) + + from loopy.symbolic import ReductionCallbackMapper + rcm = ReductionCallbackMapper(gather_from_reduction) + for insn in kernel.instructions: + rcm(insn.expression) + + return kernel.copy(function_result_dtype_getters=dtype_getters) + +# }}} + # {{{ infer types of temporaries def infer_types_of_temporaries(kernel): @@ -762,6 +783,8 @@ def adjust_local_temp_var_storage(kernel): def preprocess_kernel(kernel): + kernel = gather_dtype_getters(kernel) + # all type inference must happen *after* this point (because only then all # the functions return dtype getters are available.) diff --git a/test/test_loopy.py b/test/test_loopy.py index 717ec31f197f1f73ca6b51841b222c42cd9cf887..b799e88036a4e10e9af89a330777b11448f8737e 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -153,6 +153,7 @@ def test_eq_constraint(ctx_factory): def test_argmax(ctx_factory): dtype = np.dtype(np.float32) ctx = ctx_factory() + queue = cl.CommandQueue(ctx) order = "C" n = 10000 @@ -170,13 +171,12 @@ def test_argmax(ctx_factory): lp.GlobalArg("max_val", dtype, shape=(), order=order), ]) - seq_knl = knl + a = np.random.randn(10000).astype(dtype) + cknl = lp.CompiledKernel(ctx, knl) + evt, (max_idx, max_val) = cknl(queue, a=a) + assert max_val == np.max(np.abs(a)) + assert max_idx == np.where(np.abs(a)==max_val)[-1] - kernel_gen = lp.generate_loop_schedules(knl) - kernel_gen = lp.check_kernels(kernel_gen, {}) - - lp.auto_test_vs_ref(seq_knl, ctx, kernel_gen, - codegen_kwargs=dict(allow_complex=True))