Skip to content
Snippets Groups Projects
Commit 92dc6fbe authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Allow reuse of temporaries in precompute

parent 323ca1ed
No related branches found
No related tags found
No related merge requests found
......@@ -132,7 +132,7 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper):
access_descriptors, array_base_map,
storage_axis_names, storage_axis_sources,
non1_storage_axis_names,
target_var_name):
temporary_name):
super(RuleInvocationReplacer, self).__init__(rule_mapping_context)
self.subst_name = subst_name
......@@ -146,7 +146,7 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper):
self.storage_axis_sources = storage_axis_sources
self.non1_storage_axis_names = non1_storage_axis_names
self.target_var_name = target_var_name
self.temporary_name = temporary_name
def map_substitution(self, name, tag, arguments, expn_state):
if not (
......@@ -196,7 +196,7 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper):
ax_index = simplify_via_aff(ax_index - sax_base_idx)
stor_subscript.append(ax_index)
new_outer_expr = var(self.target_var_name)
new_outer_expr = var(self.temporary_name)
if stor_subscript:
new_outer_expr = new_outer_expr.index(tuple(stor_subscript))
......@@ -210,9 +210,9 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper):
def precompute(kernel, subst_use, sweep_inames=[], within=None,
storage_axes=None, precompute_inames=None, storage_axis_to_tag={},
default_tag="l.auto", dtype=None, fetch_bounding_box=False,
temporary_is_local=None):
storage_axes=None, temporary_name=None, precompute_inames=None,
storage_axis_to_tag={}, default_tag="l.auto", dtype=None,
fetch_bounding_box=False, temporary_is_local=None):
"""Precompute the expression described in the substitution rule determined by
*subst_use* and store it in a temporary array. A precomputation needs two
things to operate, a list of *sweep_inames* (order irrelevant) and an
......@@ -263,6 +263,11 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
May also equivalently be a comma-separated string.
:arg within: a stack match as understood by
:func:`loopy.context_matching.parse_stack_match`.
:arg temporary_name:
The temporary variable name to use for storing the precomputed data.
If it does not exist, it will be created. If it does exist, its properties
(such as size, type) are checked (and updated, if possible) to match
its use.
:arg precompute_inames:
If the specified inames do not already exist, they will be
created. If they do already exist, their loop domain is verified
......@@ -584,8 +589,10 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
# {{{ set up compute insn
target_var_name = var_name_gen(based_on=c_subst_name)
assignee = var(target_var_name)
if temporary_name is None:
temporary_name = var_name_gen(based_on=c_subst_name)
assignee = var(temporary_name)
if non1_storage_axis_names:
assignee = assignee.index(
......@@ -633,7 +640,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
access_descriptors, abm,
storage_axis_names, storage_axis_sources,
non1_storage_axis_names,
target_var_name)
temporary_name)
kernel = invr.map_kernel(kernel)
kernel = kernel.copy(
......@@ -655,15 +662,69 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
if temporary_is_local is None:
temporary_is_local = lp.auto
new_temp_shape = tuple(abm.non1_storage_shape)
new_temporary_variables = kernel.temporary_variables.copy()
temp_var = lp.TemporaryVariable(
name=target_var_name,
dtype=dtype,
base_indices=(0,)*len(abm.non1_storage_shape),
shape=tuple(abm.non1_storage_shape),
is_local=temporary_is_local)
new_temporary_variables[target_var_name] = temp_var
if temporary_name not in new_temporary_variables:
temp_var = lp.TemporaryVariable(
name=temporary_name,
dtype=dtype,
base_indices=(0,)*len(new_temp_shape),
shape=tuple(abm.non1_storage_shape),
is_local=temporary_is_local)
else:
temp_var = new_temporary_variables[temporary_name]
# {{{ check and adapt existing temporary
if temp_var.dtype is lp.auto:
pass
elif temp_var.dtype is not lp.auto and dtype is lp.auto:
dtype = temp_var.dtype
elif temp_var.dtype is not lp.auto and dtype is not lp.auto:
if temp_var.dtype != dtype:
raise LoopyError("Existing and new dtype of temporary '%s' "
"do not match (existing: %s, new: %s)"
% (temporary_name, temp_var.dtype, dtype))
temp_var = temp_var.copy(dtype=dtype)
if len(temp_var.shape) != len(new_temp_shape):
raise LoopyError("Existing and new temporary '%s' do not "
"have matching number of dimensions "
% (temporary_name,
len(temp_var.shape), len(new_temp_shape)))
if temp_var.base_indices != (0,) * len(new_temp_shape):
raise LoopyError("Existing and new temporary '%s' do not "
"have matching number of dimensions "
% (temporary_name,
len(temp_var.shape), len(new_temp_shape)))
new_temp_shape = tuple(
max(i, ex_i)
for i, ex_i in zip(new_temp_shape, temp_var.shape))
temp_var = temp_var.copy(shape=new_temp_shape)
if temporary_is_local == temp_var.is_local:
pass
elif temporary_is_local is lp.auto:
temporary_is_local = temp_var.is_local
elif temp_var.is_local is lp.auto:
pass
else:
raise LoopyError("Existing and new temporary '%s' do not "
"have matching values of 'is_local'"
% (temporary_name,
temp_var.is_local, temporary_is_local))
temp_var = temp_var.copy(is_local=temporary_is_local)
# }}}
new_temporary_variables[temporary_name] = temp_var
kernel = kernel.copy(
temporary_variables=new_temporary_variables)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment