diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 8897f35938ee29e50d70edfdc52945bc2ad856ac..bad7f840f648bc72d7f505fb72a200f3df2fb0d6 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -305,7 +305,27 @@ class Reduction(AlgebraicLeaf): init_arg_names = ("operation", "inames", "expr") def __init__(self, operation, inames, expr): + if isinstance(inames, str): + inames = tuple(iname.strip() for iname in inames.split(",")) + + elif isinstance(inames, Variable): + inames = (inames,) + assert isinstance(inames, tuple) + + def strip_var(iname): + if isinstance(iname, Variable): + iname = iname.name + + assert isinstance(iname, str) + return iname + + inames = tuple(strip_var(iname) for iname in inames) + + if isinstance(operation, str): + from loopy.library.reduction import parse_reduction_op + operation = parse_reduction_op(operation) + from loopy.library.reduction import ReductionOperation assert isinstance(operation, ReductionOperation)