Skip to content
Snippets Groups Projects
Commit 12452330 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Cache result of dim_{max,min}.

parent 73b991f1
No related branches found
No related tags found
No related merge requests found
......@@ -41,8 +41,6 @@ To-do
- CSE should be more like variable assignment
- dim_max caching
- Fix all tests
- Deal with equality constraints.
......@@ -88,6 +86,8 @@ Future ideas
Dealt with
^^^^^^^^^^
- dim_max caching
- Exhaust the search for a no-boost solution first, before looking
for a schedule with boosts.
......
......@@ -168,7 +168,7 @@ def make_kernel(*args, **kwargs):
base_indices, shape = \
find_var_base_indices_and_shape_from_inames(
new_domain, assignee_indices)
new_domain, assignee_indices, knl.cache_manager)
new_temp_vars[assignee_name] = TemporaryVariable(
name=assignee_name,
......
......@@ -498,7 +498,8 @@ def realize_cse(kernel, cse_tag, dtype, independent_inames=[],
target_var_base_indices, target_var_shape = \
find_var_base_indices_and_shape_from_inames(
new_domain, independent_inames)
new_domain, independent_inames,
kernel.cache_manager)
new_temporary_variables = kernel.temporary_variables.copy()
new_temporary_variables[target_var_name] = TemporaryVariable(
......
......@@ -452,6 +452,8 @@ class LoopKernel(Record):
workgroup axes to ther sizes, e.g. *{0: 16}* forces axis 0 to be
length 16.
:ivar cache_manager:
The following instance variables are only used until :func:`loopy.kernel.make_kernel` is
finished:
......@@ -466,7 +468,8 @@ class LoopKernel(Record):
iname_slab_increments={},
temporary_variables={},
local_sizes={},
iname_to_tag={}, iname_to_tag_requests=None, cses={}, substitutions={}):
iname_to_tag={}, iname_to_tag_requests=None, cses={}, substitutions={},
cache_manager=None):
"""
:arg domain: a :class:`islpy.BasicSet`, or a string parseable to a basic set by the isl.
Example: "{[i,j]: 0<=i < 10 and 0<= j < 9}"
......@@ -475,6 +478,9 @@ class LoopKernel(Record):
import re
if cache_manager is None:
cache_manager = SetOperationCacheManager()
if isinstance(domain, str):
ctx = isl.Context()
domain = isl.Set.read_from_str(ctx, domain)
......@@ -650,7 +656,8 @@ class LoopKernel(Record):
local_sizes=local_sizes,
iname_to_tag=iname_to_tag,
iname_to_tag_requests=iname_to_tag_requests,
cses=cses, substitutions=substitutions)
cses=cses, substitutions=substitutions,
cache_manager=cache_manager)
def make_unique_instruction_id(self, insns=None, based_on="insn", extra_used_ids=set()):
if insns is None:
......@@ -863,12 +870,14 @@ class LoopKernel(Record):
isl.align_spaces(self.assumptions, self.domain)
& self.domain)
lower_bound_pw_aff = (
dom_intersect_assumptions
.dim_min(self.iname_to_dim[iname][1])
self.cache_manager.dim_min(
dom_intersect_assumptions,
self.iname_to_dim[iname][1])
.coalesce())
upper_bound_pw_aff = (
dom_intersect_assumptions
.dim_max(self.iname_to_dim[iname][1])
self.cache_manager.dim_max(
dom_intersect_assumptions,
self.iname_to_dim[iname][1])
.coalesce())
class BoundsRecord(Record):
......@@ -1035,14 +1044,14 @@ class LoopKernel(Record):
def find_var_base_indices_and_shape_from_inames(domain, inames):
def find_var_base_indices_and_shape_from_inames(domain, inames, cache_manager):
base_indices = []
shape = []
iname_to_dim = domain.get_space().get_var_dict()
for iname in inames:
lower_bound_pw_aff = domain.dim_min(iname_to_dim[iname][1])
upper_bound_pw_aff = domain.dim_max(iname_to_dim[iname][1])
lower_bound_pw_aff = cache_manager.dim_min(domain, iname_to_dim[iname][1])
upper_bound_pw_aff = cache_manager.dim_max(domain, iname_to_dim[iname][1])
from loopy.isl_helpers import static_max_of_pw_aff
from loopy.symbolic import pw_aff_to_expr
......@@ -1078,4 +1087,30 @@ def get_dot_dependency_graph(kernel, iname_cluster=False, iname_edge=True):
class SetOperationCacheManager:
def __init__(self):
# mapping: set hash -> [(set, op, args, result)]
self.cache = {}
def op(self, set, op, args):
hashval = hash(set)
bucket = self.cache.setdefault(hashval, [])
for bkt_set, bkt_op, bkt_args, result in bucket:
if set.plain_is_equal(bkt_set) and op == bkt_op and args == bkt_args:
return result
result = getattr(set, op)(*args)
bucket.append((set, op, args, result))
return result
def dim_min(self, set, *args):
return self.op(set, "dim_min", args)
def dim_max(self, set, *args):
return self.op(set, "dim_max", args)
# vim: foldmethod=marker
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment