diff --git a/loopy/statistics.py b/loopy/statistics.py index a56be22a37d638b5af5f290e90f7ec9f899f94d7..4dac09c0d1a2d88de94e0886488f26771add2dda 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -487,7 +487,13 @@ class Op(Record): """ + count_granularity_options = ["workitem", "subgroup", "group", None] + def __init__(self, dtype=None, name=None, count_granularity=None): + if not count_granularity in self.count_granularity_options: + raise ValueError("Op.__init__: count_granularity '%s' is" + "not allowed. count_granularity options: %s" + % (count_granularity, self.count_granularity_options)) if dtype is None: Record.__init__(self, dtype=dtype, name=name, count_granularity=count_granularity) @@ -542,6 +548,8 @@ class MemAccess(Record): """ + count_granularity_options = ["workitem", "subgroup", "group", None] + def __init__(self, mtype=None, dtype=None, stride=None, direction=None, variable=None, count_granularity=None): @@ -555,6 +563,11 @@ class MemAccess(Record): raise NotImplementedError("MemAccess: variable must be None when " "mtype is 'local'") + if not count_granularity in self.count_granularity_options: + raise ValueError("Op.__init__: count_granularity '%s' is" + "not allowed. count_granularity options: %s" + % (count_granularity, self.count_granularity_options)) + if dtype is None: Record.__init__(self, mtype=mtype, dtype=dtype, stride=stride, direction=direction, variable=variable, @@ -1371,9 +1384,10 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, group_workitems *= s return ct/group_workitems else: + # this should not happen since this is enforced in MemAccess raise ValueError("get_insn_count: count_granularity '%s' is" - "not allowed. count_granularity must be 'group', " - "'subgroup', or 'workitem'." % (count_granularity)) + "not allowed. count_granularity options: %s" + % (count_granularity, MemAccess.count_granularity_options)) knl = infer_unknown_types(knl, expect_completion=True) knl = preprocess_kernel(knl) diff --git a/test/test_statistics.py b/test/test_statistics.py index f8735553fc2e6f54717d3abd37871935e1ec4214..82f9f088682c3dc12cfa4b43380311c3bfe44519 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -672,6 +672,31 @@ def test_mem_access_counter_consec(): assert f32consec == n*m*ell +def test_count_granularity_val_checks(): + + try: + lp.MemAccess(count_granularity='workitem') + lp.MemAccess(count_granularity='subgroup') + lp.MemAccess(count_granularity='group') + lp.MemAccess(count_granularity=None) + assert True + lp.MemAccess(count_granularity='bushel') + assert False + except ValueError: + assert True + + try: + lp.Op(count_granularity='workitem') + lp.Op(count_granularity='subgroup') + lp.Op(count_granularity='group') + lp.Op(count_granularity=None) + assert True + lp.Op(count_granularity='bushel') + assert False + except ValueError: + assert True + + def test_barrier_counter_nobarriers(): knl = lp.make_kernel(