From 7906d90e31d27e7d2150a1618875cccbd5997b66 Mon Sep 17 00:00:00 2001
From: Machine Owner <owner@debian.lan>
Date: Fri, 24 Apr 2015 09:10:54 -0500
Subject: [PATCH] code cleanup, very small changes

---
 loopy/statistics.py     | 230 +++++++++++++++++++++++++---------------
 test/test_statistics.py | 108 ++++++++++++++++---
 2 files changed, 234 insertions(+), 104 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index b45e62efd..fa446c7fb 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -24,128 +24,184 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-import numpy as np
-from islpy import dim_type
+import numpy as np  # noqa
 import loopy as lp
 import pyopencl as cl
 import pyopencl.array
+import warnings
+from islpy import dim_type
 from pymbolic.mapper.flop_counter import FlopCounter
+from pymbolic.mapper import CombineMapper
 
-class ExpressionFlopCounter(FlopCounter):
 
-	# ExpressionFlopCounter extends FlopCounter extends CombineMapper extends RecursiveMapper
-	
-	def map_reduction(self, expr, knl):
-		inames_domain = knl.get_inames_domain(frozenset([expr.inames[0]]))
-		domain = (inames_domain.project_out_except(frozenset([expr.inames[0]]), [dim_type.set]))
-		if str(expr.operation) == 'sum' or str(expr.operation) == 'product' :
-			return domain.card()*(1+self.rec(expr.expr))
-		else:
-			from warnings import warn
-			warn("ExpressionFlopCounter counting reduction operation as 0 flops.", stacklevel=2)
-			return domain.card()*(0+self.rec(expr.expr))
+class ExpressionOpCounter(FlopCounter):
 
-	# from pymbolic:
+    # ExpressionOpCounter extends FlopCounter extends CombineMapper extends RecursiveMapper
+    
+    def __init__(self, knl):
+        self.knl = knl
+        from loopy.codegen.expression import TypeInferenceMapper
+        self.type_inf = TypeInferenceMapper(knl)
 
-	def map_tagged_variable(self, expr):
-		return 0
+    def map_tagged_variable(self, expr):
+        return 0
 
-	# def map_variable(self, expr):   # implemented in FlopCounter
+    #def map_variable(self, expr):   # implemented in FlopCounter
+    #    return 0
 
-	def map_wildcard(self, expr):
-		return 0
+    #def map_wildcard(self, expr):
+    #    return 0,0
 
-	def map_function_symbol(self, expr):
-		return 0
+    #def map_function_symbol(self, expr):
+    #    return 0,0
 
-	# def map_call(self, expr):  # implemented in CombineMapper, recurses
-	# def map_call_with_kwargs(self, expr):  # implemented in CombineMapper, recurses
+    def map_call(self, expr):  # implemented in CombineMapper (functions in opencl spec)
+        return 0
 
-	def map_subscript(self, expr):  # implemented in CombineMapper
-		return self.rec(expr.index)
+    # def map_call_with_kwargs(self, expr):  # implemented in CombineMapper
 
-	# def map_lookup(self, expr):  # implemented in CombineMapper, recurses
-	# def map_sum(self, expr)  # implemented in FlopCounter
-	# def map_product(self, expr):  # implemented in FlopCounter
-	# def map_quotient(self, expr):  # implemented in FlopCounter
-	# def map_floor_div(self, expr):  # implemented in FlopCounter
+    def map_subscript(self, expr):  # implemented in CombineMapper
+        return self.rec(expr.index)
 
-	def map_remainder(self, expr):  # implemented in CombineMapper
-		return 0
+    # def map_lookup(self, expr):  # implemented in CombineMapper
 
-	# def map_power(self, expr):  # implemented in FlopCounter, recurses; coming soon
+    # need to worry about data type in these (and others):
+    '''
+    def map_sum(self, expr):  # implemented in FlopCounter
+        return 0
+    def map_product(self, expr):  # implemented in FlopCounter
+        return 0
+    def map_quotient(self, expr):  # implemented in FlopCounter
+        return 0
+    def map_floor_div(self, expr):  # implemented in FlopCounter
+        return 0
+    '''
+    def map_remainder(self, expr):  # implemented in CombineMapper
+        return 1+self.rec(expr.numerator)+self.rec(expr.denominator)
 
-	def map_left_shift(self, expr):  # implemented in CombineMapper, recurses; coming soon
-		return 0
+    def map_power(self, expr):  # implemented in FlopCounter
+        return 1+self.rec(expr.base)+self.rec(expr.exponent)
 
-	def map_right_shift(self, expr):  # implemented in CombineMapper, maps to left_shift; coming soon
-		return 0
+    def map_left_shift(self, expr):  # implemented in CombineMapper
+        return 0+self.rec(expr.shiftee)+self.rec(expr.shift)  #TODO test
 
-	def map_bitwise_not(self, expr):  # implemented in CombineMapper, recurses; coming soon
-		return 0
+    map_right_shift = map_left_shift  #TODO test
 
-	def map_bitwise_or(self, expr):  # implemented in CombineMapper, maps to map_sum; coming soon
-		return 0
+    def map_bitwise_not(self, expr):  # implemented in CombineMapper #TODO test
+        return 0+self.rec(expr.child)  
 
-	def map_bitwise_xor(self, expr):  # implemented in CombineMapper, maps to map_sum; coming soon
-		return 0
+    def map_bitwise_or(self, expr):  # implemented in CombineMapper, maps to map_sum; #TODO test
+        return 0+sum(self.rec(child) for child in expr.children)
 
-	def map_bitwise_and(self, expr):  # implemented in CombineMapper, maps to map_sum; coming soon
-		return 0
+    map_bitwise_xor = map_bitwise_or  # implemented in CombineMapper, maps to map_sum; #TODO test
+    map_bitwise_and = map_bitwise_or  # implemented in CombineMapper, maps to map_sum; #TODO test
 
-	def map_comparison(self, expr):  # implemented in CombineMapper, recurses; coming soon
-		return 0
+    def map_comparison(self, expr):  # implemented in CombineMapper
+        print expr
+        my_type = self.type_inf(expr)
+        print my_type
+        return 0+self.rec(expr.left)+self.rec(expr.right)
 
-	def map_logical_not(self, expr):  # implemented in CombineMapper, maps to bitwise_not; coming soon
-		return 0
+    def map_logical_not(self, expr):  # implemented in CombineMapper, maps to bitwise_not
+        return 0+self.rec(expr.child)
 
-	def map_logical_or(self, expr):  # implemented in CombineMapper, maps to map_sum; coming soon
-		return 0
+    def map_logical_or(self, expr):  # implemented in CombineMapper, maps to map_sum
+        return 0+sum(self.rec(child) for child in expr.children) 
 
-	def map_logical_and(self, expr):  # implemented in CombineMapper, maps to map_sum; coming soon
-		return 0
+    map_logical_and = map_logical_or
 
-	def map_if(self, expr):  # implemented in CombineMapper, recurses; coming soon
-		return 0
+    def map_if(self, expr):  # implemented in CombineMapper, recurses
+        warnings.warn("Counting operations as max of if-statement branches.")
+        return self.rec(expr.condition)+max(self.rec(expr.then), self.rec(expr.else_))
 
-	# def map_if_positive(self, expr):  # implemented in FlopCounter
+    # def map_if_positive(self, expr):  # implemented in FlopCounter
 
-	def map_min(self, expr):  # implemented in CombineMapper, maps to map_sum; coming soon
-		return 0
+    def map_min(self, expr):  # implemented in CombineMapper, maps to map_sum;  #TODO test
+        return 0+sum(self.rec(child) for child in expr.children)
 
-	def map_max(self, expr):  # implemented in CombineMapper, maps to map_sum 
-		return 0
+    map_max = map_min  # implemented in CombineMapper, maps to map_sum;  #TODO test
 
-	def map_common_subexpression(self, expr):
-		print "TESTING-map_common_subexpression: ", expr
-		return 0
 
-	def map_substitution(self, expr):
-		print "TESTING-map_substitution: ", expr
-		return 0
+    def map_common_subexpression(self, expr):
+        raise NotImplementedError("OpCounter encountered common_subexpression, \
+                                   map_common_subexpression not implemented.")
+        return 0
 
-	def map_derivative(self, expr):
-		print "TESTING-map_derivative: ", expr
-		return 0
+    def map_substitution(self, expr):
+        raise NotImplementedError("OpCounter encountered substitution, \
+                                    map_substitution not implemented.")
+        return 0
 
-	def map_slice(self, expr):
-		print "TESTING-map_slice: ", expr
-		return 0
+    def map_derivative(self, expr):
+        raise NotImplementedError("OpCounter encountered derivative, \
+                                    map_derivative not implemented.")
+        return 0
 
+    def map_slice(self, expr):
+        raise NotImplementedError("OpCounter encountered slice, \
+                                    map_slice not implemented.")
+        return 0
 
-# to evaluate poly: poly.eval_with_dict(dictionary)
-def get_flop_poly(knl):
-	poly = 0
-	flopCounter = ExpressionFlopCounter()
-	for insn in knl.instructions:
-		# how many times is this instruction executed?
-		# check domain size:
-		insn_inames = knl.insn_inames(insn) 
-		inames_domain = knl.get_inames_domain(insn_inames)
-		domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
-		#flops = flopCounter(insn.expression())
-		flops = flopCounter(insn.expression(),knl)
-		poly += flops*domain.card()
-	return poly
 
+class SubscriptCounter(CombineMapper):
+    def __init__(self, kernel):
+        self.kernel = kernel
+
+    def combine(self, values):
+        return sum(values)
+
+    def map_subscript(self, expr):
+        name = expr.aggregate.name
+        arg = self.kernel.arg_dict.get(name)
+        tv = self.kernel.temporary_variables.get(name)
+        if arg is not None:
+            if isinstance(arg, lp.GlobalArg):
+                # It's global memory
+                pass
+        elif tv is not None:
+            if tv.is_local:
+                # It's shared memory
+                pass
+
+        return 1 + self.rec(expr.index)
+
+    def map_constant(self, expr):
+        return 0
+
+    def map_variable(self, expr):
+        return 0
+
+# to evaluate poly: poly.eval_with_dict(dictionary)
+def get_op_poly(knl):
+
+    from loopy.preprocess import preprocess_kernel, infer_unknown_types
+    knl = infer_unknown_types(knl, expect_completion=True)
+
+    knl = preprocess_kernel(knl)
+    #print knl
+
+    fpoly = 0
+    dpoly = 0
+    op_counter = ExpressionOpCounter(knl)
+    for insn in knl.instructions:
+        # how many times is this instruction executed?
+        # check domain size:
+        insn_inames = knl.insn_inames(insn) 
+        inames_domain = knl.get_inames_domain(insn_inames)
+        domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
+        #flops, dops = op_counter(insn.expression)
+        flops = op_counter(insn.expression)
+        fpoly += flops*domain.card()
+        #dpoly += dops*domain.card()
+    return fpoly
+
+def get_DRAM_access_poly(knl): # for now just counting subscripts
+    poly = 0
+    subscript_counter = subscript_counter(knl)
+    for insn in knl.instructions:
+        insn_inames = knl.insn_inames(insn) 
+        inames_domain = knl.get_inames_domain(insn_inames)
+        domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
+        poly += subscript_counter(insn.expression) * domain.card()
+    return poly
 
