From 79873af89ebac58c4c1fbecb32ce9cfb1a8e8925 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sat, 11 Apr 2015 17:28:04 -0500 Subject: [PATCH] Add comparison operator for meshes --- meshmode/mesh/__init__.py | 45 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/meshmode/mesh/__init__.py b/meshmode/mesh/__init__.py index 6bd4b924..299f9ef8 100644 --- a/meshmode/mesh/__init__.py +++ b/meshmode/mesh/__init__.py @@ -78,6 +78,9 @@ class MeshElementGroup(Record): The number of dimensions spanned by the element. *Not* the ambient dimension, see :attr:`Mesh.ambient_dim` for that. + + .. automethod:: __eq__ + .. automethod:: __ne__ """ def __init__(self, order, vertex_indices, nodes, @@ -134,6 +137,19 @@ class MeshElementGroup(Record): def nunit_nodes(self): return self.unit_nodes.shape[-1] + def __eq__(self, other): + return ( + type(self) == type(other) + and self.order == other.order + and np.array_equal(self.vertex_indices, other.vertex_indices) + and np.array_equal(self.nodes, other.nodes) + and np.array_equal(self.unit_nodes, other.unit_nodes) + and self.element_nr_base == other.element_nr_base + and self.node_nr_base == other.node_nr_base) + + def __ne__(self, other): + return not self.__eq__(other) + class SimplexElementGroup(MeshElementGroup): def __init__(self, order, vertex_indices, nodes, @@ -236,8 +252,21 @@ class ElementConnectivity(Record): ``element_id_t []`` See :attr:`neighbors_starts`. + + .. automethod:: __eq__ + .. automethod:: __ne__ """ + def __eq__(self, other): + return ( + type(self) == type(other) + and np.array_equal(self.neighbors_starts, + other.neighbors_starts) + and np.array_equal(self.neighbors, other.neighbors)) + + def __ne__(self, other): + return not self.__eq__(other) + class Mesh(Record): """ @@ -260,6 +289,9 @@ class Mesh(Record): .. attribute:: vertex_id_dtype .. attribute:: element_id_dtype + + .. automethod:: __eq__ + .. automethod:: __ne__ """ def __init__(self, vertices, groups, skip_tests=False, @@ -344,6 +376,19 @@ class Mesh(Record): return self._element_connectivity + def __eq__(self, other): + return ( + type(self) == type(other) + and np.array_equal(self.vertices, other.vertices) + and self.groups == other.groups + and self.vertex_id_dtype == other.vertex_id_dtype + and self.element_id_dtype == other.element_id_dtype + and (self._element_connectivity + == other._element_connectivity)) + + def __ne__(self, other): + return not self.__eq__(other) + # Design experience: Try not to add too many global data structures to the # mesh. Let the element groups be responsible for that at the mesh level. # -- GitLab