diff --git a/loopy/symbolic.py b/loopy/symbolic.py index a462c8ca8919a12334482dc608363f1de8494d9f..af59655935a000fdd487114d74b13e6e34d6c6d1 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -1969,7 +1969,8 @@ def simplify_using_aff(kernel, expr): :arg expr: An instance of :class:`pymbolic.primitives.Expression`. """ - inames = get_dependencies(expr) & kernel.all_inames() + deps = get_dependencies(expr) + inames = deps & kernel.all_inames() # FIXME: Ideally, we should find out what inames are usable and allow # the simplification to use all of those. For now, fall back to making @@ -1979,6 +1980,15 @@ def simplify_using_aff(kernel, expr): .get_inames_domain(inames) .project_out_except(inames, [dim_type.set])) + non_inames = deps - set(domain.get_var_dict().keys()) + non_inames = set([name for name in set(non_inames) if name.isidentifier()]) + if non_inames: + cur_dim = domain.dim(isl.dim_type.set) + domain = domain.insert_dims(isl.dim_type.set, cur_dim, len(non_inames)) + for non_iname in non_inames: + domain.set_dim_name(isl.dim_type.set, cur_dim, non_iname) + cur_dim += 1 + try: aff = guarded_aff_from_expr(domain.space, expr) except ExpressionToAffineConversionError: diff --git a/loopy/transform/callable.py b/loopy/transform/callable.py index 1f6c9b2b3cdc69b3830a9840479df14e695a26d4..df887689ada095331cb445c258a17bb87130d8a3 100644 --- a/loopy/transform/callable.py +++ b/loopy/transform/callable.py @@ -29,6 +29,7 @@ from loopy.diagnostic import LoopyError from loopy.kernel.instruction import (CallInstruction, MultiAssignmentBase, Assignment, CInstruction, _DataObliviousInstruction) from loopy.symbolic import ( + simplify_using_aff, RuleAwareIdentityMapper, RuleAwareSubstitutionMapper, SubstitutionRuleMappingContext) from loopy.kernel.function_interface import ( @@ -121,13 +122,16 @@ class KernelArgumentSubstitutor(RuleAwareIdentityMapper): callee_knl, callee_arg_to_call_param): super().__init__(rule_mapping_context) self.caller_knl = caller_knl + + # CAUTION: This kernel has post-substitution domains! self.callee_knl = callee_knl + self.callee_arg_to_call_param = callee_arg_to_call_param def map_subscript(self, expr, expn_state): if expr.aggregate.name in self.callee_knl.arg_dict: from loopy.symbolic import get_start_subscript_from_sar - from loopy.isl_helpers import simplify_via_aff + from loopy.symbolic import simplify_via_aff from pymbolic.primitives import Subscript, Variable sar = self.callee_arg_to_call_param[expr.aggregate.name] # SubArrayRef @@ -157,7 +161,8 @@ class KernelArgumentSubstitutor(RuleAwareIdentityMapper): flatten_index -= (dim_tag.stride * ind) new_indices.append(ind) - new_indices = tuple(simplify_via_aff(i) for i in new_indices) + new_indices = tuple(simplify_using_aff( + self.callee_knl, i) for i in new_indices) return Subscript(Variable(sar.subscript.aggregate.name), new_indices) else: @@ -364,8 +369,18 @@ def _inline_call_instruction(caller_knl, callee_knl, call_insn): rule_mapping_context = SubstitutionRuleMappingContext(callee_knl.substitutions, vng) - smap = KernelArgumentSubstitutor(rule_mapping_context, caller_knl, - callee_knl, arg_map) + smap = KernelArgumentSubstitutor( + rule_mapping_context, caller_knl, + + # HACK: The kernel returned by this copy doesn't make sense: + # It uses caller inames in its domain. The domains are/should be + # only used for expression simplification. Ideally, we'd pass + # the domains for this separately. + # Other than that, the kernel is used for looking up argument + # definitions, which is OK. + callee_knl.copy(domains=new_domains), + + arg_map) callee_knl = rule_mapping_context.finish_kernel(smap.map_kernel(callee_knl)) diff --git a/test/test_transform.py b/test/test_transform.py index fc1d0efec440380f35fe11e76545e4349acbc0a1..ee18c81141071bd6f0999085a170e48ca7551b7b 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -1262,18 +1262,12 @@ def test_privatize_with_nonzero_lbound(ctx_factory): def test_simplify_indices(ctx_factory): ctx = ctx_factory() - twice = lp.make_function( - "{[i, j]: 0<=i<10 and 0<=j<4}", - """ - y[i,j] = 2*x[i,j] - """, name="zerozerozeroonezeroify") - knl = lp.make_kernel( - "{:}", + "{[i]: 0<=i<10}", """ - Y[:,:] = zerozerozeroonezeroify(X[:,:]) + Y[i] = X[10*(i//10) + i] """, [lp.GlobalArg("X,Y", - shape=(10, 4), + shape=(10,), dtype=np.float64)]) class ContainsFloorDiv(lp.symbolic.CombineMapper): @@ -1289,8 +1283,6 @@ def test_simplify_indices(ctx_factory): def map_constant(self, expr): return False - knl = lp.merge([knl, twice]) - knl = lp.inline_callable_kernel(knl, "zerozerozeroonezeroify") simplified_knl = lp.simplify_indices(knl) contains_floordiv = ContainsFloorDiv()