From 79873af89ebac58c4c1fbecb32ce9cfb1a8e8925 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sat, 11 Apr 2015 17:28:04 -0500
Subject: [PATCH] Add comparison operator for meshes

---
 meshmode/mesh/__init__.py | 45 +++++++++++++++++++++++++++++++++++++++
 1 file changed, 45 insertions(+)

diff --git a/meshmode/mesh/__init__.py b/meshmode/mesh/__init__.py
index 6bd4b924..299f9ef8 100644
--- a/meshmode/mesh/__init__.py
+++ b/meshmode/mesh/__init__.py
@@ -78,6 +78,9 @@ class MeshElementGroup(Record):
         The number of dimensions spanned by the element.
         *Not* the ambient dimension, see :attr:`Mesh.ambient_dim`
         for that.
+
+    .. automethod:: __eq__
+    .. automethod:: __ne__
     """
 
     def __init__(self, order, vertex_indices, nodes,
@@ -134,6 +137,19 @@ class MeshElementGroup(Record):
     def nunit_nodes(self):
         return self.unit_nodes.shape[-1]
 
+    def __eq__(self, other):
+        return (
+                type(self) == type(other)
+                and self.order == other.order
+                and np.array_equal(self.vertex_indices, other.vertex_indices)
+                and np.array_equal(self.nodes, other.nodes)
+                and np.array_equal(self.unit_nodes, other.unit_nodes)
+                and self.element_nr_base == other.element_nr_base
+                and self.node_nr_base == other.node_nr_base)
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
 
 class SimplexElementGroup(MeshElementGroup):
     def __init__(self, order, vertex_indices, nodes,
@@ -236,8 +252,21 @@ class ElementConnectivity(Record):
         ``element_id_t []``
 
         See :attr:`neighbors_starts`.
+
+    .. automethod:: __eq__
+    .. automethod:: __ne__
     """
 
+    def __eq__(self, other):
+        return (
+                type(self) == type(other)
+                and np.array_equal(self.neighbors_starts,
+                    other.neighbors_starts)
+                and np.array_equal(self.neighbors, other.neighbors))
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
 
 class Mesh(Record):
     """
@@ -260,6 +289,9 @@ class Mesh(Record):
     .. attribute:: vertex_id_dtype
 
     .. attribute:: element_id_dtype
+
+    .. automethod:: __eq__
+    .. automethod:: __ne__
     """
 
     def __init__(self, vertices, groups, skip_tests=False,
@@ -344,6 +376,19 @@ class Mesh(Record):
 
         return self._element_connectivity
 
+    def __eq__(self, other):
+        return (
+                type(self) == type(other)
+                and np.array_equal(self.vertices, other.vertices)
+                and self.groups == other.groups
+                and self.vertex_id_dtype == other.vertex_id_dtype
+                and self.element_id_dtype == other.element_id_dtype
+                and (self._element_connectivity
+                        == other._element_connectivity))
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
     # Design experience: Try not to add too many global data structures to the
     # mesh. Let the element groups be responsible for that at the mesh level.
     #
-- 
GitLab