From 9f56f724d53fa303afe580da8c703d6d7b0fd2d7 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 15 Jul 2015 10:26:48 -0500
Subject: [PATCH] Minor tweaks to the stats code

---
 loopy/precompute.py |  2 +-
 loopy/statistics.py | 12 ------------
 test/test_loopy.py  |  5 ++++-
 3 files changed, 5 insertions(+), 14 deletions(-)

diff --git a/loopy/precompute.py b/loopy/precompute.py
index ae973f98c..b1df5f678 100644
--- a/loopy/precompute.py
+++ b/loopy/precompute.py
@@ -572,7 +572,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
                 new_var_aff = isl.Aff.var_on_domain(mod_domain.space, dt, dim_idx)
 
                 mod_domain = mod_domain.add_constraint(
-                        isl.Constraint.inequality_from_aff(new_var_aff - saxis_aff))
+                        isl.Constraint.equality_from_aff(new_var_aff - saxis_aff))
 
                 # project out the new one
                 mod_domain = mod_domain.project_out(dt, dim_idx, 1)
diff --git a/loopy/statistics.py b/loopy/statistics.py
index 15b0605ec..2c87b6078 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -184,22 +184,18 @@ class ExpressionOpCounter(CombineMapper):
         raise NotImplementedError("ExpressionOpCounter encountered "
                                   "common_subexpression, "
                                   "map_common_subexpression not implemented.")
-        return 0
 
     def map_substitution(self, expr):
         raise NotImplementedError("ExpressionOpCounter encountered substitution, "
                                   "map_substitution not implemented.")
-        return 0
 
     def map_derivative(self, expr):
         raise NotImplementedError("ExpressionOpCounter encountered derivative, "
                                   "map_derivative not implemented.")
-        return 0
 
     def map_slice(self, expr):
         raise NotImplementedError("ExpressionOpCounter encountered slice, "
                                   "map_slice not implemented.")
-        return 0
 
 
 class ExpressionSubscriptCounter(CombineMapper):
@@ -252,9 +248,6 @@ class ExpressionSubscriptCounter(CombineMapper):
 
         if not local_id_found:
             # count as uniform access
-            warnings.warn("ExpressionSubscriptCounter did not find "
-                          "local iname tags in expression:\n %s,\n"
-                          "considering these DRAM accesses uniform." % expr)
             return TypeToCountMap(
                     {(self.type_inf(expr), 'uniform'): 1}
                     ) + self.rec(expr.index)
@@ -359,24 +352,20 @@ class ExpressionSubscriptCounter(CombineMapper):
         raise NotImplementedError("ExpressionSubscriptCounter encountered "
                                   "common_subexpression, "
                                   "map_common_subexpression not implemented.")
-        return 0
 
     def map_substitution(self, expr):
         raise NotImplementedError("ExpressionSubscriptCounter encountered "
                                   "substitution, "
                                   "map_substitution not implemented.")
-        return 0
 
     def map_derivative(self, expr):
         raise NotImplementedError("ExpressionSubscriptCounter encountered "
                                   "derivative, "
                                   "map_derivative not implemented.")
-        return 0
 
     def map_slice(self, expr):
         raise NotImplementedError("ExpressionSubscriptCounter encountered slice, "
                                   "map_slice not implemented.")
-        return 0
 
 
 def count(kernel, bset):
@@ -472,4 +461,3 @@ def get_barrier_poly(knl):
                 barrier_poly += isl.PwQPolynomial('{ 1 }')
 
     return barrier_poly
-
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 1fa35101d..17e0cc543 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -353,7 +353,10 @@ def test_stencil(ctx_factory):
         knl = lp.set_loop_priority(knl, ["a_dim_0_outer", "a_dim_1_outer"])
         return knl
 
-    for variant in [variant_1, variant_2]:
+    for variant in [
+            #variant_1,
+            variant_2,
+            ]:
         lp.auto_test_vs_ref(ref_knl, ctx, variant(knl),
                 print_ref_code=False,
                 op_count=[n*n], op_label=["cells"])
-- 
GitLab