Skip to content
Snippets Groups Projects
Commit f8cf6fcf authored by James Stevens's avatar James Stevens
Browse files

added CountGranularity.ALL to list all granularities

parent c44ebfd0
No related branches found
No related tags found
No related merge requests found
......@@ -478,9 +478,11 @@ class CountGranularity:
once per *work-group*.
"""
WORKITEM = "workitem"
SUBGROUP = "subgroup"
WORKGROUP = "workgroup"
ALL = [WORKITEM, SUBGROUP, WORKGROUP]
# {{{ Op descriptor
......@@ -511,16 +513,11 @@ class Op(Record):
"""
count_granularity_options = [CountGranularity.WORKITEM,
CountGranularity.SUBGROUP,
CountGranularity.WORKGROUP,
None]
def __init__(self, dtype=None, name=None, count_granularity=None):
if count_granularity not in self.count_granularity_options:
raise ValueError("Op.__init__: count_granularity '%s' is"
if count_granularity not in CountGranularity.ALL+[None]:
raise ValueError("Op.__init__: count_granularity '%s' is "
"not allowed. count_granularity options: %s"
% (count_granularity, self.count_granularity_options))
% (count_granularity, CountGranularity.ALL+[None]))
if dtype is None:
Record.__init__(self, dtype=dtype, name=name,
count_granularity=count_granularity)
......@@ -582,11 +579,6 @@ class MemAccess(Record):
"""
count_granularity_options = [CountGranularity.WORKITEM,
CountGranularity.SUBGROUP,
CountGranularity.WORKGROUP,
None]
def __init__(self, mtype=None, dtype=None, stride=None, direction=None,
variable=None, count_granularity=None):
......@@ -600,10 +592,10 @@ class MemAccess(Record):
raise NotImplementedError("MemAccess: variable must be None when "
"mtype is 'local'")
if count_granularity not in self.count_granularity_options:
raise ValueError("Op.__init__: count_granularity '%s' is"
if count_granularity not in CountGranularity.ALL+[None]:
raise ValueError("Op.__init__: count_granularity '%s' is "
"not allowed. count_granularity options: %s"
% (count_granularity, self.count_granularity_options))
% (count_granularity, CountGranularity.ALL+[None]))
if dtype is None:
Record.__init__(self, mtype=mtype, dtype=dtype, stride=stride,
......@@ -1490,7 +1482,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False,
# this should not happen since this is enforced in MemAccess
raise ValueError("get_insn_count: count_granularity '%s' is"
"not allowed. count_granularity options: %s"
% (count_granularity, MemAccess.count_granularity_options))
% (count_granularity, CountGranularity.ALL+[None]))
knl = infer_unknown_types(knl, expect_completion=True)
knl = preprocess_kernel(knl)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment