From dbb086e4fdbf687dd340b8ba4dcffa8ee574d631 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Mon, 1 Feb 2021 17:08:41 -0600
Subject: [PATCH] tests statistics for callable kernels

---
 loopy/statistics.py     | 10 ++++----
 test/test_statistics.py | 51 +++++++++++++++++++++++++++++++++++++++++
 2 files changed, 57 insertions(+), 4 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index 9257cafc1..34027a5a0 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -98,6 +98,7 @@ def _get_param_tuple(obj):
 
 class GuardedPwQPolynomial:
     def __init__(self, pwqpolynomial, valid_domain):
+        assert isinstance(pwqpolynomial, isl.PwQPolynomial)
         self.pwqpolynomial = pwqpolynomial
         self.valid_domain = valid_domain
 
@@ -664,10 +665,10 @@ class Op(ImmutableRecord):
     def __repr__(self):
         # Record.__repr__ overridden for consistent ordering and conciseness
         if self.kernel_name is not None:
-            return (f"Op({self.dtype}, {self.name}, {self.count_granularity},"
-                    f" {self.kernel_name})")
+            return (f'Op("{self.dtype}", "{self.name}", "{self.count_granularity}",'
+                    f' "{self.kernel_name}")')
         else:
-            return f"Op({self.dtype}, {self.name}, {self.count_granularity})"
+            return f'Op("{self.dtype}", "{self.name}", "{self.count_granularity}")'
 
 # }}}
 
@@ -1548,7 +1549,8 @@ def get_unused_hw_axes_factor(knl, callables_table, insn,
 def count_inames_domain(knl, inames):
     space = get_kernel_parameter_space(knl)
     if not inames:
-        return get_kernel_zero_pwqpolynomial(knl) + 1
+        return add_assumptions_guard(knl,
+                get_kernel_zero_pwqpolynomial(knl) + 1)
 
     inames_domain = knl.get_inames_domain(inames)
     domain = inames_domain.project_out_except(inames, [dim_type.set])
diff --git a/test/test_statistics.py b/test/test_statistics.py
index 4136f8d06..ca38b9af6 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -1400,6 +1400,57 @@ def test_strided_footprint():
     assert 2*num < denom
 
 
+def test_stats_on_callable_kernel():
+    callee = lp.make_function(
+            "{[i, j]: 0<=i, j< 20}",
+            """
+            y[i] = sum(j, A[i,j]*x[j])
+            """, name="matvec20x20")
+
+    caller = lp.make_kernel(
+            "{:}",
+            """
+            y[:]  = matvec20x20(A[:,:], x[:])
+            """,
+            [
+                lp.GlobalArg("x,y", shape=(20,), dtype=np.float),
+                lp.GlobalArg("A", shape=(20, 20), dtype=np.float),
+                ],
+            name="matvec")
+    caller = lp.merge([caller, callee])
+
+    op_map = lp.get_op_map(caller, subgroup_size=SGS, count_redundant_work=True,
+                           count_within_subscripts=True)
+    f64_add = op_map.filter_by(name="add").eval_and_sum({})
+    assert f64_add == 400
+
+
+def test_stats_on_callable_kernel_within_loop():
+    callee = lp.make_function(
+            "{[i, j]: 0<=i, j< 20}",
+            """
+            y[i] = sum(j, A[i,j]*x[j])
+            """, name="matvec20x20")
+
+    caller = lp.make_kernel(
+            "{[i]: 0<=i< 20}",
+            """
+            y[i, :]  = matvec20x20(A[:,:], x[i, :])
+            """,
+            [
+                lp.GlobalArg("x,y", shape=(20, 20), dtype=np.float),
+                lp.GlobalArg("A", shape=(20, 20), dtype=np.float),
+                ],
+            name="matmat")
+    caller = lp.merge([caller, callee])
+
+    op_map = lp.get_op_map(caller, subgroup_size=SGS, count_redundant_work=True,
+                           count_within_subscripts=True)
+
+    f64_add = op_map.filter_by(name="add").eval_and_sum({})
+    assert f64_add == 8000
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab