From dc953acd0473d091d01e1c8559a8df606426834c Mon Sep 17 00:00:00 2001
From: James Stevens <jdsteve2@illinois.edu>
Date: Wed, 13 May 2015 02:11:41 -0500
Subject: [PATCH] op counter fixes, tests updated

---
 loopy/statistics.py     | 106 +++++++++++++++-------------------------
 test/test_statistics.py |   8 +--
 2 files changed, 42 insertions(+), 72 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index 6c6a3a8a4..d1108dd2a 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -33,35 +33,34 @@ from pymbolic.mapper import CombineMapper
 
 class TypeToOpCountMap:
 
-    def __init__(self):
-        self.dict = {}
+    def __init__(self, init_dict=None):
+        if init_dict is None:
+            self.dict = {}
+        else:
+            self.dict = init_dict
 
     def __add__(self, other):
-        result = TypeToOpCountMap()
-        result.dict = dict(self.dict.items() + other.dict.items()
-                           + [(k, self.dict[k] + other.dict[k])
-                           for k in set(self.dict) & set(other.dict)])
-        return result
+        return TypeToOpCountMap(dict(self.dict.items() + other.dict.items()
+                                     + [(k, self.dict[k] + other.dict[k])
+                                     for k in set(self.dict) & set(other.dict)]))
 
     def __radd__(self, other):
         if (other != 0):
-            message = "TypeToOpCountMap: Attempted to add TypeToOpCountMap to " + \
-                      str(type(other)) + " " + str(other) + ". TypeToOpCountMap " + \
-                      "may only be added to 0 and other TypeToOpCountMap objects."
-            raise ValueError(message)
+            raise ValueError("TypeToOpCountMap: Attempted to add TypeToOpCountMap "
+                                "to {} {}. TypeToOpCountMap may only be added to "
+                                "0 and other TypeToOpCountMap objects."
+                                .format(type(other), other))
             return
         return self
 
     def __mul__(self, other):
         if isinstance(other, isl.PwQPolynomial):
-            result = TypeToOpCountMap()
-            for index in self.dict.keys():
-                result.dict[index] = self.dict[index]*other
-            return result
+            return TypeToOpCountMap({index: self.dict[index]*other
+                                     for index in self.dict.keys()})
         else:
-            message = "TypeToOpCountMap: Attempted to multiply TypeToOpCountMap by " + \
-                      str(type(other)) + " " + str(other) + "."
-            raise ValueError(message)
+            raise ValueError("TypeToOpCountMap: Attempted to multiply "
+                                "TypeToOpCountMap by {} {}."
+                                .format(type(other), other))
 
     __rmul__ = __mul__
 
@@ -82,11 +81,8 @@ class ExpressionOpCounter(CombineMapper):
     def map_constant(self, expr):
         return TypeToOpCountMap()
 
-    def map_tagged_variable(self, expr):
-        return TypeToOpCountMap()
-
-    def map_variable(self, expr):
-        return TypeToOpCountMap()
+    map_tagged_variable = map_constant
+    map_variable = map_constant
 
     #def map_wildcard(self, expr):
     #    return 0,0
@@ -94,9 +90,7 @@ class ExpressionOpCounter(CombineMapper):
     #def map_function_symbol(self, expr):
     #    return 0,0
 
-    def map_call(self, expr):
-        # implemented in CombineMapper (functions in opencl spec)
-        return TypeToOpCountMap()
+    map_call = map_constant
 
     # def map_call_with_kwargs(self, expr):  # implemented in CombineMapper
 
@@ -106,60 +100,51 @@ class ExpressionOpCounter(CombineMapper):
     # def map_lookup(self, expr):  # implemented in CombineMapper
 
     def map_sum(self, expr):
-        op_count_map = TypeToOpCountMap()
-        op_count_map.dict[self.type_inf(expr)] = len(expr.children)-1
         if expr.children:
-            return op_count_map + sum(self.rec(child) for child in expr.children)
+            return TypeToOpCountMap(
+                        {self.type_inf(expr): len(expr.children)-1}
+                        ) + sum(self.rec(child) for child in expr.children)
         else:
             return TypeToOpCountMap()
 
     map_product = map_sum
 
     def map_quotient(self, expr, *args):
-        op_count_map = TypeToOpCountMap()
-        op_count_map.dict[self.type_inf(expr)] = 1
-        return op_count_map + self.rec(expr.numerator) + self.rec(expr.denominator)
+        return TypeToOpCountMap({self.type_inf(expr): 1}) \
+                                + self.rec(expr.numerator) \
+                                + self.rec(expr.denominator)
 
     map_floor_div = map_quotient
