diff --git a/test/test_loopy.py b/test/test_loopy.py index d5d1a1f31ba5ad9ecaeedeb92b1188d5208e37c6..b92161ac7bb7145ad6bffd8615bff0cd84eabdcc 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1987,19 +1987,28 @@ def test_integer_reduction(ctx_factory): dtype=to_loopy_type(vtype), shape=lp.auto) - reductions = [('max', lambda x: x == np.max(var_int)), - ('min', lambda x: x == np.min(var_int)), - ('sum', lambda x: x == np.sum(var_int)), - ('product', lambda x: x == np.prod(var_int)), - ('argmax', lambda x: (x[0] == np.max(var_int) and - var_int[out[1]] == np.max(var_int))), - ('argmin', lambda x: (x[0] == np.min(var_int) and - var_int[out[1]] == np.min(var_int)))] - - for reduction, function in reductions: + from collections import namedtuple + ReductionTest = namedtuple('ReductionTest', 'kind, check, args') + + reductions = [ + ReductionTest('max', lambda x: x == np.max(var_int), args='var[k]'), + ReductionTest('min', lambda x: x == np.min(var_int), args='var[k]'), + ReductionTest('sum', lambda x: x == np.sum(var_int), args='var[k]'), + ReductionTest('product', lambda x: x == np.prod(var_int), args='var[k]'), + ReductionTest('argmax', + lambda x: ( + x[0] == np.max(var_int) and var_int[out[1]] == np.max(var_int)), + args='var[k], k'), + ReductionTest('argmin', + lambda x: ( + x[0] == np.min(var_int) and var_int[out[1]] == np.min(var_int)), + args='var[k], k') + ] + + for reduction, function, args in reductions: kstr = ("out" if 'arg' not in reduction else "out[0], out[1]") - kstr += ' = {0}(k, var[k])'.format(reduction) + kstr += ' = {0}(k, {1})'.format(reduction, args) knl = lp.make_kernel('{[k]: 0<=k<n}', kstr, [var_lp, '...'])