diff --git a/loopy/subst.py b/loopy/subst.py index b5b2145edadb72b648860ad7c1231088d14c7627..db94bb652b6862f5f35d38d4f1fc0ce29c8b395d 100644 --- a/loopy/subst.py +++ b/loopy/subst.py @@ -43,10 +43,12 @@ class ExprDescriptor(Record): __slots__ = ["insn", "expr", "unif_var_dict"] -def extract_subst(kernel, subst_name, template, parameters): +def extract_subst(kernel, subst_name, template, parameters=()): """ :arg subst_name: The name of the substitution rule to be created. :arg template: Unification template expression. + :arg parameters: An iterable of parameters used in + *template*, or a comma-separated string of the same. All targeted subexpressions must match ('unify with') *template* The template may contain '*' wildcards that will have to match exactly across all @@ -57,6 +59,10 @@ def extract_subst(kernel, subst_name, template, parameters): from pymbolic import parse template = parse(template) + if isinstance(parameters, str): + parameters = tuple( + s.strip() for s in parameters.split()) + var_name_gen = kernel.get_var_name_generator() # {{{ replace any wildcards in template with new variables diff --git a/test/test_loopy.py b/test/test_loopy.py index 17e28af6a05ef94abb5ba061271d9f7f5d4ec617..564291316f2f3ef8c43ded75219327d5098afc61 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -84,6 +84,23 @@ def test_complicated_subst(ctx_factory): assert substs_with_letter == how_many +def test_extract_subst(ctx_factory): + knl = lp.make_kernel( + "{[i]: 0<=i