From 06196fa5eec1a00e639a353865fd2d0815c51124 Mon Sep 17 00:00:00 2001
From: Machine Owner <owner@debian.lan>
Date: Thu, 30 Apr 2015 12:10:19 -0500
Subject: [PATCH] removed hardcoded datatypes from TypeToOpCountMap, updated
 tests

---
 loopy/statistics.py     | 113 ++++++++++++++++--------------------
 test/test_statistics.py | 123 ++++++++++++++++++++++++----------------
 2 files changed, 124 insertions(+), 112 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index 0add67a52..d514e852f 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -24,7 +24,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-import numpy as np
 import loopy as lp
 import warnings
 from islpy import dim_type
@@ -32,46 +31,42 @@ import islpy._isl as isl
 from pymbolic.mapper import CombineMapper
 
 
-class TypedPolyDict:
+class TypeToOpCountMap:
 
-    def __init__(self, i32=0, f32=0, f64=0):
-        self.poly_dict = {
-                            np.dtype(np.int32): i32,
-                            np.dtype(np.float32): f32,
-                            np.dtype(np.float64): f64}
+    def __init__(self):
+        self.dict = {}
 
-    def __add__(self, TPD):
-        return TypedPolyDict(
-            self[np.dtype(np.int32)]+TPD[np.dtype(np.int32)],
-            self[np.dtype(np.float32)]+TPD[np.dtype(np.float32)],
-            self[np.dtype(np.float64)]+TPD[np.dtype(np.float64)])
+    def __add__(self, other):
+        result = TypeToOpCountMap()
+        result.dict = dict(self.dict.items() + other.dict.items()
+                           + [(k, self.dict[k] + other.dict[k])
+                           for k in set(self.dict) & set(other.dict)])
+        return result
 
     def __radd__(self, other):
         if (other != 0):
-            print "ERROR TRYING TO ADD TPD TO NON-ZERO NON-TPD"  # TODO
+            message = "TypeToOpCountMap: Attempted to add TypeToOpCountMap to " + \
+                      str(type(other)) + " " + str(other) + ". TypeToOpCountMap " + \
+                      "may only be added to 0 and other TypeToOpCountMap objects."
+            raise ValueError(message)
             return
         return self
 
     def __mul__(self, other):
         if isinstance(other, isl.PwQPolynomial):
-            return TypedPolyDict(
-                self[np.dtype(np.int32)]*other,
-                self[np.dtype(np.float32)]*other,
-                self[np.dtype(np.float64)]*other)
+            result = TypeToOpCountMap()
+            for index in self.dict.keys():
+                result.dict[index] = self.dict[index]*other
+            return result
         else:
-            # TODO
-            print "ERROR: Cannot multiply TypedPolyDict by type ", type(other)
+            message = "TypeToOpCountMap: Attempted to multiply TypeToOpCountMap by " + \
+                      str(type(other)) + " " + str(other) + "."
+            raise ValueError(message)
 
     __rmul__ = __mul__
 
-    def __getitem__(self, index):
-        return self.poly_dict[index]
-
-    def __setitem__(self, index, value):
-        self.poly_dict[index] = value
-
     def __str__(self):
-        return str(self.poly_dict)
+        return str(self.dict)
 
 
 class ExpressionOpCounter(CombineMapper):
@@ -85,13 +80,13 @@ class ExpressionOpCounter(CombineMapper):
         return sum(values)
 
     def map_constant(self, expr):
-        return TypedPolyDict(0, 0, 0)
+        return TypeToOpCountMap()
 
     def map_tagged_variable(self, expr):
-        return TypedPolyDict(0, 0, 0)
+        return TypeToOpCountMap()
 
-    def map_variable(self, expr):   # implemented in FlopCounter
-        return TypedPolyDict(0, 0, 0)
+    def map_variable(self, expr):
+        return TypeToOpCountMap()
 
     #def map_wildcard(self, expr):
     #    return 0,0
@@ -101,7 +96,7 @@ class ExpressionOpCounter(CombineMapper):
 
     def map_call(self, expr):
         # implemented in CombineMapper (functions in opencl spec)
-        return TypedPolyDict(0, 0, 0)
+        return TypeToOpCountMap()
 
     # def map_call_with_kwargs(self, expr):  # implemented in CombineMapper
 
@@ -110,32 +105,32 @@ class ExpressionOpCounter(CombineMapper):
 
     # def map_lookup(self, expr):  # implemented in CombineMapper
 
-    def map_sum(self, expr):  # implemented in FlopCounter
-        TPD = TypedPolyDict(0, 0, 0)
-        TPD[self.type_inf(expr)] = len(expr.children)-1
+    def map_sum(self, expr):
+        op_count_map = TypeToOpCountMap()
+        op_count_map.dict[self.type_inf(expr)] = len(expr.children)-1
         if expr.children:
-            return TPD + sum(self.rec(child) for child in expr.children)
+            return op_count_map + sum(self.rec(child) for child in expr.children)
         else:
-            return TypedPolyDict(0, 0, 0)
+            return TypeToOpCountMap()
 
     map_product = map_sum
 
     def map_quotient(self, expr, *args):
-        TPD = TypedPolyDict(0, 0, 0)
-        TPD[self.type_inf(expr)] = 1
-        return TPD + self.rec(expr.numerator) + self.rec(expr.denominator)
+        op_count_map = TypeToOpCountMap()
+        op_count_map.dict[self.type_inf(expr)] = 1
+        return op_count_map + self.rec(expr.numerator) + self.rec(expr.denominator)
 
     map_floor_div = map_quotient
 
     def map_remainder(self, expr):  # implemented in CombineMapper
-        TPD = TypedPolyDict(0, 0, 0)
-        TPD[self.type_inf(expr)] = 1
-        return TPD + self.rec(expr.numerator)+self.rec(expr.denominator)
+        op_count_map = TypeToOpCountMap()
+        op_count_map.dict[self.type_inf(expr)] = 1
+        return op_count_map + self.rec(expr.numerator)+self.rec(expr.denominator)
 
-    def map_power(self, expr):  # implemented in FlopCounter
-        TPD = TypedPolyDict(0, 0, 0)
-        TPD[self.type_inf(expr)] = 1
-        return TPD + self.rec(expr.base)+self.rec(expr.exponent)
+    def map_power(self, expr):
+        op_count_map = TypeToOpCountMap()
+        op_count_map.dict[self.type_inf(expr)] = 1
+        return op_count_map + self.rec(expr.base)+self.rec(expr.exponent)
 
     def map_left_shift(self, expr):  # implemented in CombineMapper
         return self.rec(expr.shiftee)+self.rec(expr.shift)  # TODO test
@@ -169,15 +164,10 @@ class ExpressionOpCounter(CombineMapper):
 
     def map_if(self, expr):  # implemented in CombineMapper, recurses
         warnings.warn("Counting operations as sum of if-statement branches.")
-        # return self.rec(expr.condition) + max(
-        #                    self.rec(expr.then), self.rec(expr.else_))
         return self.rec(expr.condition) + self.rec(expr.then) + self.rec(expr.else_)
 
     def map_if_positive(self, expr):  # implemented in FlopCounter
         warnings.warn("Counting operations as sum of if_pos-statement branches.")
-        # return self.rec(expr.criterion) + max(
-        #                                    self.rec(expr.then),
-        #                                    self.rec(expr.else_))
         return self.rec(expr.criterion) + self.rec(expr.then) + self.rec(expr.else_)
 
     def map_min(self, expr):
@@ -187,23 +177,23 @@ class ExpressionOpCounter(CombineMapper):
     map_max = map_min  # implemented in CombineMapper, maps to map_sum;  # TODO test
 
     def map_common_subexpression(self, expr):
-        raise NotImplementedError("OpCounter encountered common_subexpression, \
-                                   map_common_subexpression not implemented.")
+        raise NotImplementedError("OpCounter encountered common_subexpression, "
+                                  "map_common_subexpression not implemented.")
         return 0
 
     def map_substitution(self, expr):
