diff --git a/loopy/precompute.py b/loopy/precompute.py index 15b46ecd648ab9289c4d2a467dd5b5cbaa6a4be4..a584e5c2455b42853b47f6fa007786862957709b 100644 --- a/loopy/precompute.py +++ b/loopy/precompute.py @@ -25,7 +25,7 @@ THE SOFTWARE. """ -from loopy.symbolic import (get_dependencies, SubstitutionMapper, +from loopy.symbolic import (get_dependencies, ExpandingSubstitutionMapper, ExpandingIdentityMapper) from pymbolic.mapper.substitutor import make_subst_func import numpy as np @@ -435,7 +435,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, list(extra_storage_axes) + list(range(len(subst.arguments)))) - expr_subst_dict = {} + prior_storage_axis_name_dict = {} storage_axis_names = [] storage_axis_sources = [] # number for arg#, or iname @@ -468,18 +468,12 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, new_iname_to_tag[name] = storage_axis_to_tag.get( tag_lookup_saxis, default_tag) - expr_subst_dict[old_name] = var(name) + prior_storage_axis_name_dict[name] = old_name del storage_axis_to_tag del storage_axes del new_storage_axis_names - compute_expr = ( - SubstitutionMapper(make_subst_func(expr_subst_dict)) - (subst.expression)) - - del expr_subst_dict - # }}} # {{{ fill out access_descriptors[...].storage_axis_exprs @@ -532,36 +526,24 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, # leave kernel domains unchanged new_kernel_domains = kernel.domains + non1_storage_axis_names = [] abm = NoOpArrayToBufferMap() # {{{ set up compute insn target_var_name = var_name_gen(based_on=c_subst_name) - assignee = var(target_var_name) if non1_storage_axis_names: assignee = assignee.index( tuple(var(iname) for iname in non1_storage_axis_names)) - def zero_length_1_arg(arg_name): - if arg_name in non1_storage_axis_names: - return var(arg_name) - else: - return 0 - - compute_expr = (SubstitutionMapper( - make_subst_func(dict( - (arg_name, zero_length_1_arg(arg_name)+bi) - for arg_name, bi in zip(storage_axis_names, abm.storage_base_indices) - ))) - (compute_expr)) - from loopy.kernel.data import ExpressionInstruction + compute_insn_id = kernel.make_unique_instruction_id(based_on=c_subst_name) compute_insn = ExpressionInstruction( - id=kernel.make_unique_instruction_id(based_on=c_subst_name), + id=compute_insn_id, assignee=assignee, - expression=compute_expr) + expression=subst.expression) # }}} @@ -607,6 +589,26 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, instructions=[compute_insn] + kernel.instructions, temporary_variables=new_temporary_variables) + # {{{ process substitutions on compute instruction + + storage_axis_subst_dict = {} + + for arg_name, bi in zip(storage_axis_names, abm.storage_base_indices): + if arg_name in non1_storage_axis_names: + arg = var(arg_name) + else: + arg = 0 + + storage_axis_subst_dict[prior_storage_axis_name_dict.get(arg_name, arg_name)] = arg+bi + + expr_subst_map = ExpandingSubstitutionMapper( + kernel.substitutions, kernel.get_var_name_generator(), + make_subst_func(storage_axis_subst_dict), + parse_stack_match("... < "+compute_insn_id)) + kernel = expr_subst_map.map_kernel(kernel) + + # }}} + from loopy import tag_inames return tag_inames(kernel, new_iname_to_tag)