Skip to content
Snippets Groups Projects
Commit 9bde2387 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Guard assumptions on PwQPolynomials returned from loopy.statistics

(rather than rely on the PwQPolynomial's domain, which only captures
where the corresponding poly is non-zero)
parent 40001b02
No related branches found
No related tags found
1 merge request!121Stats: take into account unused hw axes in run count, refactor code
......@@ -38,6 +38,7 @@ __doc__ = """
.. currentmodule:: loopy
.. autoclass:: GuardedPwQPolynomial
.. autoclass:: ToCountMap
.. autoclass:: Op
.. autoclass:: MemAccess
......@@ -52,6 +53,60 @@ __doc__ = """
"""
# {{{ GuardedPwQPolynomial
class GuardedPwQPolynomial(object):
def __init__(self, pwqpolynomial, valid_domain):
self.pwqpolynomial = pwqpolynomial
self.valid_domain = valid_domain
def __add__(self, other):
if isinstance(other, GuardedPwQPolynomial):
return GuardedPwQPolynomial(
self.pwqpolynomial + other.pwqpolynomial,
self.valid_domain & other.valid_domain)
else:
return GuardedPwQPolynomial(
self.pwqpolynomial + other,
self.valid_domain)
__radd__ = __add__
def __mul__(self, other):
if isinstance(other, GuardedPwQPolynomial):
return GuardedPwQPolynomial(
self.pwqpolynomial * other.pwqpolynomial,
self.valid_domain & other.valid_domain)
else:
return GuardedPwQPolynomial(
self.pwqpolynomial * other,
self.valid_domain)
__rmul__ = __mul__
def eval_with_dict(self, value_dict):
space = self.pwqpolynomial.space
pt = isl.Point.zero(space.params())
for i in range(space.dim(dim_type.param)):
par_name = space.get_dim_name(dim_type.param, i)
pt = pt.set_coordinate_val(
dim_type.param, i, value_dict[par_name])
if not (isl.Set.from_point(pt) <= self.valid_domain):
raise ValueError("evaluation point outside of domain of "
"definition of piecewise quasipolynomial")
return self.pwqpolynomial.eval(pt).to_python()
@staticmethod
def zero():
p = isl.PwQPolynomial('{ 0 }')
return GuardedPwQPolynomial(p, isl.Set.universe(p.domain().space))
# }}}
# {{{ ToCountMap
class ToCountMap(object):
......@@ -66,7 +121,7 @@ class ToCountMap(object):
"""
def __init__(self, init_dict=None, val_type=isl.PwQPolynomial):
def __init__(self, init_dict=None, val_type=GuardedPwQPolynomial):
if init_dict is None:
init_dict = {}
self.count_map = init_dict
......@@ -87,7 +142,7 @@ class ToCountMap(object):
return self
def __mul__(self, other):
if isinstance(other, isl.PwQPolynomial):
if isinstance(other, GuardedPwQPolynomial):
return ToCountMap(dict(
(index, self.count_map[index]*other)
for index in self.keys()))
......@@ -103,8 +158,8 @@ class ToCountMap(object):
return self.count_map[index]
except KeyError:
#TODO what is the best way to handle this?
if self.val_type is isl.PwQPolynomial:
return isl.PwQPolynomial('{ 0 }')
if self.val_type is GuardedPwQPolynomial:
return GuardedPwQPolynomial.zero()
else:
return 0
......@@ -342,8 +397,8 @@ class ToCountMap(object):
"""
if self.val_type is isl.PwQPolynomial:
total = isl.PwQPolynomial('{ 0 }')
if self.val_type is GuardedPwQPolynomial:
total = GuardedPwQPolynomial.zero()
else:
total = 0
......@@ -929,9 +984,13 @@ class AccessFootprintGatherer(CombineMapper):
# {{{ count
def add_assumptions_guard(kernel, pwqpolynomial):
return GuardedPwQPolynomial(pwqpolynomial, kernel.assumptions)
def count(kernel, set, space=None):
try:
return set.card()
return add_assumptions_guard(kernel, set.card())
except AttributeError:
pass
......@@ -1022,7 +1081,7 @@ def count(kernel, set, space=None):
"number of integer points in your loop "
"domain.")
return count
return add_assumptions_guard(kernel, count)
def get_unused_hw_axes_factor(knl, insn, disregard_local_axes, space=None):
......@@ -1056,9 +1115,11 @@ def get_unused_hw_axes_factor(knl, insn, disregard_local_axes, space=None):
return result
if disregard_local_axes:
return mult_grid_factor(g_used, gsize)
result = mult_grid_factor(g_used, gsize)
else:
return mult_grid_factor(g_used, gsize) * mult_grid_factor(l_used, lsize)
result = mult_grid_factor(g_used, gsize) * mult_grid_factor(l_used, lsize)
return add_assumptions_guard(knl, result)
def count_insn_runs(knl, insn, count_redundant_work, disregard_local_axes=False):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment