From 61595cb3a877980c4827b0b7a355d3d69f9a02df Mon Sep 17 00:00:00 2001
From: jdsteve2 <jdsteve2@illinois.edu>
Date: Mon, 22 Jan 2018 07:45:40 -0600
Subject: [PATCH] inheriting from record in Op and MemAccess

---
 loopy/statistics.py | 73 ++++++++++-----------------------------------
 1 file changed, 16 insertions(+), 57 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index 765c75a8f..4987b27df 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -32,6 +32,7 @@ from functools import reduce
 from loopy.kernel.data import (
         MultiAssignmentBase, TemporaryVariable, temp_var_scope)
 from loopy.diagnostic import warn_with_kernel, LoopyError
+from pytools import Record
 
 
 __doc__ = """
@@ -466,7 +467,7 @@ def stringify_stats_mapping(m):
 
 # {{{ Op descriptor
 
-class Op(object):
+class Op(Record):
     """A descriptor for a type of arithmetic operation.
 
     .. attribute:: dtype
@@ -481,26 +482,14 @@ class Op(object):
 
     """
 
-    # FIXME: This could be done much more briefly by inheriting from Record.
-
     def __init__(self, dtype=None, name=None, count_granularity=None):
-        self.name = name
-        self.count_granularity = count_granularity
         if dtype is None:
-            self.dtype = dtype
+            Record.__init__(self, dtype=dtype, name=name,
+                            count_granularity=count_granularity)
         else:
             from loopy.types import to_loopy_type
-            self.dtype = to_loopy_type(dtype)
-
-    def __eq__(self, other):
-        return isinstance(other, Op) and (
-                (self.dtype is None or other.dtype is None or
-                 self.dtype == other.dtype) and
-                (self.name is None or other.name is None or
-                 self.name == other.name) and
-                (self.count_granularity is None or
-                 other.count_granularity is None or
-                 self.count_granularity == other.count_granularity))
+            Record.__init__(self, dtype=to_loopy_type(dtype), name=name,
+                            count_granularity=count_granularity)
 
     def __hash__(self):
         return hash(str(self))
@@ -513,7 +502,7 @@ class Op(object):
 
 # {{{ MemAccess descriptor
 
-class MemAccess(object):
+class MemAccess(Record):
     """A descriptor for a type of memory access.
 
     .. attribute:: mtype
@@ -547,17 +536,6 @@ class MemAccess(object):
 
     def __init__(self, mtype=None, dtype=None, stride=None, direction=None,
                  variable=None, count_granularity=None):
-        self.mtype = mtype
-        self.stride = stride
-        self.direction = direction
-        self.variable = variable
-        self.count_granularity = count_granularity
-
-        if dtype is None:
-            self.dtype = dtype
-        else:
-            from loopy.types import to_loopy_type
-            self.dtype = to_loopy_type(dtype)
 
         #TODO currently giving all lmem access stride=None
         if (mtype == 'local') and (stride is not None):
@@ -569,34 +547,15 @@ class MemAccess(object):
             raise NotImplementedError("MemAccess: variable must be None when "
                                       "mtype is 'local'")
 
-    def copy(self, mtype=None, dtype=None, stride=None, direction=None,
-             variable=None, count_granularity=None):
-        return MemAccess(
-                mtype=mtype if mtype is not None else self.mtype,
-                dtype=dtype if dtype is not None else self.dtype,
-                stride=stride if stride is not None else self.stride,
-                direction=direction if direction is not None else self.direction,
-                variable=variable if variable is not None else self.variable,
-                count_granularity=count_granularity
-                if count_granularity is not None
-                else self.count_granularity)
-
-    def __eq__(self, other):
-        return isinstance(other, MemAccess) and (
-                (self.mtype is None or other.mtype is None or
-                 self.mtype == other.mtype) and
-                (self.dtype is None or other.dtype is None or
-                 self.dtype == other.dtype) and
-                (self.stride is None or other.stride is None or
-                 self.stride == other.stride) and
-                (self.direction is None or other.direction is None or
-                 self.direction == other.direction) and
-                (self.variable is None or other.variable is None or
-                 self.variable == other.variable) and
-                (self.count_granularity is None or
-                 other.count_granularity is None or
-                 self.count_granularity == other.count_granularity)
-                )
+        if dtype is None:
+            Record.__init__(self, mtype=mtype, dtype=dtype, stride=stride,
+                            direction=direction, variable=variable,
+                            count_granularity=count_granularity)
+        else:
+            from loopy.types import to_loopy_type
+            Record.__init__(self, mtype=mtype, dtype=to_loopy_type(dtype), stride=stride,
+                            direction=direction, variable=variable,
+                            count_granularity=count_granularity)
 
     def __hash__(self):
         return hash(str(self))
-- 
GitLab