diff --git a/loopy/kernel.py b/loopy/kernel.py index a1cfe80596513e9666f32d5f38246711fccd3c92..1e74ac2e7cbba91052156c128bcbb6b6a5fb2760 100644 --- a/loopy/kernel.py +++ b/loopy/kernel.py @@ -392,49 +392,159 @@ class Instruction(Record): # {{{ reduction operations class ReductionOperation(object): - """ - :ivar neutral_element: - :ivar dtype: - """ + def dtype(self, inames): + raise NotImplementedError + + def get_preambles(self, inames, c_code_mapper): + return [] + + def neutral_element(self, inames): + raise NotImplementedError def __call__(self, operand1, operand2): raise NotImplementedError -class TypedReductionOperation(ReductionOperation): +class ScalarReductionOperation(ReductionOperation): def __init__(self, dtype): - self.dtype = dtype + self.scalar_dtype = dtype + + def dtype(self, inames): + return self.scalar_dtype def __str__(self): return (type(self).__name__.replace("ReductionOperation", "").lower() + "_" + str(self.dtype)) -class SumReductionOperation(TypedReductionOperation): - neutral_element = 0 +class SumReductionOperation(ScalarReductionOperation): + def neutral_element(self, inames): + return 0 - def __call__(self, operand1, operand2): + def __call__(self, operand1, operand2, index): return operand1 + operand2 -class ProductReductionOperation(TypedReductionOperation): - neutral_element = 1 +class ProductReductionOperation(ScalarReductionOperation): + def neutral_element(self, inames): + return 1 - def __call__(self, operand1, operand2): + def __call__(self, operand1, operand2, index): return operand1 * operand2 -class FloatingPointMaxOperation(TypedReductionOperation): - # OpenCL 1.1, section 6.11.2 - neutral_element = -var("INFINITY") +def get_le_neutral(dtype): + """Return a number y that satisfies (x <= y) for all y.""" - def __call__(self, operand1, operand2): - from pymbolic.primitives import FunctionSymbol - return FunctionSymbol("max")(operand1, operand2) + if dtype.kind == "f": + # OpenCL 1.1, section 6.11.2 + return var("INFINITY") + else: + raise NotImplementedError("less") + +class MaxReductionOperation(ScalarReductionOperation): + def neutral_element(self, inames): + return get_le_neutral(self.dtype) + + def __call__(self, operand1, operand2, index): + return var("max")(operand1, operand2) + +class MinReductionOperation(ScalarReductionOperation): + @property + def neutral_element(self, inames): + return -get_le_neutral(self.dtype) + + def __call__(self, operand1, operand2, index): + return var("min")(operand1, operand2) + + + + +class _ArgExtremumReductionOperation(ReductionOperation): + def __init__(self, dtype): + self.scalar_dtype = dtype + + self.struct_dtype = np.dtype( + [("value", self.scalar_dtype), + ("index", np.int32)]) + + self.prefix = "loopy_arg%s_%s" % (self.which, self.scalar_dtype.type.__name__) + self.type_name = self.prefix + "_result" + from pyopencl.tools import register_dtype + register_dtype(self.struct_dtype, self.type_name, alias_ok=True) + + def dtype(self, inames): + return self.struct_dtype + + # No need to make type inference aware of our functions: Their results + # always get assigned directly to typed temporaries without any arithmetic. + + def get_preambles(self, inames): + """Returns a tuple (preamble_key, preamble), where *preamble* is a string + that goes into the kernel preamble, and *preamble_key* is a unique + identifier. No two preambles with the same key will be emitted. + """ + + if len(inames) != 1: + raise RuntimeError("arg%s must be used with exactly one iname" + % self.which) + + from pyopencl.tools import dtype_to_ctype + from pymbolic.mapper.c_code import CCodeMapper + + c_code_mapper = CCodeMapper() + + return [(self.prefix, """ + typedef struct { + %(scalar_type)s value; + int index; + } %(type_name)s; + + inline %(type_name)s %(prefix)s_init() + { + %(type_name)s result; + result.value = %(neutral)s; + result.index = INT_MIN; + return result; + } + + inline %(type_name)s %(prefix)s_update( + %(type_name)s state, %(scalar_type)s op2, int index) + { + %(type_name)s result; + if (op2 %(comp)s state.value) + { + result.value = op2; + result.index = index; + return result; + } + else return state; + } + """ % dict( + type_name=self.type_name, + scalar_type=dtype_to_ctype(self.scalar_dtype), + prefix=self.prefix, + neutral=c_code_mapper( + self.neutral_sign*get_le_neutral(self.scalar_dtype)), + comp=self.update_comparison, + ))] + + def neutral_element(self, inames): + return var(self.prefix+"_init")() + + def __call__(self, operand1, operand2, inames): + iname, = inames + + return var(self.prefix+"_update")( + operand1, operand2, var(iname)) + +class ArgMaxReductionOperation(_ArgExtremumReductionOperation): + which = "max" + update_comparison = ">=" + neutral_sign = -1 + +class ArgMinReductionOperation(_ArgExtremumReductionOperation): + which = "min" + update_comparison = "<=" + neutral_sign = +1 -class FloatingPointMinOperation(TypedReductionOperation): - # OpenCL 1.1, section 6.11.2 - neutral_element = var("INFINITY") - def __call__(self, operand1, operand2): - from pymbolic.primitives import FunctionSymbol - return FunctionSymbol("min")(operand1, operand2) @@ -442,8 +552,10 @@ class FloatingPointMinOperation(TypedReductionOperation): _REDUCTION_OPS = { "sum": SumReductionOperation, "product": ProductReductionOperation, - "fpmax": FloatingPointMaxOperation, - "fpmin": FloatingPointMinOperation, + "max": MaxReductionOperation, + "min": MinReductionOperation, + "argmax": ArgMaxReductionOperation, + "argmin": ArgMinReductionOperation, } _REDUCTION_OP_PARSERS = [ diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 0030fba0633ce996b0e075c4b50ddc8bb2fae953..3f00f85d6e7302b14c6781c1b67151b21c44e008 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -258,6 +258,7 @@ def realize_reduction(kernel, insn_id_filter=None): new_insns = [] new_temporary_variables = kernel.temporary_variables.copy() + new_preambles = kernel.preambles from loopy.kernel import IlpBaseTag @@ -287,6 +288,8 @@ def realize_reduction(kernel, insn_id_filter=None): # }}} + new_preambles.extend(expr.operation.get_preambles(expr.inames)) + from pymbolic import var target_var_name = kernel.make_unique_var_name("acc_"+"_".join(expr.inames), @@ -302,7 +305,7 @@ def realize_reduction(kernel, insn_id_filter=None): from loopy.kernel import TemporaryVariable new_temporary_variables[target_var_name] = TemporaryVariable( name=target_var_name, - dtype=expr.operation.dtype, + dtype=expr.operation.dtype(expr.inames), shape=tuple(ilp_iname_lengths), is_local=False) @@ -314,7 +317,7 @@ def realize_reduction(kernel, insn_id_filter=None): id=new_id, assignee=target_var, forced_iname_deps=temp_kernel.insn_inames(insn) - set(expr.inames), - expression=expr.operation.neutral_element) + expression=expr.operation.neutral_element(expr.inames)) generated_insns.append(init_insn) @@ -325,7 +328,7 @@ def realize_reduction(kernel, insn_id_filter=None): reduction_insn = Instruction( id=new_id, assignee=target_var, - expression=expr.operation(target_var, expr.expr), + expression=expr.operation(target_var, expr.expr, expr.inames), insn_deps=set([init_insn.id]) | insn.insn_deps, forced_iname_deps=temp_kernel.insn_inames(insn) | set(expr.inames)) diff --git a/test/test_loopy.py b/test/test_loopy.py index d412b48617a4bc36eda4b4fe31d15635730eccfc..717ec31f197f1f73ca6b51841b222c42cd9cf887 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -150,6 +150,37 @@ def test_eq_constraint(ctx_factory): +def test_argmax(ctx_factory): + dtype = np.dtype(np.float32) + ctx = ctx_factory() + order = "C" + + n = 10000 + + knl = lp.make_kernel(ctx.devices[0], + "{[i]: 0<=i<%d}" % n, + [ + "<> result = argmax_float32(i, fabs(a[i]))", + "max_idx = result.index", + "max_val = result.value", + ], + [ + lp.GlobalArg("a", dtype, shape=(n,), order=order), + lp.GlobalArg("max_idx", np.int32, shape=(), order=order), + lp.GlobalArg("max_val", dtype, shape=(), order=order), + ]) + + seq_knl = knl + + kernel_gen = lp.generate_loop_schedules(knl) + kernel_gen = lp.check_kernels(kernel_gen, {}) + + lp.auto_test_vs_ref(seq_knl, ctx, kernel_gen, + codegen_kwargs=dict(allow_complex=True)) + + + + if __name__ == "__main__": import sys if len(sys.argv) > 1: