Skip to content
Snippets Groups Projects
Commit a3e18078 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Use proper attribute to override get_grid_sizes_for_insn_ids

parent 6a8191ae
No related branches found
No related tags found
No related merge requests found
...@@ -199,7 +199,7 @@ class LoopKernel(ImmutableRecordWithoutPickling): ...@@ -199,7 +199,7 @@ class LoopKernel(ImmutableRecordWithoutPickling):
# When kernels get intersected in slab decomposition, # When kernels get intersected in slab decomposition,
# their grid sizes shouldn't change. This provides # their grid sizes shouldn't change. This provides
# a way to forward sub-kernel grid size requests. # a way to forward sub-kernel grid size requests.
get_grid_sizes_for_insn_ids=None): overridden_get_grid_sizes_for_insn_ids=None):
if cache_manager is None: if cache_manager is None:
from loopy.kernel.tools import SetOperationCacheManager from loopy.kernel.tools import SetOperationCacheManager
...@@ -265,10 +265,6 @@ class LoopKernel(ImmutableRecordWithoutPickling): ...@@ -265,10 +265,6 @@ class LoopKernel(ImmutableRecordWithoutPickling):
if np.iinfo(index_dtype.numpy_dtype).min >= 0: if np.iinfo(index_dtype.numpy_dtype).min >= 0:
raise TypeError("index_dtype must be signed") raise TypeError("index_dtype must be signed")
if get_grid_sizes_for_insn_ids is not None:
# overwrites method down below
self.get_grid_sizes_for_insn_ids = get_grid_sizes_for_insn_ids
if state not in [ if state not in [
kernel_state.INITIAL, kernel_state.INITIAL,
kernel_state.PREPROCESSED, kernel_state.PREPROCESSED,
...@@ -302,7 +298,9 @@ class LoopKernel(ImmutableRecordWithoutPickling): ...@@ -302,7 +298,9 @@ class LoopKernel(ImmutableRecordWithoutPickling):
index_dtype=index_dtype, index_dtype=index_dtype,
options=options, options=options,
state=state, state=state,
target=target) target=target,
overridden_get_grid_sizes_for_insn_ids=(
overridden_get_grid_sizes_for_insn_ids))
self._kernel_executor_cache = {} self._kernel_executor_cache = {}
...@@ -923,6 +921,11 @@ class LoopKernel(ImmutableRecordWithoutPickling): ...@@ -923,6 +921,11 @@ class LoopKernel(ImmutableRecordWithoutPickling):
*global_size* and *local_size* are :class:`islpy.PwAff` objects. *global_size* and *local_size* are :class:`islpy.PwAff` objects.
""" """
if self.overridden_get_grid_sizes_for_insn_ids:
return self.overridden_get_grid_sizes_for_insn_ids(
insn_ids,
ignore_auto=ignore_auto)
all_inames_by_insns = set() all_inames_by_insns = set()
for insn_id in insn_ids: for insn_id in insn_ids:
all_inames_by_insns |= self.insn_inames(insn_id) all_inames_by_insns |= self.insn_inames(insn_id)
......
...@@ -439,7 +439,8 @@ class DomainChanger: ...@@ -439,7 +439,8 @@ class DomainChanger:
# Changing the domain might look like it wants to change grid # Changing the domain might look like it wants to change grid
# sizes. Not true. # sizes. Not true.
# (Relevant for 'slab decomposition') # (Relevant for 'slab decomposition')
get_grid_sizes_for_insn_ids=self.kernel.get_grid_sizes_for_insn_ids) overridden_get_grid_sizes_for_insn_ids=(
self.kernel.get_grid_sizes_for_insn_ids))
# }}} # }}}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment