diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py index f435820b23e8da909f0cff14ff5a1272874e865f..c61c052c09d39f0181144200d3fd8ecf3ebef877 100644 --- a/loopy/library/reduction.py +++ b/loopy/library/reduction.py @@ -135,13 +135,37 @@ def get_le_neutral(dtype): if dtype.numpy_dtype.kind == "f": # OpenCL 1.1, section 6.11.2 return var("INFINITY") + elif dtype.numpy_dtype.kind == "i": + if dtype.numpy_dtype.itemsize == 4: + #32 bit integer + return var("INT_MAX") + elif dtype.numpy_dtype.itemsize == 8: + #64 bit integer + return var('LONG_MAX') else: raise NotImplementedError("less") +def get_ge_neutral(dtype): + """Return a number y that satisfies (x >= y) for all y.""" + + if dtype.numpy_dtype.kind == "f": + # OpenCL 1.1, section 6.11.2 + return -var("INFINITY") + elif dtype.numpy_dtype.kind == "i": + if dtype.numpy_dtype.itemsize == 4: + #32 bit integer + return var("INT_MIN") + elif dtype.numpy_dtype.itemsize == 8: + #64 bit integer + return var('LONG_MIN') + else: + raise NotImplementedError("less") + + class MaxReductionOperation(ScalarReductionOperation): def neutral_element(self, dtype, inames): - return -get_le_neutral(dtype) + return get_ge_neutral(dtype) def __call__(self, dtype, operand1, operand2, inames): return var("max")(operand1, operand2) @@ -213,6 +237,8 @@ def get_argext_preamble(kernel, func_id): c_code_mapper = CCodeMapper() + neutral = get_ge_neutral if op.neutral_sign < 0 else get_le_neutral + return (prefix, """ inline %(scalar_t)s %(prefix)s_init(%(index_t)s *index_out) { diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index 31cf7c6b648ebf370a17d8beb2538b9748ddb30a..01e56405e30285705be7cb8eb6d75479c8658ef5 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -249,6 +249,10 @@ def opencl_symbol_mangler(kernel, name): return NumpyType(np.dtype(np.float64)), name elif name == "INFINITY": return NumpyType(np.dtype(np.float32)), name + elif name.startswith("INT_"): + return NumpyType(np.dtype(np.int32)), name + elif name.startswith("LONG_"): + return NumpyType(np.dtype(np.int64)), name else: return None