From 6dfc346bc735f8165bfdd81b0578042663b0292f Mon Sep 17 00:00:00 2001 From: jdsteve2 Date: Tue, 23 Jan 2018 04:16:58 -0600 Subject: [PATCH] ensuring count_granularity values are valid in Op.__init__ and MemAccess.__init__ --- loopy/statistics.py | 18 ++++++++++++++++-- test/test_statistics.py | 25 +++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/loopy/statistics.py b/loopy/statistics.py index a56be22a3..4dac09c0d 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -487,7 +487,13 @@ class Op(Record): """ + count_granularity_options = ["workitem", "subgroup", "group", None] + def __init__(self, dtype=None, name=None, count_granularity=None): + if not count_granularity in self.count_granularity_options: + raise ValueError("Op.__init__: count_granularity '%s' is" + "not allowed. count_granularity options: %s" + % (count_granularity, self.count_granularity_options)) if dtype is None: Record.__init__(self, dtype=dtype, name=name, count_granularity=count_granularity) @@ -542,6 +548,8 @@ class MemAccess(Record): """ + count_granularity_options = ["workitem", "subgroup", "group", None] + def __init__(self, mtype=None, dtype=None, stride=None, direction=None, variable=None, count_granularity=None): @@ -555,6 +563,11 @@ class MemAccess(Record): raise NotImplementedError("MemAccess: variable must be None when " "mtype is 'local'") + if not count_granularity in self.count_granularity_options: + raise ValueError("Op.__init__: count_granularity '%s' is" + "not allowed. count_granularity options: %s" + % (count_granularity, self.count_granularity_options)) + if dtype is None: Record.__init__(self, mtype=mtype, dtype=dtype, stride=stride, direction=direction, variable=variable, @@ -1371,9 +1384,10 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, group_workitems *= s return ct/group_workitems else: + # this should not happen since this is enforced in MemAccess raise ValueError("get_insn_count: count_granularity '%s' is" - "not allowed. count_granularity must be 'group', " - "'subgroup', or 'workitem'." % (count_granularity)) + "not allowed. count_granularity options: %s" + % (count_granularity, MemAccess.count_granularity_options)) knl = infer_unknown_types(knl, expect_completion=True) knl = preprocess_kernel(knl) diff --git a/test/test_statistics.py b/test/test_statistics.py index f8735553f..82f9f0886 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -672,6 +672,31 @@ def test_mem_access_counter_consec(): assert f32consec == n*m*ell +def test_count_granularity_val_checks(): + + try: + lp.MemAccess(count_granularity='workitem') + lp.MemAccess(count_granularity='subgroup') + lp.MemAccess(count_granularity='group') + lp.MemAccess(count_granularity=None) + assert True + lp.MemAccess(count_granularity='bushel') + assert False + except ValueError: + assert True + + try: + lp.Op(count_granularity='workitem') + lp.Op(count_granularity='subgroup') + lp.Op(count_granularity='group') + lp.Op(count_granularity=None) + assert True + lp.Op(count_granularity='bushel') + assert False + except ValueError: + assert True + + def test_barrier_counter_nobarriers(): knl = lp.make_kernel( -- GitLab