diff --git a/loopy/statistics.py b/loopy/statistics.py index ed21dd0450b6797f9e3c4ded419233ef36b29967..5a5f85f6503a04495dfcf47784ff1af3105c2c23 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -478,9 +478,11 @@ class CountGranularity: once per *work-group*. """ + WORKITEM = "workitem" SUBGROUP = "subgroup" WORKGROUP = "workgroup" + ALL = [WORKITEM, SUBGROUP, WORKGROUP] # {{{ Op descriptor @@ -511,16 +513,11 @@ class Op(Record): """ - count_granularity_options = [CountGranularity.WORKITEM, - CountGranularity.SUBGROUP, - CountGranularity.WORKGROUP, - None] - def __init__(self, dtype=None, name=None, count_granularity=None): - if count_granularity not in self.count_granularity_options: - raise ValueError("Op.__init__: count_granularity '%s' is" + if count_granularity not in CountGranularity.ALL+[None]: + raise ValueError("Op.__init__: count_granularity '%s' is " "not allowed. count_granularity options: %s" - % (count_granularity, self.count_granularity_options)) + % (count_granularity, CountGranularity.ALL+[None])) if dtype is None: Record.__init__(self, dtype=dtype, name=name, count_granularity=count_granularity) @@ -582,11 +579,6 @@ class MemAccess(Record): """ - count_granularity_options = [CountGranularity.WORKITEM, - CountGranularity.SUBGROUP, - CountGranularity.WORKGROUP, - None] - def __init__(self, mtype=None, dtype=None, stride=None, direction=None, variable=None, count_granularity=None): @@ -600,10 +592,10 @@ class MemAccess(Record): raise NotImplementedError("MemAccess: variable must be None when " "mtype is 'local'") - if count_granularity not in self.count_granularity_options: - raise ValueError("Op.__init__: count_granularity '%s' is" + if count_granularity not in CountGranularity.ALL+[None]: + raise ValueError("Op.__init__: count_granularity '%s' is " "not allowed. count_granularity options: %s" - % (count_granularity, self.count_granularity_options)) + % (count_granularity, CountGranularity.ALL+[None])) if dtype is None: Record.__init__(self, mtype=mtype, dtype=dtype, stride=stride, @@ -1490,7 +1482,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, # this should not happen since this is enforced in MemAccess raise ValueError("get_insn_count: count_granularity '%s' is" "not allowed. count_granularity options: %s" - % (count_granularity, MemAccess.count_granularity_options)) + % (count_granularity, CountGranularity.ALL+[None])) knl = infer_unknown_types(knl, expect_completion=True) knl = preprocess_kernel(knl)