-        raise NotImplementedError("OpCounter encountered substitution, \
-                                    map_substitution not implemented.")
+        raise NotImplementedError("OpCounter encountered substitution, "
+                                  "map_substitution not implemented.")
         return 0
 
     def map_derivative(self, expr):
-        raise NotImplementedError("OpCounter encountered derivative, \
-                                    map_derivative not implemented.")
+        raise NotImplementedError("OpCounter encountered derivative, "
+                                  "map_derivative not implemented.")
         return 0
 
     def map_slice(self, expr):
-        raise NotImplementedError("OpCounter encountered slice, \
-                                    map_slice not implemented.")
+        raise NotImplementedError("OpCounter encountered slice, "
+                                  "map_slice not implemented.")
         return 0
 
 
@@ -226,7 +216,6 @@ class SubscriptCounter(CombineMapper):
             if tv.is_local:
                 # It's shared memory
                 pass
-
         return 1 + self.rec(expr.index)
 
     def map_constant(self, expr):
@@ -238,12 +227,9 @@ class SubscriptCounter(CombineMapper):
 
 # to evaluate poly: poly.eval_with_dict(dictionary)
 def get_op_poly(knl):
-
     from loopy.preprocess import preprocess_kernel, infer_unknown_types
     knl = infer_unknown_types(knl, expect_completion=True)
-
     knl = preprocess_kernel(knl)
-    #print knl
 
     op_poly = 0
     op_counter = ExpressionOpCounter(knl)
@@ -259,6 +245,7 @@ def get_op_poly(knl):
 
 
 def get_DRAM_access_poly(knl):  # for now just counting subscripts
+    raise NotImplementedError("get_DRAM_access_poly not yet implemented.")
     poly = 0
     subscript_counter = SubscriptCounter(knl)
     for insn in knl.instructions:
diff --git a/test/test_statistics.py b/test/test_statistics.py
index a1d81735b..0d3a6f44c 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -26,7 +26,8 @@ import sys
 from pyopencl.tools import (
         pytest_generate_tests_for_pyopencl
         as pytest_generate_tests)
-from loopy.statistics import *
+from loopy.statistics import *  # noqa
+import numpy as np
 
 
 def test_op_counter_basic(ctx_factory):
@@ -34,95 +35,119 @@ def test_op_counter_basic(ctx_factory):
     knl = lp.make_kernel(
             "[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
             [
-            """
-            c[i, j, k] = a[i,j,k]*b[i,j,k]/3.0+a[i,j,k]
-            e[i, k] = g[i,k]*h[i,k]
-            """
+                """
+                c[i, j, k] = a[i,j,k]*b[i,j,k]/3.0+a[i,j,k]
+                e[i, k] = g[i,k]*h[i,k+1]
+                """
             ],
             name="weird", assumptions="n,m,l >= 1")
 
-    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float32, h=np.float32))
+    knl = lp.add_and_infer_dtypes(knl,
+                        dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
     poly = get_op_poly(knl)
-    n=512
-    m=256
-    l=128
-    flops = poly.eval_with_dict({'n':n, 'm':m, 'l':l})
-    assert flops == n*m+3*n*m*l
+    n = 512
+    m = 256
+    l = 128
+    f32 = poly.dict[np.dtype(np.float32)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    f64 = poly.dict[np.dtype(np.float64)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    i32 = poly.dict[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f32 == 3*n*m*l
+    assert f64 == n*m
+    assert i32 == n*m
+
 
 def test_op_counter_reduction(ctx_factory):
 
     knl = lp.make_kernel(
             "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
             [
-            "c[i, j] = sum(k, a[i, k]*b[k, j])"
+                "c[i, j] = sum(k, a[i, k]*b[k, j])"
             ],
             name="matmul", assumptions="n,m,l >= 1")
 
     knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))
     poly = get_op_poly(knl)
-    n=512
-    m=256
-    l=128
-    flops = poly.eval_with_dict({'n':n, 'm':m, 'l':l})
-    assert flops == 2*n*m*l
+    n = 512
+    m = 256
+    l = 128
+    f32 = poly.dict[np.dtype(np.float32)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f32 == 2*n*m*l
+
 
 def test_op_counter_logic(ctx_factory):
 
     knl = lp.make_kernel(
             "[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
             [
-            """
-            e[i,k] = if(not(k < l-2) and k > l+6 or k/2 == l, g[i,k]*h[i,k], g[i,k]+h[i,k]/2.0)
-            """
+                """
+                e[i,k] = if(not(k<l-2) and k>6 or k/2==l, g[i,k]*2, g[i,k]+h[i,k]/2)
+                """
             ],
             name="logic", assumptions="n,m,l >= 1")
 
-    knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float32))
+    knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float64))
     poly = get_op_poly(knl)
