diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 1f15e3fc065fc7a295e2dc56df4d7eb685207268..24ab2e93442ac3b03bd59ba32fcf5e4f6041c407 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -24,7 +24,7 @@ from islpy import dim_type # {{{ loopy-specific primitives class Reduction(AlgebraicLeaf): - def __init__(self, operation, inames, expr, tag=None): + def __init__(self, operation, inames, expr): assert isinstance(inames, tuple) if isinstance(operation, str): @@ -34,21 +34,19 @@ class Reduction(AlgebraicLeaf): self.operation = operation self.inames = inames self.expr = expr - self.tag = tag def __getinitargs__(self): - return (self.operation, self.inames, self.expr, self.tag) + return (self.operation, self.inames, self.expr) def get_hash(self): return hash((self.__class__, self.operation, self.inames, - self.expr, self.tag)) + self.expr)) def is_equal(self, other): return (other.__class__ == self.__class__ and other.operation == self.operation and other.inames == self.inames - and other.expr == self.expr - and other.tag == self.tag) + and other.expr == self.expr) def stringifier(self): return StringifyMapper @@ -61,8 +59,7 @@ class Reduction(AlgebraicLeaf): class IdentityMapperMixin(object): def map_reduction(self, expr): - return Reduction(expr.operation, expr.inames, - self.rec(expr.expr), expr.tag) + return Reduction(expr.operation, expr.inames, self.rec(expr.expr)) class IdentityMapper(IdentityMapperBase, IdentityMapperMixin): pass @@ -104,14 +101,20 @@ class FunctionToPrimitiveMapper(IdentityMapper): else: raise TypeError("cse takes two arguments") - elif isinstance(expr.function, Variable) and expr.function.name == "reduce": - if len(expr.parameters) == 3: - operation, inames, red_expr = expr.parameters - tag = None - elif len(expr.parameters) == 4: - operation, inames, red_expr, tag = expr.parameters + elif isinstance(expr.function, Variable): + if expr.function.name == "reduce": + if len(expr.parameters) == 3: + operation, inames, red_expr = expr.parameters + else: + raise TypeError("invalid 'reduce' calling sequence") else: - raise TypeError("reduce takes three or four arguments") + from loopy.kernel import parse_reduction_op + if (parse_reduction_op(expr.function.name) + and len(expr.parameters) == 2): + operation = expr.function + inames, red_expr = expr.parameters + else: + return IdentityMapper.map_call(self, expr) red_expr = self.rec(red_expr) @@ -133,12 +136,7 @@ class FunctionToPrimitiveMapper(IdentityMapper): processed_inames.append(iname.name) - if tag is not None: - if not isinstance(tag, Variable): - raise TypeError("tag argument to reduce() must be a symbol") - tag = tag.name - - return Reduction(operation, tuple(processed_inames), red_expr, tag) + return Reduction(operation, tuple(processed_inames), red_expr) else: return IdentityMapper.map_call(self, expr) @@ -158,7 +156,7 @@ class ReductionLoopSplitter(IdentityMapper): new_inames.remove(self.old_iname) new_inames.extend([self.outer_iname, self.inner_iname]) return Reduction(expr.operation, tuple(new_inames), - expr.expr, expr.tag) + expr.expr) else: return IdentityMapper.map_reduction(self, expr) diff --git a/test/test_matmul.py b/test/test_matmul.py index b928b57c820524f4d93e4e68eade2dee930861ab..2ecdfdb2754b22b2f0c1e1ece05b049a151ffb86 100644 --- a/test/test_matmul.py +++ b/test/test_matmul.py @@ -203,7 +203,7 @@ def test_plain_matrix_mul_new_ui(ctx_factory): knl = lp.LoopKernel(ctx.devices[0], "[n] -> {[i,j,k]: 0<=i,j,k<n}", [ - "{yo} c[i, j] = reduce(sum_float32, k, cse(a[i, k], lhsmat)*cse(b[k, j], rhsmat))" + "{label} c[i, j] = sum_float32(k, cse(a[i, k], lhsmat)*cse(b[k, j], rhsmat))" ], [ lp.ArrayArg("a", dtype, shape=(n, n), order=order),