diff --git a/loopy/__init__.py b/loopy/__init__.py index cecd0761c924d2dcf45f36017ffd0fa73c5392fe..3645408568727ae53fe661ac7f5241a9ba89b438 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -141,7 +141,9 @@ class _InameSplitter(ExpandingIdentityMapper): self.replacement_index = replacement_index def map_reduction(self, expr, expn_state): - if self.split_iname in expr.inames and self.within(expn_state.stack): + if (self.split_iname in expr.inames + and self.split_iname not in expn_state.arg_context + and self.within(expn_state.stack)): new_inames = list(expr.inames) new_inames.remove(self.split_iname) new_inames.extend([self.outer_iname, self.inner_iname]) @@ -153,7 +155,9 @@ class _InameSplitter(ExpandingIdentityMapper): return super(_InameSplitter, self).map_reduction(expr, expn_state) def map_variable(self, expr, expn_state): - if expr.name == self.split_iname and self.within(expn_state.stack): + if (expr.name == self.split_iname + and self.split_iname not in expn_state.arg_context + and self.within(expn_state.stack)): return self.replacement_index else: return super(_InameSplitter, self).map_variable(expr, expn_state) @@ -300,7 +304,8 @@ class _InameJoiner(ExpandingSubstitutionMapper): def map_reduction(self, expr, expn_state): expr_inames = set(expr.inames) - overlap = self.join_inames & expr_inames + overlap = (self.join_inames & expr_inames + - set(expn_state.arg_context)) if overlap and self.within(expn_state.stack): if overlap != expr_inames: raise LoopyError( @@ -500,22 +505,27 @@ class _InameDuplicator(ExpandingIdentityMapper): self.within = within def map_reduction(self, expr, expn_state): - if set(expr.inames) & self.old_inames_set and self.within(expn_state.stack): + if (set(expr.inames) & self.old_inames_set + and self.within(expn_state.stack)): new_inames = tuple( self.old_to_new.get(iname, iname) + if iname not in expn_state.arg_context + else iname for iname in expr.inames) from loopy.symbolic import Reduction return Reduction(expr.operation, new_inames, self.rec(expr.expr, expn_state)) else: - return ExpandingIdentityMapper.map_reduction(self, expr, expn_state) + return super(_InameDuplicator, self).map_reduction(expr, expn_state) def map_variable(self, expr, expn_state): new_name = self.old_to_new.get(expr.name) - if new_name is None or not self.within(expn_state.stack): - return ExpandingIdentityMapper.map_variable(self, expr, expn_state) + if (new_name is None + or expr.name in expn_state.arg_context + or not self.within(expn_state.stack)): + return super(_InameDuplicator, self).map_variable(expr, expn_state) else: from pymbolic import var return var(new_name) @@ -1187,6 +1197,10 @@ class _ReductionSplitter(ExpandingIdentityMapper): self.direction = direction def map_reduction(self, expr, expn_state): + if set(expr.inames) & set(expn_state.arg_context): + # FIXME + raise NotImplementedError() + if self.inames <= set(expr.inames) and self.within(expn_state.stack): leftover_inames = set(expr.inames) - self.inames diff --git a/loopy/buffer.py b/loopy/buffer.py index 2a7386af876ecf7159dd272387bb968d2720edd8..814eabb7c3be203b147526b877085b5429ca080a 100644 --- a/loopy/buffer.py +++ b/loopy/buffer.py @@ -81,6 +81,8 @@ class ArrayAccessReplacer(ExpandingIdentityMapper): abm = self.array_base_map + index = expn_state.apply_arg_context(index) + assert len(index) == len(abm.non1_storage_axis_flags) access_subscript = [] diff --git a/loopy/subst.py b/loopy/subst.py index d4dba4b5dfd9918926adaad2af9179813914e8be..d1a643f309b0a5394bc1115630c0fe6092d6b177 100644 --- a/loopy/subst.py +++ b/loopy/subst.py @@ -230,7 +230,8 @@ class TemporaryToSubstChanger(ExpandingIdentityMapper): return subst_name def map_variable(self, expr, expn_state): - if expr.name == self.temp_name: + if (expr.name == self.temp_name + and expr.name not in expn_state.arg_context): result = self.transform_access(None, expn_state) if result is not None: return result @@ -239,7 +240,8 @@ class TemporaryToSubstChanger(ExpandingIdentityMapper): expr, expn_state) def map_subscript(self, expr, expn_state): - if expr.aggregate.name == self.temp_name: + if (expr.aggregate.name == self.temp_name + and expr.aggregate.name not in expn_state.arg_context): result = self.transform_access(expr.index, expn_state) if result is not None: return result diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 9309fd0418e453311240cce8e9f6c73a8d428f9c..a17c4cdd5ae00b3ae6152083196bdf2793a7453b 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -346,16 +346,26 @@ def parse_tagged_name(expr): class ExpansionState(Record): """ - :ivar stack: a tuple representing the current expansion stack, as a tuple + .. attribute:: stack + + a tuple representing the current expansion stack, as a tuple of (name, tag) pairs. At the top level, this should be initialized to a tuple with the id of the calling instruction. - :ivar arg_context: a dict representing current argument values + + .. attribute:: arg_context + + a dict representing current argument values """ @property def insn_id(self): return self.stack[0][0] + def apply_arg_context(self, expr): + from pymbolic.mapper.substitutor import make_subst_func + return SubstitutionMapper( + make_subst_func(self.arg_context))(expr) + class SubstitutionRuleRenamer(IdentityMapper): def __init__(self, renames): @@ -402,6 +412,9 @@ def rename_subst_rules_in_instructions(insns, renames): class ExpandingIdentityMapper(IdentityMapper): """Note: the third argument dragged around by this mapper is the current :class:`ExpansionState`. + + Subclasses of this must be careful to not touch identifiers that + are in :attr:`ExpansionState.arg_context`. """ def __init__(self, old_subst_rules, make_unique_var_name): @@ -585,8 +598,12 @@ class ExpandingSubstitutionMapper(ExpandingIdentityMapper): self.within = within def map_variable(self, expr, expn_state): + if (expr.name in expn_state.arg_context + or not self.within(expn_state.stack)): + return ExpandingIdentityMapper.map_variable(self, expr, expn_state) + result = self.subst_func(expr) - if result is not None or not self.within(expn_state.stack): + if result is not None: return result else: return ExpandingIdentityMapper.map_variable(self, expr, expn_state) diff --git a/test/test_loopy.py b/test/test_loopy.py index 6a59eb7388afedcc65bb6aa3d42f57b2e7152ac3..9498dab3573290ae130a0d6189dd47cf1710709b 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1816,6 +1816,36 @@ def test_affine_map_inames(): print(knl) +def test_precompute_confusing_subst_arguments(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel( + "{[i,j]: 0<=i<n and 0<=j<5}", + """ + D(i):=a[i+1]-a[i] + b[i,j] = D(j) + """) + + knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32)) + + ref_knl = knl + + knl = lp.tag_inames(knl, dict(j="g.1")) + knl = lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0") + + from loopy.symbolic import get_dependencies + assert "i_inner" not in get_dependencies(knl.substitutions["D"].expression) + knl = lp.precompute(knl, "D") + + lp.auto_test_vs_ref( + ref_knl, ctx, knl, + parameters=dict(n=12345)) + + +def test_precompute_nested_subst(ctx_factory): + pass + + def test_poisson(ctx_factory): # Stolen from Peter Coogan and Rob Kirby for FEM assembly ctx = ctx_factory()