-    n=512
-    m=256
-    l=128
-    flops = poly.eval_with_dict({'n':n, 'm':m, 'l':l})
-    assert flops == 5*n*m
+    n = 512
+    m = 256
+    l = 128
+    f32 = poly.dict[np.dtype(np.float32)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    f64 = poly.dict[np.dtype(np.float64)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    i32 = poly.dict[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f32 == n*m
+    assert f64 == 3*n*m
+    assert i32 == n*m
 
-def test_op_counter_remainder(ctx_factory):
+
+def test_op_counter_specialops(ctx_factory):
 
     knl = lp.make_kernel(
             "[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
             [
-            """
-            c[i, j, k] = (2*a[i,j,k])%(2+b[i,j,k]/3.0)
-            """
+                """
+                c[i, j, k] = (2*a[i,j,k])%(2+b[i,j,k]/3.0)
+                e[i, k] = (1+g[i,k])**(1+h[i,k+1])
+                """
             ],
-            name="logic", assumptions="n,m,l >= 1")
+            name="specialops", assumptions="n,m,l >= 1")
 
-    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))
+    knl = lp.add_and_infer_dtypes(knl,
+                        dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
     poly = get_op_poly(knl)
-    n=512
-    m=256
-    l=128
-    flops = poly.eval_with_dict({'n':n, 'm':m, 'l':l})
-    assert flops == 4*n*m*l
+    n = 512
+    m = 256
+    l = 128
+    f32 = poly.dict[np.dtype(np.float32)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    f64 = poly.dict[np.dtype(np.float64)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    i32 = poly.dict[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    assert f32 == 4*n*m*l
+    assert f64 == 3*n*m
+    assert i32 == n*m
+
 
-def test_op_counter_power(ctx_factory):
+def test_op_counter_bitwise(ctx_factory):
 
     knl = lp.make_kernel(
             "[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
             [
-            """
-            c[i, j, k] = a[i,j,k]**3.0
-            e[i, k] = (1+g[i,k])**(1+h[i,k+1])
-            """
+                """
+                c[i, j, k] = (a[i,j,k] | 1) + (b[i,j,k] & 1)
+                e[i, k] = (g[i,k] ^ k)*(~h[i,k+1])
+                """
             ],
-            name="weird", assumptions="n,m,l >= 1")
+            name="bitwise", assumptions="n,m,l >= 1")
 
-    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, g=np.float32, h=np.float32))
+    knl = lp.add_and_infer_dtypes(knl,
+                        dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
     poly = get_op_poly(knl)
-    n=512
-    m=256
-    l=128
-    flops = poly.eval_with_dict({'n':n, 'm':m, 'l':l})
-    assert flops == 4*n*m+n*m*l
+    n = 512
+    m = 256
+    l = 128
+    '''
+    f32 = poly.dict[np.dtype(np.float32)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    f64 = poly.dict[np.dtype(np.float64)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    i32 = poly.dict[np.dtype(np.int32)].eval_with_dict({'n': n, 'm': m, 'l': l})
+    '''
+    # TODO figure out how these operations should be counted
 
 if __name__ == "__main__":
     if len(sys.argv) > 1:
-- 
GitLab