-
-    def map_remainder(self, expr):  # implemented in CombineMapper
-        op_count_map = TypeToOpCountMap()
-        op_count_map.dict[self.type_inf(expr)] = 1
-        return op_count_map + self.rec(expr.numerator)+self.rec(expr.denominator)
+    map_remainder = map_quotient  # implemented in CombineMapper
 
     def map_power(self, expr):
-        op_count_map = TypeToOpCountMap()
-        op_count_map.dict[self.type_inf(expr)] = 1
-        return op_count_map + self.rec(expr.base)+self.rec(expr.exponent)
+        return TypeToOpCountMap({self.type_inf(expr): 1}) \
+                                + self.rec(expr.base) \
+                                + self.rec(expr.exponent)
 
     def map_left_shift(self, expr):  # implemented in CombineMapper
-        return self.rec(expr.shiftee)+self.rec(expr.shift)  # TODO test
+        return self.rec(expr.shiftee)+self.rec(expr.shift)
 
-    map_right_shift = map_left_shift  # TODO test
+    map_right_shift = map_left_shift
 
-    def map_bitwise_not(self, expr):  # implemented in CombineMapper # TODO test
+    def map_bitwise_not(self, expr):  # implemented in CombineMapper
         return self.rec(expr.child)
 
     def map_bitwise_or(self, expr):
-        # implemented in CombineMapper, maps to map_sum; # TODO test
+        # implemented in CombineMapper, maps to map_sum;
         return sum(self.rec(child) for child in expr.children)
 
     map_bitwise_xor = map_bitwise_or
-    # implemented in CombineMapper, maps to map_sum; # TODO test
+    # implemented in CombineMapper, maps to map_sum;
 
     map_bitwise_and = map_bitwise_or
-    # implemented in CombineMapper, maps to map_sum; # TODO test
+    # implemented in CombineMapper, maps to map_sum;
 
     def map_comparison(self, expr):  # implemented in CombineMapper
         return self.rec(expr.left)+self.rec(expr.right)
 
-    def map_logical_not(self, expr):
-        # implemented in CombineMapper, maps to bitwise_not
-        return self.rec(expr.child)
-
-    def map_logical_or(self, expr):  # implemented in CombineMapper, maps to map_sum
-        return sum(self.rec(child) for child in expr.children)
-
+    map_logical_not = map_bitwise_not
+    map_logical_or = map_bitwise_or  # implemented in CombineMapper, maps to map_sum
     map_logical_and = map_logical_or
 
     def map_if(self, expr):  # implemented in CombineMapper, recurses
@@ -170,9 +155,8 @@ class ExpressionOpCounter(CombineMapper):
         warnings.warn("Counting operations as sum of if_pos-statement branches.")
         return self.rec(expr.criterion) + self.rec(expr.then) + self.rec(expr.else_)
 
-    def map_min(self, expr):
-        # implemented in CombineMapper, maps to map_sum;  # TODO test
-        return sum(self.rec(child) for child in expr.children)
+    map_min = map_bitwise_or
+    # implemented in CombineMapper, maps to map_sum;  # TODO test
 
     map_max = map_min  # implemented in CombineMapper, maps to map_sum;  # TODO test
 
@@ -224,16 +208,6 @@ class SubscriptCounter(CombineMapper):
     def map_variable(self, expr):
         return 0
 
-#TODO find stride looking in ArrayBase.dim tag
-'''
-for each instruction, find which iname is associated with local id0 (iname_to_tag)
-then for each array axis in that instruction, run through all axes and see if local id0 iname occurs
-for each axis where this occurs, see if stride=1 (using coefficient collecter)
-
-variable has dimTags (one for each axis), 
-localid 0 is threadidx.x
-'''
-
 
 # to evaluate poly: poly.eval_with_dict(dictionary)
 def get_op_poly(knl):
diff --git a/test/test_statistics.py b/test/test_statistics.py
index 84f6671b4..54d08a498 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -137,17 +137,13 @@ def test_op_counter_bitwise(ctx_factory):
             name="bitwise", assumptions="n,m,l >= 1")
 
     knl = lp.add_and_infer_dtypes(knl,
-                        dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
+                        dict(a=np.int32, b=np.int32, g=np.int64, h=np.int64))
     poly = get_op_poly(knl)
     n = 512
     m = 256
     l = 128
-    '''
-    f32 = poly.dict[np.dtype(np.float32)].eval_with_dict({'n': n, 'm': m, 'l': l})
-    f64 = poly.dict[np.dtype(np.float64)].eval_with_dict({'n': n, 'm': m, 'l': l})
     i32 = poly.dict[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l})
-    '''
-    # TODO figure out how these operations should be counted
+    assert i32 == 3*n*m+n*m*l
 
 if __name__ == "__main__":
     if len(sys.argv) > 1:
-- 
GitLab