diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index 29c0a2f76bcb7fb871951b78a405f85c0d925e4a..468f4528a442a32d1e600f3036c432c37892d2e3 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -408,13 +408,16 @@ class IndexRankFinder(WalkMapper): else: self.index_ranks.append(len(expr.index)) -def guess_kernel_args_if_requested(domains, instructions, temporary_variables, kernel_args): +def guess_kernel_args_if_requested(domains, instructions, temporary_variables, subst_rules, kernel_args): if "..." not in kernel_args: return kernel_args kernel_args = kernel_args[:] kernel_args.remove("...") + from loopy.symbolic import SubstitutionRuleExpander + submap = SubstitutionRuleExpander(subst_rules) + # {{{ find names that are *not* arguments all_inames = set() @@ -442,8 +445,8 @@ def guess_kernel_args_if_requested(domains, instructions, temporary_variables, k from loopy.symbolic import get_dependencies for insn in instructions: all_written_names.add(insn.get_assignee_var_name()) - all_names.update(get_dependencies(insn.expression)) - all_names.update(get_dependencies(insn.assignee)) + all_names.update(get_dependencies(submap(insn.assignee, insn.id))) + all_names.update(get_dependencies(submap(insn.expression, insn.id))) all_params = set() for dom in domains: @@ -456,9 +459,10 @@ def guess_kernel_args_if_requested(domains, instructions, temporary_variables, k def find_index_rank(name): irf = IndexRankFinder(name) + for insn in instructions: - irf(insn.expression) - irf(insn.assignee) + irf(submap(insn.expression, insn.id)) + irf(submap(insn.assignee, insn.id)) if not irf.index_ranks: return 0 @@ -778,6 +782,10 @@ def guess_arg_shape_if_requested(kernel, default_order): new_args = [] from loopy.kernel.data import ShapedArg, auto_shape, auto_strides + from loopy.symbolic import SubstitutionRuleExpander + + submap = SubstitutionRuleExpander(kernel.substitutions, + kernel.get_var_name_generator()) for arg in kernel.args: if isinstance(arg, ShapedArg) and ( @@ -786,8 +794,9 @@ def guess_arg_shape_if_requested(kernel, default_order): for insn in kernel.instructions: domain = kernel.get_inames_domain(kernel.insn_inames(insn)) - armap(insn.assignee, domain) - armap(insn.expression, domain) + armap(submap(insn.assignee, insn.id), domain) + armap(submap(insn.expression, insn.id), domain) + if armap.access_range is None: # no subscripts found, let's call it a scalar @@ -895,7 +904,7 @@ def make_kernel(device, domains, instructions, kernel_args=["..."], **kwargs): domains = parse_domains(isl_context, domains, defines) kernel_args = guess_kernel_args_if_requested(domains, instructions, - kwargs.get("temporary_variables", {}), kernel_args) + kwargs.get("temporary_variables", {}), substitutions, kernel_args) from loopy.kernel import LoopKernel knl = LoopKernel(device, domains, instructions, kernel_args, **kwargs)