diff --git a/loopy/buffer.py b/loopy/buffer.py index 814eabb7c3be203b147526b877085b5429ca080a..46ebeae7caba01c3244650e0a3867e0264e41f02 100644 --- a/loopy/buffer.py +++ b/loopy/buffer.py @@ -313,7 +313,10 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None, init_instruction = ExpressionInstruction(id=init_insn_id, assignee=buf_var_init, expression=init_expression, - forced_iname_deps=frozenset(within_inames)) + forced_iname_deps=frozenset(within_inames), + insn_deps=frozenset(), + insn_deps_is_final=True, + ) # }}} diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index 68f7f131c4c40d2f647fb6c6095959b4884a3c41..a27158c2c3e62fc887d850fb71097f8aa41bb60e 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -922,12 +922,6 @@ class LoopKernel(RecordWithoutPickling): line = "%s: %s" % (iname, self.iname_to_tag.get(iname)) lines.append(line) - if self.substitutions: - lines.append(sep) - lines.append("SUBSTIUTION RULES:") - for rule_name in sorted(six.iterkeys(self.substitutions)): - lines.append(str(self.substitutions[rule_name])) - if self.temporary_variables: lines.append(sep) lines.append("TEMPORARIES:") @@ -935,6 +929,12 @@ class LoopKernel(RecordWithoutPickling): key=lambda tv: tv.name): lines.append(str(tv)) + if self.substitutions: + lines.append(sep) + lines.append("SUBSTIUTION RULES:") + for rule_name in sorted(six.iterkeys(self.substitutions)): + lines.append(str(self.substitutions[rule_name])) + lines.append(sep) lines.append("INSTRUCTIONS:") loop_list_width = 35 diff --git a/loopy/precompute.py b/loopy/precompute.py index 04461bce969d0b95d2939a01dd936d32bf66c0c9..d37b53d7b1d38c69fb700969f6589c3b9a812186 100644 --- a/loopy/precompute.py +++ b/loopy/precompute.py @@ -25,7 +25,7 @@ THE SOFTWARE. """ -from loopy.symbolic import (get_dependencies, SubstitutionMapper, +from loopy.symbolic import (get_dependencies, ExpandingSubstitutionMapper, ExpandingIdentityMapper) from pymbolic.mapper.substitutor import make_subst_func import numpy as np @@ -419,7 +419,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, list(extra_storage_axes) + list(range(len(subst.arguments)))) - expr_subst_dict = {} + prior_storage_axis_name_dict = {} storage_axis_names = [] storage_axis_sources = [] # number for arg#, or iname @@ -452,18 +452,12 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, new_iname_to_tag[name] = storage_axis_to_tag.get( tag_lookup_saxis, default_tag) - expr_subst_dict[old_name] = var(name) + prior_storage_axis_name_dict[name] = old_name del storage_axis_to_tag del storage_axes del new_storage_axis_names - compute_expr = ( - SubstitutionMapper(make_subst_func(expr_subst_dict)) - (subst.expression)) - - del expr_subst_dict - # }}} # {{{ fill out access_descriptors[...].storage_axis_exprs @@ -516,36 +510,24 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, # leave kernel domains unchanged new_kernel_domains = kernel.domains + non1_storage_axis_names = [] abm = NoOpArrayToBufferMap() # {{{ set up compute insn target_var_name = var_name_gen(based_on=c_subst_name) - assignee = var(target_var_name) if non1_storage_axis_names: assignee = assignee.index( tuple(var(iname) for iname in non1_storage_axis_names)) - def zero_length_1_arg(arg_name): - if arg_name in non1_storage_axis_names: - return var(arg_name) - else: - return 0 - - compute_expr = (SubstitutionMapper( - make_subst_func(dict( - (arg_name, zero_length_1_arg(arg_name)+bi) - for arg_name, bi in zip(storage_axis_names, abm.storage_base_indices) - ))) - (compute_expr)) - from loopy.kernel.data import ExpressionInstruction + compute_insn_id = kernel.make_unique_instruction_id(based_on=c_subst_name) compute_insn = ExpressionInstruction( - id=kernel.make_unique_instruction_id(based_on=c_subst_name), + id=compute_insn_id, assignee=assignee, - expression=compute_expr) + expression=subst.expression) # }}} @@ -591,6 +573,26 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, instructions=[compute_insn] + kernel.instructions, temporary_variables=new_temporary_variables) + # {{{ process substitutions on compute instruction + + storage_axis_subst_dict = {} + + for arg_name, bi in zip(storage_axis_names, abm.storage_base_indices): + if arg_name in non1_storage_axis_names: + arg = var(arg_name) + else: + arg = 0 + + storage_axis_subst_dict[prior_storage_axis_name_dict.get(arg_name, arg_name)] = arg+bi + + expr_subst_map = ExpandingSubstitutionMapper( + kernel.substitutions, kernel.get_var_name_generator(), + make_subst_func(storage_axis_subst_dict), + parse_stack_match("... < "+compute_insn_id)) + kernel = expr_subst_map.map_kernel(kernel) + + # }}} + from loopy import tag_inames return tag_inames(kernel, new_iname_to_tag)