diff --git a/loopy/statistics.py b/loopy/statistics.py index af9d6d47dc53e2e1ef6d585d368365f5ffa41c57..a0b3a0165b234610f0e78246786d6ef7bb39747c 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -38,6 +38,7 @@ __doc__ = """ .. currentmodule:: loopy +.. autoclass:: GuardedPwQPolynomial .. autoclass:: ToCountMap .. autoclass:: Op .. autoclass:: MemAccess @@ -52,6 +53,60 @@ __doc__ = """ """ +# {{{ GuardedPwQPolynomial + +class GuardedPwQPolynomial(object): + def __init__(self, pwqpolynomial, valid_domain): + self.pwqpolynomial = pwqpolynomial + self.valid_domain = valid_domain + + def __add__(self, other): + if isinstance(other, GuardedPwQPolynomial): + return GuardedPwQPolynomial( + self.pwqpolynomial + other.pwqpolynomial, + self.valid_domain & other.valid_domain) + else: + return GuardedPwQPolynomial( + self.pwqpolynomial + other, + self.valid_domain) + + __radd__ = __add__ + + def __mul__(self, other): + if isinstance(other, GuardedPwQPolynomial): + return GuardedPwQPolynomial( + self.pwqpolynomial * other.pwqpolynomial, + self.valid_domain & other.valid_domain) + else: + return GuardedPwQPolynomial( + self.pwqpolynomial * other, + self.valid_domain) + + __rmul__ = __mul__ + + def eval_with_dict(self, value_dict): + space = self.pwqpolynomial.space + pt = isl.Point.zero(space.params()) + + for i in range(space.dim(dim_type.param)): + par_name = space.get_dim_name(dim_type.param, i) + pt = pt.set_coordinate_val( + dim_type.param, i, value_dict[par_name]) + + if not (isl.Set.from_point(pt) <= self.valid_domain): + raise ValueError("evaluation point outside of domain of " + "definition of piecewise quasipolynomial") + + return self.pwqpolynomial.eval(pt).to_python() + + @staticmethod + def zero(): + p = isl.PwQPolynomial('{ 0 }') + return GuardedPwQPolynomial(p, isl.Set.universe(p.domain().space)) + +# }}} + + # {{{ ToCountMap class ToCountMap(object): @@ -66,7 +121,7 @@ class ToCountMap(object): """ - def __init__(self, init_dict=None, val_type=isl.PwQPolynomial): + def __init__(self, init_dict=None, val_type=GuardedPwQPolynomial): if init_dict is None: init_dict = {} self.count_map = init_dict @@ -87,7 +142,7 @@ class ToCountMap(object): return self def __mul__(self, other): - if isinstance(other, isl.PwQPolynomial): + if isinstance(other, GuardedPwQPolynomial): return ToCountMap(dict( (index, self.count_map[index]*other) for index in self.keys())) @@ -103,8 +158,8 @@ class ToCountMap(object): return self.count_map[index] except KeyError: #TODO what is the best way to handle this? - if self.val_type is isl.PwQPolynomial: - return isl.PwQPolynomial('{ 0 }') + if self.val_type is GuardedPwQPolynomial: + return GuardedPwQPolynomial.zero() else: return 0 @@ -342,8 +397,8 @@ class ToCountMap(object): """ - if self.val_type is isl.PwQPolynomial: - total = isl.PwQPolynomial('{ 0 }') + if self.val_type is GuardedPwQPolynomial: + total = GuardedPwQPolynomial.zero() else: total = 0 @@ -929,9 +984,13 @@ class AccessFootprintGatherer(CombineMapper): # {{{ count +def add_assumptions_guard(kernel, pwqpolynomial): + return GuardedPwQPolynomial(pwqpolynomial, kernel.assumptions) + + def count(kernel, set, space=None): try: - return set.card() + return add_assumptions_guard(kernel, set.card()) except AttributeError: pass @@ -1022,7 +1081,7 @@ def count(kernel, set, space=None): "number of integer points in your loop " "domain.") - return count + return add_assumptions_guard(kernel, count) def get_unused_hw_axes_factor(knl, insn, disregard_local_axes, space=None): @@ -1056,9 +1115,11 @@ def get_unused_hw_axes_factor(knl, insn, disregard_local_axes, space=None): return result if disregard_local_axes: - return mult_grid_factor(g_used, gsize) + result = mult_grid_factor(g_used, gsize) else: - return mult_grid_factor(g_used, gsize) * mult_grid_factor(l_used, lsize) + result = mult_grid_factor(g_used, gsize) * mult_grid_factor(l_used, lsize) + + return add_assumptions_guard(knl, result) def count_insn_runs(knl, insn, count_redundant_work, disregard_local_axes=False):