From f8cf6fcf8025e4412f2327c4d7ece9b055734ffe Mon Sep 17 00:00:00 2001 From: jdsteve2 <jdsteve2@illinois.edu> Date: Tue, 20 Feb 2018 22:50:40 -0600 Subject: [PATCH] added CountGranularity.ALL to list all granularities --- loopy/statistics.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/loopy/statistics.py b/loopy/statistics.py index ed21dd045..5a5f85f65 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -478,9 +478,11 @@ class CountGranularity: once per *work-group*. """ + WORKITEM = "workitem" SUBGROUP = "subgroup" WORKGROUP = "workgroup" + ALL = [WORKITEM, SUBGROUP, WORKGROUP] # {{{ Op descriptor @@ -511,16 +513,11 @@ class Op(Record): """ - count_granularity_options = [CountGranularity.WORKITEM, - CountGranularity.SUBGROUP, - CountGranularity.WORKGROUP, - None] - def __init__(self, dtype=None, name=None, count_granularity=None): - if count_granularity not in self.count_granularity_options: - raise ValueError("Op.__init__: count_granularity '%s' is" + if count_granularity not in CountGranularity.ALL+[None]: + raise ValueError("Op.__init__: count_granularity '%s' is " "not allowed. count_granularity options: %s" - % (count_granularity, self.count_granularity_options)) + % (count_granularity, CountGranularity.ALL+[None])) if dtype is None: Record.__init__(self, dtype=dtype, name=name, count_granularity=count_granularity) @@ -582,11 +579,6 @@ class MemAccess(Record): """ - count_granularity_options = [CountGranularity.WORKITEM, - CountGranularity.SUBGROUP, - CountGranularity.WORKGROUP, - None] - def __init__(self, mtype=None, dtype=None, stride=None, direction=None, variable=None, count_granularity=None): @@ -600,10 +592,10 @@ class MemAccess(Record): raise NotImplementedError("MemAccess: variable must be None when " "mtype is 'local'") - if count_granularity not in self.count_granularity_options: - raise ValueError("Op.__init__: count_granularity '%s' is" + if count_granularity not in CountGranularity.ALL+[None]: + raise ValueError("Op.__init__: count_granularity '%s' is " "not allowed. count_granularity options: %s" - % (count_granularity, self.count_granularity_options)) + % (count_granularity, CountGranularity.ALL+[None])) if dtype is None: Record.__init__(self, mtype=mtype, dtype=dtype, stride=stride, @@ -1490,7 +1482,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, # this should not happen since this is enforced in MemAccess raise ValueError("get_insn_count: count_granularity '%s' is" "not allowed. count_granularity options: %s" - % (count_granularity, MemAccess.count_granularity_options)) + % (count_granularity, CountGranularity.ALL+[None])) knl = infer_unknown_types(knl, expect_completion=True) knl = preprocess_kernel(knl) -- GitLab