From f8cf6fcf8025e4412f2327c4d7ece9b055734ffe Mon Sep 17 00:00:00 2001
From: jdsteve2 <jdsteve2@illinois.edu>
Date: Tue, 20 Feb 2018 22:50:40 -0600
Subject: [PATCH] added CountGranularity.ALL to list all granularities

---
 loopy/statistics.py | 26 +++++++++-----------------
 1 file changed, 9 insertions(+), 17 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index ed21dd045..5a5f85f65 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -478,9 +478,11 @@ class CountGranularity:
        once per *work-group*.
 
     """
+
     WORKITEM = "workitem"
     SUBGROUP = "subgroup"
     WORKGROUP = "workgroup"
+    ALL = [WORKITEM, SUBGROUP, WORKGROUP]
 
 
 # {{{ Op descriptor
@@ -511,16 +513,11 @@ class Op(Record):
 
     """
 
-    count_granularity_options = [CountGranularity.WORKITEM,
-                                 CountGranularity.SUBGROUP,
-                                 CountGranularity.WORKGROUP,
-                                 None]
-
     def __init__(self, dtype=None, name=None, count_granularity=None):
-        if count_granularity not in self.count_granularity_options:
-            raise ValueError("Op.__init__: count_granularity '%s' is"
+        if count_granularity not in CountGranularity.ALL+[None]:
+            raise ValueError("Op.__init__: count_granularity '%s' is "
                     "not allowed. count_granularity options: %s"
-                    % (count_granularity, self.count_granularity_options))
+                    % (count_granularity, CountGranularity.ALL+[None]))
         if dtype is None:
             Record.__init__(self, dtype=dtype, name=name,
                             count_granularity=count_granularity)
@@ -582,11 +579,6 @@ class MemAccess(Record):
 
     """
 
-    count_granularity_options = [CountGranularity.WORKITEM,
-                                 CountGranularity.SUBGROUP,
-                                 CountGranularity.WORKGROUP,
-                                 None]
-
     def __init__(self, mtype=None, dtype=None, stride=None, direction=None,
                  variable=None, count_granularity=None):
 
@@ -600,10 +592,10 @@ class MemAccess(Record):
             raise NotImplementedError("MemAccess: variable must be None when "
                                       "mtype is 'local'")
 
-        if count_granularity not in self.count_granularity_options:
-            raise ValueError("Op.__init__: count_granularity '%s' is"
+        if count_granularity not in CountGranularity.ALL+[None]:
+            raise ValueError("Op.__init__: count_granularity '%s' is "
                     "not allowed. count_granularity options: %s"
-                    % (count_granularity, self.count_granularity_options))
+                    % (count_granularity, CountGranularity.ALL+[None]))
 
         if dtype is None:
             Record.__init__(self, mtype=mtype, dtype=dtype, stride=stride,
@@ -1490,7 +1482,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False,
             # this should not happen since this is enforced in MemAccess
             raise ValueError("get_insn_count: count_granularity '%s' is"
                     "not allowed. count_granularity options: %s"
-                    % (count_granularity, MemAccess.count_granularity_options))
+                    % (count_granularity, CountGranularity.ALL+[None]))
 
     knl = infer_unknown_types(knl, expect_completion=True)
     knl = preprocess_kernel(knl)
-- 
GitLab