From 92dc6fbec92997d0b77f31fe0a6a472c83c31e8b Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Mon, 22 Jun 2015 02:17:53 -0500 Subject: [PATCH] Allow reuse of temporaries in precompute --- loopy/precompute.py | 95 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 78 insertions(+), 17 deletions(-) diff --git a/loopy/precompute.py b/loopy/precompute.py index 935d6d440..726cc0786 100644 --- a/loopy/precompute.py +++ b/loopy/precompute.py @@ -132,7 +132,7 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper): access_descriptors, array_base_map, storage_axis_names, storage_axis_sources, non1_storage_axis_names, - target_var_name): + temporary_name): super(RuleInvocationReplacer, self).__init__(rule_mapping_context) self.subst_name = subst_name @@ -146,7 +146,7 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper): self.storage_axis_sources = storage_axis_sources self.non1_storage_axis_names = non1_storage_axis_names - self.target_var_name = target_var_name + self.temporary_name = temporary_name def map_substitution(self, name, tag, arguments, expn_state): if not ( @@ -196,7 +196,7 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper): ax_index = simplify_via_aff(ax_index - sax_base_idx) stor_subscript.append(ax_index) - new_outer_expr = var(self.target_var_name) + new_outer_expr = var(self.temporary_name) if stor_subscript: new_outer_expr = new_outer_expr.index(tuple(stor_subscript)) @@ -210,9 +210,9 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper): def precompute(kernel, subst_use, sweep_inames=[], within=None, - storage_axes=None, precompute_inames=None, storage_axis_to_tag={}, - default_tag="l.auto", dtype=None, fetch_bounding_box=False, - temporary_is_local=None): + storage_axes=None, temporary_name=None, precompute_inames=None, + storage_axis_to_tag={}, default_tag="l.auto", dtype=None, + fetch_bounding_box=False, temporary_is_local=None): """Precompute the expression described in the substitution rule determined by *subst_use* and store it in a temporary array. A precomputation needs two things to operate, a list of *sweep_inames* (order irrelevant) and an @@ -263,6 +263,11 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, May also equivalently be a comma-separated string. :arg within: a stack match as understood by :func:`loopy.context_matching.parse_stack_match`. + :arg temporary_name: + The temporary variable name to use for storing the precomputed data. + If it does not exist, it will be created. If it does exist, its properties + (such as size, type) are checked (and updated, if possible) to match + its use. :arg precompute_inames: If the specified inames do not already exist, they will be created. If they do already exist, their loop domain is verified @@ -584,8 +589,10 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, # {{{ set up compute insn - target_var_name = var_name_gen(based_on=c_subst_name) - assignee = var(target_var_name) + if temporary_name is None: + temporary_name = var_name_gen(based_on=c_subst_name) + + assignee = var(temporary_name) if non1_storage_axis_names: assignee = assignee.index( @@ -633,7 +640,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, access_descriptors, abm, storage_axis_names, storage_axis_sources, non1_storage_axis_names, - target_var_name) + temporary_name) kernel = invr.map_kernel(kernel) kernel = kernel.copy( @@ -655,15 +662,69 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, if temporary_is_local is None: temporary_is_local = lp.auto + new_temp_shape = tuple(abm.non1_storage_shape) + new_temporary_variables = kernel.temporary_variables.copy() - temp_var = lp.TemporaryVariable( - name=target_var_name, - dtype=dtype, - base_indices=(0,)*len(abm.non1_storage_shape), - shape=tuple(abm.non1_storage_shape), - is_local=temporary_is_local) - - new_temporary_variables[target_var_name] = temp_var + if temporary_name not in new_temporary_variables: + temp_var = lp.TemporaryVariable( + name=temporary_name, + dtype=dtype, + base_indices=(0,)*len(new_temp_shape), + shape=tuple(abm.non1_storage_shape), + is_local=temporary_is_local) + + else: + temp_var = new_temporary_variables[temporary_name] + + # {{{ check and adapt existing temporary + + if temp_var.dtype is lp.auto: + pass + elif temp_var.dtype is not lp.auto and dtype is lp.auto: + dtype = temp_var.dtype + elif temp_var.dtype is not lp.auto and dtype is not lp.auto: + if temp_var.dtype != dtype: + raise LoopyError("Existing and new dtype of temporary '%s' " + "do not match (existing: %s, new: %s)" + % (temporary_name, temp_var.dtype, dtype)) + + temp_var = temp_var.copy(dtype=dtype) + + if len(temp_var.shape) != len(new_temp_shape): + raise LoopyError("Existing and new temporary '%s' do not " + "have matching number of dimensions " + % (temporary_name, + len(temp_var.shape), len(new_temp_shape))) + + if temp_var.base_indices != (0,) * len(new_temp_shape): + raise LoopyError("Existing and new temporary '%s' do not " + "have matching number of dimensions " + % (temporary_name, + len(temp_var.shape), len(new_temp_shape))) + + new_temp_shape = tuple( + max(i, ex_i) + for i, ex_i in zip(new_temp_shape, temp_var.shape)) + + temp_var = temp_var.copy(shape=new_temp_shape) + + if temporary_is_local == temp_var.is_local: + pass + elif temporary_is_local is lp.auto: + temporary_is_local = temp_var.is_local + elif temp_var.is_local is lp.auto: + pass + else: + raise LoopyError("Existing and new temporary '%s' do not " + "have matching values of 'is_local'" + % (temporary_name, + temp_var.is_local, temporary_is_local)) + + temp_var = temp_var.copy(is_local=temporary_is_local) + + # }}} + + new_temporary_variables[temporary_name] = temp_var kernel = kernel.copy( temporary_variables=new_temporary_variables) -- GitLab