From bdd6159ee01abeaa702b922f030f9e5aa509e112 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 4 Jan 2020 21:34:29 -0600 Subject: [PATCH 1/2] extract_subst: removes unnecessary iname renaming --- loopy/transform/subst.py | 31 +------------------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/loopy/transform/subst.py b/loopy/transform/subst.py index b92698ffa..3eee3d8f3 100644 --- a/loopy/transform/subst.py +++ b/loopy/transform/subst.py @@ -25,10 +25,8 @@ THE SOFTWARE. import six from loopy.symbolic import ( - get_dependencies, SubstitutionMapper, RuleAwareIdentityMapper, SubstitutionRuleMappingContext) from loopy.diagnostic import LoopyError -from pymbolic.mapper.substitutor import make_subst_func from pytools import ImmutableRecord from pymbolic import var @@ -80,40 +78,13 @@ def extract_subst(kernel, subst_name, template, parameters=()): # }}} - # {{{ deal with iname deps of template that are not independent_inames - - # (We call these 'matching_vars', because they have to match exactly in - # every CSE. As above, they might need to be renamed to make them unique - # within the kernel.) - - matching_vars = [] - old_to_new = {} - - for iname in (get_dependencies(template) - - set(parameters) - - kernel.non_iname_variable_names()): - if iname in kernel.all_inames(): - # need to rename to be unique - new_iname = var_name_gen(iname) - old_to_new[iname] = var(new_iname) - matching_vars.append(new_iname) - else: - matching_vars.append(iname) - - if old_to_new: - template = ( - SubstitutionMapper(make_subst_func(old_to_new)) - (template)) - - # }}} - # {{{ gather up expressions expr_descriptors = [] from loopy.symbolic import UnidirectionalUnifier unif = UnidirectionalUnifier( - lhs_mapping_candidates=set(parameters) | set(matching_vars)) + lhs_mapping_candidates=set(parameters)) def gather_exprs(expr, mapper): urecs = unif(template, expr) -- GitLab From 239c3f86916273a70c193bd25dbe814aec6fde03 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 4 Jan 2020 21:38:25 -0600 Subject: [PATCH 2/2] another test for extract_subst --- test/test_transform.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/test_transform.py b/test/test_transform.py index cdc0c14b8..6eb6697b5 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -570,6 +570,21 @@ def test_nested_substs_in_insns(ctx_factory): lp.auto_test_vs_ref(ref_knl, ctx, knl) +def test_extract_subst_with_iname_deps_in_templ(ctx_factory): + knl = lp.make_kernel( + "{[i, j, k]: 0<=i<100 and 0<=j,k<5}", + """ + y[i, j, k] = x[i, j, k] + """, + [lp.GlobalArg('x,y', shape=lp.auto, dtype=float)], + lang_version=(2018, 2)) + + knl = lp.extract_subst(knl, 'rule1', 'x[i, arg1, arg2]', + parameters=('arg1', 'arg2')) + + lp.auto_test_vs_ref(knl, ctx_factory(), knl) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab