From 447ab9453044e1cbe5898aee778315dcc8c84dd0 Mon Sep 17 00:00:00 2001
From: James Stevens <jdsteve2@illinois.edu>
Date: Thu, 17 Sep 2015 14:20:44 -0500
Subject: [PATCH] added initial naive reg counter

---
 loopy/statistics.py     | 125 ++++++++++++++++++++++++++++++++++++++++
 test/test_statistics.py |  90 +++++++++++++++++++++++++++++
 2 files changed, 215 insertions(+)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index 429a6a211..b9c0d701f 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -377,6 +377,108 @@ class GlobalSubscriptCounter(CombineMapper):
         raise NotImplementedError("GlobalSubscriptCounter encountered slice, "
                                   "map_slice not implemented.")
 
+class RegisterUsageEstimator(CombineMapper):
+
+    def __init__(self, knl):
+        self.knl = knl
+        from loopy.expression import TypeInferenceMapper
+        self.type_inf = TypeInferenceMapper(knl)
+
+    def combine(self, values):
+        return sum(values)
+
+    def map_constant(self, expr):
+        return 1
+
+    map_tagged_variable = map_constant
+    map_variable = map_constant
+    map_call = map_constant  # TODO what is this?
+
+    def map_subscript(self, expr):
+        name = expr.aggregate.name  # name of array
+
+        if name in self.knl.arg_dict:
+            array = self.knl.arg_dict[name]
+        else:
+            # this is a temporary variable
+            return 1  # TODO +self.rec(expr.index)?
+
+        if not isinstance(array, lp.GlobalArg):
+            # this array is not in global memory
+            return 0  # TODO is this right? recurse on index?
+
+        # this is a global mem access
+        return 1  # TODO +self.rec(expr.index)?
+
+    def map_sum(self, expr):
+        if expr.children:
+            return sum(self.rec(child) for child in expr.children)
+        else:
+            return 0  # TODO when does this happen?
+
+    map_product = map_sum
+
+    def map_quotient(self, expr, *args):
+        return self.rec(expr.numerator) + self.rec(expr.denominator)
+
+    map_floor_div = map_quotient
+    map_remainder = map_quotient
+
+    def map_power(self, expr):
+        return self.rec(expr.base) + self.rec(expr.exponent)
+
+    def map_left_shift(self, expr):
+        return self.rec(expr.shiftee)+self.rec(expr.shift)
+
+    map_right_shift = map_left_shift
+
+    def map_bitwise_not(self, expr):
+        return self.rec(expr.child)
+
+    def map_bitwise_or(self, expr):
+        return sum(self.rec(child) for child in expr.children)
+
+    map_bitwise_xor = map_bitwise_or
+    map_bitwise_and = map_bitwise_or
+
+    def map_comparison(self, expr):
+        return self.rec(expr.left)+self.rec(expr.right)
+
+    map_logical_not = map_bitwise_not
+    map_logical_or = map_bitwise_or
+    map_logical_and = map_logical_or
+
+    def map_if(self, expr):
+        warnings.warn("RegisterUsageEstimator counting register usage as "
+                      "sum of if-statement branches.")
+        return self.rec(expr.condition) + self.rec(expr.then) + self.rec(expr.else_)
+
+    def map_if_positive(self, expr):
+        warnings.warn("RegisterUsageEstimator counting register usage as "
+                      "sum of if_pos-statement branches.")
+        return self.rec(expr.criterion) + self.rec(expr.then) + self.rec(expr.else_)
+
+    map_min = map_bitwise_or
+    map_max = map_min
+
+    def map_common_subexpression(self, expr):
+        raise NotImplementedError("GlobalSubscriptCounter encountered "
+                                  "common_subexpression, "
+                                  "map_common_subexpression not implemented.")
+
+    def map_substitution(self, expr):
+        raise NotImplementedError("GlobalSubscriptCounter encountered "
+                                  "substitution, "
+                                  "map_substitution not implemented.")
+
+    def map_derivative(self, expr):
+        raise NotImplementedError("GlobalSubscriptCounter encountered "
+                                  "derivative, "
+                                  "map_derivative not implemented.")
+
+    def map_slice(self, expr):
+        raise NotImplementedError("GlobalSubscriptCounter encountered slice, "
+                                  "map_slice not implemented.")
 
 def count(kernel, bset):
     try:
@@ -596,3 +698,26 @@ def get_barrier_poly(knl):
                 barrier_poly += isl.PwQPolynomial('{ 1 }')
 
     return barrier_poly
+
+
+def get_regs_per_thread(knl):
+
+    """Estimate registers per thread usage by a loopy kernel.
+
+    :parameter knl: A :class:`loopy.LoopKernel` whose reg usage will be estimated.
+
+    """
+
+    from loopy.preprocess import preprocess_kernel, infer_unknown_types
+    from loopy.schedule import RunInstruction
+    knl = infer_unknown_types(knl, expect_completion=True)
+    knl = preprocess_kernel(knl)
+    knl = lp.get_one_scheduled_kernel(knl)
+    regs = 0
+    reg_counter = RegisterUsageEstimator(knl)
+    for sched_item in knl.schedule:
+        if isinstance(sched_item, RunInstruction):
+            insn = knl.id_to_insn[sched_item.insn_id]
+            regs += reg_counter(insn.expression) 
+            regs += reg_counter(insn.assignee) 
+    return regs
diff --git a/test/test_statistics.py b/test/test_statistics.py
index a50476119..5426d6312 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -28,6 +28,7 @@ from pyopencl.tools import (  # noqa
         as pytest_generate_tests)
 import loopy as lp
 from loopy.statistics import get_op_poly, get_gmem_access_poly, get_barrier_poly
+from loopy.statistics import get_regs_per_thread
 import numpy as np
 
 
@@ -510,6 +511,92 @@ def test_barrier_counter_barriers():
     assert barrier_count == 50*10*2
 
 
+def test_reg_counter_basic():
+
+    knl = lp.make_kernel(
+            "[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
+            [
+                """
+                c[i, j, k] = a[i,j,k]*b[i,j,k]/3.0+a[i,j,k]
+                e[i, k+1] = g[i,k]*h[i,k+1]
+                """
+            ],
+            name="basic", assumptions="n,m,l >= 1")
+
+    knl = lp.add_and_infer_dtypes(knl,
+                        dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
+    regs = get_regs_per_thread(knl)
+    assert regs == 8
+
+
+def test_reg_counter_reduction():
+
+    knl = lp.make_kernel(
+            "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
+            [
+                "c[i, j] = sum(k, a[i, k]*b[k, j])"
+            ],
+            name="matmul_serial", assumptions="n,m,l >= 1")
+
+    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))
+    regs = get_regs_per_thread(knl)
+    assert regs == 8
+
+
+def test_reg_counter_logic():
+
+    knl = lp.make_kernel(
+            "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
+            [
+                """
+                e[i,k] = if(not(k<l-2) and k>6 or k/2==l, g[i,k]*2, g[i,k]+h[i,k]/2)
+                """
+            ],
+            name="logic", assumptions="n,m,l >= 1")
+
+    knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float64))
+    regs = get_regs_per_thread(knl)
+    assert regs == 14
+
+
+def test_reg_counter_specialops():
+
+    knl = lp.make_kernel(
+            "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
+            [
+                """
+                c[i, j, k] = (2*a[i,j,k])%(2+b[i,j,k]/3.0)
+                e[i, k] = (1+g[i,k])**(1+h[i,k+1])
+                """
+            ],
+            name="specialops", assumptions="n,m,l >= 1")
+
+    knl = lp.add_and_infer_dtypes(knl,
+                        dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64))
+    regs = get_regs_per_thread(knl)
+    assert regs == 11
+
+
+def test_reg_counter_bitwise():
+
+    knl = lp.make_kernel(
+            "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
+            [
+                """
+                c[i, j, k] = (a[i,j,k] | 1) + (b[i,j,k] & 1)
+                e[i, k] = (g[i,k] ^ k)*(~h[i,k+1]) + (g[i, k] << (h[i,k] >> k))
+                """
+            ],
+            name="bitwise", assumptions="n,m,l >= 1")
+
+    knl = lp.add_and_infer_dtypes(
+            knl, dict(
+                a=np.int32, b=np.int32,
+                g=np.int64, h=np.int64))
+    regs = get_regs_per_thread(knl)
+    assert regs == 12
+
+
 def test_all_counters_parallel_matmul():
 
     knl = lp.make_kernel(
@@ -557,6 +644,9 @@ def test_all_counters_parallel_matmul():
 
     assert f32coal == n*l
 
+    regs = get_regs_per_thread(knl)
+    assert regs == 8
+
 
 if __name__ == "__main__":
     if len(sys.argv) > 1:
-- 
GitLab