diff --git a/loopy/transform/subst.py b/loopy/transform/subst.py index b92698ffa1e84455be3f79bed7dbf884f36be490..3eee3d8f3093ce68670ab2c119f41bc385afde01 100644 --- a/loopy/transform/subst.py +++ b/loopy/transform/subst.py @@ -25,10 +25,8 @@ THE SOFTWARE. import six from loopy.symbolic import ( - get_dependencies, SubstitutionMapper, RuleAwareIdentityMapper, SubstitutionRuleMappingContext) from loopy.diagnostic import LoopyError -from pymbolic.mapper.substitutor import make_subst_func from pytools import ImmutableRecord from pymbolic import var @@ -80,40 +78,13 @@ def extract_subst(kernel, subst_name, template, parameters=()): # }}} - # {{{ deal with iname deps of template that are not independent_inames - - # (We call these 'matching_vars', because they have to match exactly in - # every CSE. As above, they might need to be renamed to make them unique - # within the kernel.) - - matching_vars = [] - old_to_new = {} - - for iname in (get_dependencies(template) - - set(parameters) - - kernel.non_iname_variable_names()): - if iname in kernel.all_inames(): - # need to rename to be unique - new_iname = var_name_gen(iname) - old_to_new[iname] = var(new_iname) - matching_vars.append(new_iname) - else: - matching_vars.append(iname) - - if old_to_new: - template = ( - SubstitutionMapper(make_subst_func(old_to_new)) - (template)) - - # }}} - # {{{ gather up expressions expr_descriptors = [] from loopy.symbolic import UnidirectionalUnifier unif = UnidirectionalUnifier( - lhs_mapping_candidates=set(parameters) | set(matching_vars)) + lhs_mapping_candidates=set(parameters)) def gather_exprs(expr, mapper): urecs = unif(template, expr) diff --git a/test/test_transform.py b/test/test_transform.py index cdc0c14b8bacc4fe5279d000461c0ea2244af021..6eb6697b5c192911864000781381244dfcbef631 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -570,6 +570,21 @@ def test_nested_substs_in_insns(ctx_factory): lp.auto_test_vs_ref(ref_knl, ctx, knl) +def test_extract_subst_with_iname_deps_in_templ(ctx_factory): + knl = lp.make_kernel( + "{[i, j, k]: 0<=i<100 and 0<=j,k<5}", + """ + y[i, j, k] = x[i, j, k] + """, + [lp.GlobalArg('x,y', shape=lp.auto, dtype=float)], + lang_version=(2018, 2)) + + knl = lp.extract_subst(knl, 'rule1', 'x[i, arg1, arg2]', + parameters=('arg1', 'arg2')) + + lp.auto_test_vs_ref(knl, ctx_factory(), knl) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])