From c6898ffa48da9ef24acdc65570e44c9aa95de707 Mon Sep 17 00:00:00 2001 From: Matt Wala <wala1@illinois.edu> Date: Sun, 9 Jul 2017 20:24:42 -0500 Subject: [PATCH] Fix argmin and segmented reductions. --- loopy/library/reduction.py | 213 +++++++++++++------------------------ test/test_loopy.py | 41 +++++++ 2 files changed, 115 insertions(+), 139 deletions(-) diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py index f9648bde7..962b31681 100644 --- a/loopy/library/reduction.py +++ b/loopy/library/reduction.py @@ -123,7 +123,7 @@ class ScalarReductionOperation(ReductionOperation): class SumReductionOperation(ScalarReductionOperation): def neutral_element(self, dtype): - return 0 + return dtype.numpy_dtype.type(0) def __call__(self, dtype, operand1, operand2): return operand1 + operand2 @@ -131,7 +131,7 @@ class SumReductionOperation(ScalarReductionOperation): class ProductReductionOperation(ScalarReductionOperation): def neutral_element(self, dtype): - return 1 + return dtype.numpy_dtype.type(1) def __call__(self, dtype, operand1, operand2): return operand1 * operand2 @@ -189,8 +189,26 @@ class MinReductionOperation(ScalarReductionOperation): return var("min")(operand1, operand2) +# {{{ base class for symbolic reduction ops + +class ReductionOpFunction(FunctionIdentifier): + init_arg_names = ("reduction_op",) + + def __init__(self, reduction_op): + self.reduction_op = reduction_op + + def __getinitargs__(self): + return (self.reduction_op,) + +# }}} + + # {{{ segmented reduction +class SegmentedOp(ReductionOpFunction): + pass + + class _SegmentedScalarReductionOperation(ReductionOperation): def __init__(self, **kwargs): self.inner_reduction = self.base_reduction_class(**kwargs) @@ -205,7 +223,9 @@ class _SegmentedScalarReductionOperation(ReductionOperation): segment_flag_dtype.numpy_dtype.type.__name__) def neutral_element(self, scalar_dtype, segment_flag_dtype): - return SegmentedFunction(self, (scalar_dtype, segment_flag_dtype), "init")() + scalar_neutral_element = self.inner_reduction.neutral_element(scalar_dtype) + return var("make_tuple")(scalar_neutral_element, + segment_flag_dtype.numpy_dtype.type(0)) def result_dtypes(self, kernel, scalar_dtype, segment_flag_dtype): return (self.inner_reduction.result_dtypes(kernel, scalar_dtype) @@ -221,7 +241,7 @@ class _SegmentedScalarReductionOperation(ReductionOperation): return type(self) == type(other) def __call__(self, dtypes, operand1, operand2): - return SegmentedFunction(self, dtypes, "update")(*(operand1 + operand2)) + return SegmentedOp(self)(*(operand1 + operand2)) class SegmentedSumReductionOperation(_SegmentedScalarReductionOperation): @@ -236,45 +256,13 @@ class SegmentedProductReductionOperation(_SegmentedScalarReductionOperation): which = "product" -class SegmentedFunction(FunctionIdentifier): - init_arg_names = ("reduction_op", "dtypes", "name") - - def __init__(self, reduction_op, dtypes, name): - """ - :arg dtypes: A :class:`tuple` of `(scalar_dtype, segment_flag_dtype)` - """ - self.reduction_op = reduction_op - self.dtypes = dtypes - self.name = name - - @property - def scalar_dtype(self): - return self.dtypes[0] - - @property - def segment_flag_dtype(self): - return self.dtypes[1] - - def __getinitargs__(self): - return (self.reduction_op, self.dtypes, self.name) - - -def get_segmented_function_preamble(kernel, func_id): +def get_segmented_function_preamble(kernel, func_id, arg_dtypes): op = func_id.reduction_op - prefix = op.prefix(func_id.scalar_dtype, func_id.segment_flag_dtype) - - from pymbolic.mapper.c_code import CCodeMapper - - c_code_mapper = CCodeMapper() + scalar_dtype, segment_flag_dtype = arg_dtypes + prefix = op.prefix(scalar_dtype, segment_flag_dtype) return (prefix, """ - inline %(scalar_t)s %(prefix)s_init(%(segment_flag_t)s *segment_flag_out) - { - *segment_flag_out = 0; - return %(neutral)s; - } - - inline %(scalar_t)s %(prefix)s_update( + inline %(scalar_t)s %(prefix)s_op( %(scalar_t)s op1, %(segment_flag_t)s segment_flag1, %(scalar_t)s op2, %(segment_flag_t)s segment_flag2, %(segment_flag_t)s *segment_flag_out) @@ -283,32 +271,36 @@ def get_segmented_function_preamble(kernel, func_id): return segment_flag2 ? op2 : %(combined)s; } """ % dict( - scalar_t=kernel.target.dtype_to_typename(func_id.scalar_dtype), + scalar_t=kernel.target.dtype_to_typename(scalar_dtype), prefix=prefix, - segment_flag_t=kernel.target.dtype_to_typename( - func_id.segment_flag_dtype), - neutral=c_code_mapper( - op.inner_reduction.neutral_element(func_id.scalar_dtype)), + segment_flag_t=kernel.target.dtype_to_typename(segment_flag_dtype), combined=op.op % ("op1", "op2"), )) - # }}} # {{{ argmin/argmax +class ArgExtOp(ReductionOpFunction): + pass + + class _ArgExtremumReductionOperation(ReductionOperation): def prefix(self, scalar_dtype, index_dtype): return "loopy_arg%s_%s_%s" % (self.which, - index_dtype.numpy_dtype.type.__name__, - scalar_dtype.numpy_dtype.type.__name__) + scalar_dtype.numpy_dtype.type.__name__, + index_dtype.numpy_dtype.type.__name__) def result_dtypes(self, kernel, scalar_dtype, index_dtype): return (scalar_dtype, index_dtype) def neutral_element(self, scalar_dtype, index_dtype): - return ArgExtFunction(self, (scalar_dtype, index_dtype), "init")() + scalar_neutral_func = ( + get_ge_neutral if self.neutral_sign < 0 else get_le_neutral) + scalar_neutral_element = scalar_neutral_func(scalar_dtype) + return var("make_tuple")(scalar_neutral_element, + index_dtype.numpy_dtype.type(-1)) def __str__(self): return self.which @@ -324,7 +316,7 @@ class _ArgExtremumReductionOperation(ReductionOperation): return 2 def __call__(self, dtypes, operand1, operand2): - return ArgExtFunction(self, dtypes, "update")(*(operand1 + operand2)) + return ArgExtOp(self)(*(operand1 + operand2)) class ArgMaxReductionOperation(_ArgExtremumReductionOperation): @@ -339,44 +331,15 @@ class ArgMinReductionOperation(_ArgExtremumReductionOperation): neutral_sign = +1 -class ArgExtFunction(FunctionIdentifier): - init_arg_names = ("reduction_op", "dtypes", "name") - - def __init__(self, reduction_op, dtypes, name): - self.reduction_op = reduction_op - self.dtypes = dtypes - self.name = name - - @property - def scalar_dtype(self): - return self.dtypes[0] - - @property - def index_dtype(self): - return self.dtypes[1] - - def __getinitargs__(self): - return (self.reduction_op, self.dtypes, self.name) - - -def get_argext_preamble(kernel, func_id): +def get_argext_preamble(kernel, func_id, arg_dtypes): op = func_id.reduction_op - prefix = op.prefix(func_id.scalar_dtype, func_id.index_dtype) - - from pymbolic.mapper.c_code import CCodeMapper + scalar_dtype = arg_dtypes[0] + index_dtype = arg_dtypes[1] - c_code_mapper = CCodeMapper() - - neutral = get_ge_neutral if op.neutral_sign < 0 else get_le_neutral + prefix = op.prefix(scalar_dtype, index_dtype) return (prefix, """ - inline %(scalar_t)s %(prefix)s_init(%(index_t)s *index_out) - { - *index_out = INT_MIN; - return %(neutral)s; - } - - inline %(scalar_t)s %(prefix)s_update( + inline %(scalar_t)s %(prefix)s_op( %(scalar_t)s op1, %(index_t)s index1, %(scalar_t)s op2, %(index_t)s index2, %(index_t)s *index_out) @@ -393,10 +356,9 @@ def get_argext_preamble(kernel, func_id): } } """ % dict( - scalar_t=kernel.target.dtype_to_typename(func_id.scalar_dtype), + scalar_t=kernel.target.dtype_to_typename(scalar_dtype), prefix=prefix, - index_t=kernel.target.dtype_to_typename(func_id.index_dtype), - neutral=c_code_mapper(neutral(func_id.scalar_dtype)), + index_t=kernel.target.dtype_to_typename(index_dtype), comp=op.update_comparison, )) @@ -454,76 +416,47 @@ def parse_reduction_op(name): def reduction_function_mangler(kernel, func_id, arg_dtypes): - if isinstance(func_id, ArgExtFunction) and func_id.name == "init": + if isinstance(func_id, ArgExtOp): from loopy.target.opencl import CTarget if not isinstance(kernel.target, CTarget): raise LoopyError("%s: only C-like targets supported for now" % func_id) op = func_id.reduction_op + scalar_dtype = arg_dtypes[0] + index_dtype = arg_dtypes[1] from loopy.kernel.data import CallMangleInfo return CallMangleInfo( - target_name="%s_init" % op.prefix( - func_id.scalar_dtype, func_id.index_dtype), + target_name="%s_op" % op.prefix( + scalar_dtype, index_dtype), result_dtypes=op.result_dtypes( - kernel, func_id.scalar_dtype, func_id.index_dtype), - arg_dtypes=(), - ) - - elif isinstance(func_id, ArgExtFunction) and func_id.name == "update": - from loopy.target.opencl import CTarget - if not isinstance(kernel.target, CTarget): - raise LoopyError("%s: only C-like targets supported for now" % func_id) - - op = func_id.reduction_op - - from loopy.kernel.data import CallMangleInfo - return CallMangleInfo( - target_name="%s_update" % op.prefix( - func_id.scalar_dtype, func_id.index_dtype), - result_dtypes=op.result_dtypes( - kernel, func_id.scalar_dtype, func_id.index_dtype), + kernel, scalar_dtype, index_dtype), arg_dtypes=( - func_id.scalar_dtype, - kernel.index_dtype, - func_id.scalar_dtype, - kernel.index_dtype), - ) - - elif isinstance(func_id, SegmentedFunction) and func_id.name == "init": - from loopy.target.opencl import CTarget - if not isinstance(kernel.target, CTarget): - raise LoopyError("%s: only C-like targets supported for now" % func_id) - - op = func_id.reduction_op - - from loopy.kernel.data import CallMangleInfo - return CallMangleInfo( - target_name="%s_init" % op.prefix( - func_id.scalar_dtype, func_id.segment_flag_dtype), - result_dtypes=op.result_dtypes( - kernel, func_id.scalar_dtype, func_id.segment_flag_dtype), - arg_dtypes=(), + scalar_dtype, + index_dtype, + scalar_dtype, + index_dtype), ) - elif isinstance(func_id, SegmentedFunction) and func_id.name == "update": + elif isinstance(func_id, SegmentedOp): from loopy.target.opencl import CTarget if not isinstance(kernel.target, CTarget): raise LoopyError("%s: only C-like targets supported for now" % func_id) op = func_id.reduction_op + scalar_dtype, segment_flag_dtype = arg_dtypes from loopy.kernel.data import CallMangleInfo return CallMangleInfo( - target_name="%s_update" % op.prefix( - func_id.scalar_dtype, func_id.segment_flag_dtype), + target_name="%s_op" % op.prefix( + scalar_dtype, segment_flag_dtype), result_dtypes=op.result_dtypes( - kernel, func_id.scalar_dtype, func_id.segment_flag_dtype), + kernel, scalar_dtype, segment_flag_dtype), arg_dtypes=( - func_id.scalar_dtype, - func_id.segment_flag_dtype, - func_id.scalar_dtype, - func_id.segment_flag_dtype), + scalar_dtype, + segment_flag_dtype, + scalar_dtype, + segment_flag_dtype), ) return None @@ -533,16 +466,18 @@ def reduction_preamble_generator(preamble_info): from loopy.target.opencl import OpenCLTarget for func in preamble_info.seen_functions: - if isinstance(func.name, ArgExtFunction): + if isinstance(func.name, ArgExtOp): if not isinstance(preamble_info.kernel.target, OpenCLTarget): raise LoopyError("only OpenCL supported for now") - yield get_argext_preamble(preamble_info.kernel, func.name) + yield get_argext_preamble(preamble_info.kernel, func.name, + func.arg_dtypes) - elif isinstance(func.name, SegmentedFunction): + elif isinstance(func.name, SegmentedOp): if not isinstance(preamble_info.kernel.target, OpenCLTarget): raise LoopyError("only OpenCL supported for now") - yield get_segmented_function_preamble(preamble_info.kernel, func.name) + yield get_segmented_function_preamble(preamble_info.kernel, func.name, + func.arg_dtypes) # vim: fdm=marker diff --git a/test/test_loopy.py b/test/test_loopy.py index 21db62610..ad5fd72b6 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2335,6 +2335,47 @@ def test_kernel_var_name_generator(): assert vng("b") != "b" +def test_complex_argmin(ctx_factory): + cl_ctx = ctx_factory() + knl = lp.make_kernel( + "{[ictr,itgt,idim]: " + "0<=itgt<ntargets " + "and 0<=ictr<ncenters " + "and 0<=idim<ambient_dim}", + + """ + for itgt + for ictr + <> dist_sq = sum(idim, + (tgt[idim,itgt] - center[idim,ictr])**2) + <> in_disk = dist_sq < (radius[ictr]*1.05)**2 + <> matches = ( + (in_disk + and qbx_forced_limit == 0) + or (in_disk + and qbx_forced_limit != 0 + and qbx_forced_limit * center_side[ictr] > 0) + ) + + <> post_dist_sq = if(matches, dist_sq, HUGE) + end + <> min_dist_sq, <> min_ictr = argmin(ictr, ictr, post_dist_sq) + + tgt_to_qbx_center[itgt] = if(min_dist_sq < HUGE, min_ictr, -1) + end + """) + + knl = lp.fix_parameters(knl, ambient_dim=2) + knl = lp.add_and_infer_dtypes(knl, { + "tgt,center,radius,HUGE": np.float32, + "center_side,qbx_forced_limit": np.int32, + }) + + lp.auto_test_vs_ref(knl, cl_ctx, knl, parameters={ + "HUGE": 1e20, "ncenters": 200, "ntargets": 300, + "qbx_forced_limit": 1}) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab