diff --git a/loopy/statistics.py b/loopy/statistics.py
index af9d6d47dc53e2e1ef6d585d368365f5ffa41c57..a0b3a0165b234610f0e78246786d6ef7bb39747c 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -38,6 +38,7 @@ __doc__ = """
 
 .. currentmodule:: loopy
 
+.. autoclass:: GuardedPwQPolynomial
 .. autoclass:: ToCountMap
 .. autoclass:: Op
 .. autoclass:: MemAccess
@@ -52,6 +53,60 @@ __doc__ = """
 """
 
 
+# {{{ GuardedPwQPolynomial
+
+class GuardedPwQPolynomial(object):
+    def __init__(self, pwqpolynomial, valid_domain):
+        self.pwqpolynomial = pwqpolynomial
+        self.valid_domain = valid_domain
+
+    def __add__(self, other):
+        if isinstance(other, GuardedPwQPolynomial):
+            return GuardedPwQPolynomial(
+                    self.pwqpolynomial + other.pwqpolynomial,
+                    self.valid_domain & other.valid_domain)
+        else:
+            return GuardedPwQPolynomial(
+                    self.pwqpolynomial + other,
+                    self.valid_domain)
+
+    __radd__ = __add__
+
+    def __mul__(self, other):
+        if isinstance(other, GuardedPwQPolynomial):
+            return GuardedPwQPolynomial(
+                    self.pwqpolynomial * other.pwqpolynomial,
+                    self.valid_domain & other.valid_domain)
+        else:
+            return GuardedPwQPolynomial(
+                    self.pwqpolynomial * other,
+                    self.valid_domain)
+
+    __rmul__ = __mul__
+
+    def eval_with_dict(self, value_dict):
+        space = self.pwqpolynomial.space
+        pt = isl.Point.zero(space.params())
+
+        for i in range(space.dim(dim_type.param)):
+            par_name = space.get_dim_name(dim_type.param, i)
+            pt = pt.set_coordinate_val(
+                dim_type.param, i, value_dict[par_name])
+
+        if not (isl.Set.from_point(pt) <= self.valid_domain):
+            raise ValueError("evaluation point outside of domain of "
+                    "definition of piecewise quasipolynomial")
+
+        return self.pwqpolynomial.eval(pt).to_python()
+
+    @staticmethod
+    def zero():
+        p = isl.PwQPolynomial('{ 0 }')
+        return GuardedPwQPolynomial(p, isl.Set.universe(p.domain().space))
+
+# }}}
+
+
 # {{{ ToCountMap
 
 class ToCountMap(object):
@@ -66,7 +121,7 @@ class ToCountMap(object):
 
     """
 
-    def __init__(self, init_dict=None, val_type=isl.PwQPolynomial):
+    def __init__(self, init_dict=None, val_type=GuardedPwQPolynomial):
         if init_dict is None:
             init_dict = {}
         self.count_map = init_dict
@@ -87,7 +142,7 @@ class ToCountMap(object):
         return self
 
     def __mul__(self, other):
-        if isinstance(other, isl.PwQPolynomial):
+        if isinstance(other, GuardedPwQPolynomial):
             return ToCountMap(dict(
                 (index, self.count_map[index]*other)
                 for index in self.keys()))
@@ -103,8 +158,8 @@ class ToCountMap(object):
             return self.count_map[index]
         except KeyError:
             #TODO what is the best way to handle this?
-            if self.val_type is isl.PwQPolynomial:
-                return isl.PwQPolynomial('{ 0 }')
+            if self.val_type is GuardedPwQPolynomial:
+                return GuardedPwQPolynomial.zero()
             else:
                 return 0
 
@@ -342,8 +397,8 @@ class ToCountMap(object):
 
         """
 
-        if self.val_type is isl.PwQPolynomial:
-            total = isl.PwQPolynomial('{ 0 }')
+        if self.val_type is GuardedPwQPolynomial:
+            total = GuardedPwQPolynomial.zero()
         else:
             total = 0
 
@@ -929,9 +984,13 @@ class AccessFootprintGatherer(CombineMapper):
 
 # {{{ count
 
+def add_assumptions_guard(kernel, pwqpolynomial):
+    return GuardedPwQPolynomial(pwqpolynomial, kernel.assumptions)
+
+
 def count(kernel, set, space=None):
     try:
-        return set.card()
+        return add_assumptions_guard(kernel, set.card())
     except AttributeError:
         pass
 
@@ -1022,7 +1081,7 @@ def count(kernel, set, space=None):
                         "number of integer points in your loop "
                         "domain.")
 
-    return count
+    return add_assumptions_guard(kernel, count)
 
 
 def get_unused_hw_axes_factor(knl, insn, disregard_local_axes, space=None):
@@ -1056,9 +1115,11 @@ def get_unused_hw_axes_factor(knl, insn, disregard_local_axes, space=None):
         return result
 
     if disregard_local_axes:
-        return mult_grid_factor(g_used, gsize)
+        result = mult_grid_factor(g_used, gsize)
     else:
-        return mult_grid_factor(g_used, gsize) * mult_grid_factor(l_used, lsize)
+        result = mult_grid_factor(g_used, gsize) * mult_grid_factor(l_used, lsize)
+
+    return add_assumptions_guard(knl, result)
 
 
 def count_insn_runs(knl, insn, count_redundant_work, disregard_local_axes=False):