From 61595cb3a877980c4827b0b7a355d3d69f9a02df Mon Sep 17 00:00:00 2001 From: jdsteve2 <jdsteve2@illinois.edu> Date: Mon, 22 Jan 2018 07:45:40 -0600 Subject: [PATCH] inheriting from record in Op and MemAccess --- loopy/statistics.py | 73 ++++++++++----------------------------------- 1 file changed, 16 insertions(+), 57 deletions(-) diff --git a/loopy/statistics.py b/loopy/statistics.py index 765c75a8f..4987b27df 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -32,6 +32,7 @@ from functools import reduce from loopy.kernel.data import ( MultiAssignmentBase, TemporaryVariable, temp_var_scope) from loopy.diagnostic import warn_with_kernel, LoopyError +from pytools import Record __doc__ = """ @@ -466,7 +467,7 @@ def stringify_stats_mapping(m): # {{{ Op descriptor -class Op(object): +class Op(Record): """A descriptor for a type of arithmetic operation. .. attribute:: dtype @@ -481,26 +482,14 @@ class Op(object): """ - # FIXME: This could be done much more briefly by inheriting from Record. - def __init__(self, dtype=None, name=None, count_granularity=None): - self.name = name - self.count_granularity = count_granularity if dtype is None: - self.dtype = dtype + Record.__init__(self, dtype=dtype, name=name, + count_granularity=count_granularity) else: from loopy.types import to_loopy_type - self.dtype = to_loopy_type(dtype) - - def __eq__(self, other): - return isinstance(other, Op) and ( - (self.dtype is None or other.dtype is None or - self.dtype == other.dtype) and - (self.name is None or other.name is None or - self.name == other.name) and - (self.count_granularity is None or - other.count_granularity is None or - self.count_granularity == other.count_granularity)) + Record.__init__(self, dtype=to_loopy_type(dtype), name=name, + count_granularity=count_granularity) def __hash__(self): return hash(str(self)) @@ -513,7 +502,7 @@ class Op(object): # {{{ MemAccess descriptor -class MemAccess(object): +class MemAccess(Record): """A descriptor for a type of memory access. .. attribute:: mtype @@ -547,17 +536,6 @@ class MemAccess(object): def __init__(self, mtype=None, dtype=None, stride=None, direction=None, variable=None, count_granularity=None): - self.mtype = mtype - self.stride = stride - self.direction = direction - self.variable = variable - self.count_granularity = count_granularity - - if dtype is None: - self.dtype = dtype - else: - from loopy.types import to_loopy_type - self.dtype = to_loopy_type(dtype) #TODO currently giving all lmem access stride=None if (mtype == 'local') and (stride is not None): @@ -569,34 +547,15 @@ class MemAccess(object): raise NotImplementedError("MemAccess: variable must be None when " "mtype is 'local'") - def copy(self, mtype=None, dtype=None, stride=None, direction=None, - variable=None, count_granularity=None): - return MemAccess( - mtype=mtype if mtype is not None else self.mtype, - dtype=dtype if dtype is not None else self.dtype, - stride=stride if stride is not None else self.stride, - direction=direction if direction is not None else self.direction, - variable=variable if variable is not None else self.variable, - count_granularity=count_granularity - if count_granularity is not None - else self.count_granularity) - - def __eq__(self, other): - return isinstance(other, MemAccess) and ( - (self.mtype is None or other.mtype is None or - self.mtype == other.mtype) and - (self.dtype is None or other.dtype is None or - self.dtype == other.dtype) and - (self.stride is None or other.stride is None or - self.stride == other.stride) and - (self.direction is None or other.direction is None or - self.direction == other.direction) and - (self.variable is None or other.variable is None or - self.variable == other.variable) and - (self.count_granularity is None or - other.count_granularity is None or - self.count_granularity == other.count_granularity) - ) + if dtype is None: + Record.__init__(self, mtype=mtype, dtype=dtype, stride=stride, + direction=direction, variable=variable, + count_granularity=count_granularity) + else: + from loopy.types import to_loopy_type + Record.__init__(self, mtype=mtype, dtype=to_loopy_type(dtype), stride=stride, + direction=direction, variable=variable, + count_granularity=count_granularity) def __hash__(self): return hash(str(self)) -- GitLab