diff --git a/test/test_statistics.py b/test/test_statistics.py
index 0df186502..a1d81735b 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -26,29 +26,103 @@ import sys
 from pyopencl.tools import (
         pytest_generate_tests_for_pyopencl
         as pytest_generate_tests)
-from pymbolic.mapper.flop_counter import FlopCounter
 from loopy.statistics import *
 
 
-def test_flop_counter_basic(ctx_factory):
+def test_op_counter_basic(ctx_factory):
 
-	knl = lp.make_kernel(
-			"[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
-			[
-			"""
-			c[i, j, k] = a[i,j,k]*b[i,j,k]/3.0+a[i,j,k]
-			e[i, k] = g[i,k]*h[i,k]
-			"""
-			],
-			name="weird", assumptions="n,m,l >= 1")
+    knl = lp.make_kernel(
+            "[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
+            [
+            """
+            c[i, j, k] = a[i,j,k]*b[i,j,k]/3.0+a[i,j,k]
+            e[i, k] = g[i,k]*h[i,k]
+            """
+            ],
+            name="weird", assumptions="n,m,l >= 1")
 
-	poly = get_flop_poly(knl)
-	n=512
-	m=256
-	l=128
-	flops = poly.eval_with_dict({'n':n, 'm':m, 'l':l})
-	assert flops == n*m+3*n*m*l
+    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float32, h=np.float32))
+    poly = get_op_poly(knl)
+    n=512
+    m=256
+    l=128
+    flops = poly.eval_with_dict({'n':n, 'm':m, 'l':l})
+    assert flops == n*m+3*n*m*l
 
+def test_op_counter_reduction(ctx_factory):
+
+    knl = lp.make_kernel(
+            "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
+            [
+            "c[i, j] = sum(k, a[i, k]*b[k, j])"
+            ],
+            name="matmul", assumptions="n,m,l >= 1")
+
+    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))
+    poly = get_op_poly(knl)
+    n=512
+    m=256
+    l=128
+    flops = poly.eval_with_dict({'n':n, 'm':m, 'l':l})
+    assert flops == 2*n*m*l
+
+def test_op_counter_logic(ctx_factory):
+
+    knl = lp.make_kernel(
+            "[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
+            [
+            """
+            e[i,k] = if(not(k < l-2) and k > l+6 or k/2 == l, g[i,k]*h[i,k], g[i,k]+h[i,k]/2.0)
+            """
+            ],
+            name="logic", assumptions="n,m,l >= 1")
+
+    knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float32))
+    poly = get_op_poly(knl)
+    n=512
+    m=256
+    l=128
+    flops = poly.eval_with_dict({'n':n, 'm':m, 'l':l})
+    assert flops == 5*n*m
+
+def test_op_counter_remainder(ctx_factory):
+
+    knl = lp.make_kernel(
+            "[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
+            [
+            """
+            c[i, j, k] = (2*a[i,j,k])%(2+b[i,j,k]/3.0)
+            """
+            ],
+            name="logic", assumptions="n,m,l >= 1")
+
+    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))
+    poly = get_op_poly(knl)
+    n=512
+    m=256
+    l=128
+    flops = poly.eval_with_dict({'n':n, 'm':m, 'l':l})
+    assert flops == 4*n*m*l
+
+def test_op_counter_power(ctx_factory):
+
+    knl = lp.make_kernel(
+            "[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
+            [
+            """
+            c[i, j, k] = a[i,j,k]**3.0
+            e[i, k] = (1+g[i,k])**(1+h[i,k+1])
+            """
+            ],
+            name="weird", assumptions="n,m,l >= 1")
+
+    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, g=np.float32, h=np.float32))
+    poly = get_op_poly(knl)
+    n=512
+    m=256
+    l=128
+    flops = poly.eval_with_dict({'n':n, 'm':m, 'l':l})
+    assert flops == 4*n*m+n*m*l
 
 if __name__ == "__main__":
     if len(sys.argv) > 1:
-- 
GitLab