Skip to content
Snippets Groups Projects
Commit 7f07c08c authored by Ellis Hoag's avatar Ellis Hoag
Browse files

spelling

parent 9d674052
No related branches found
No related tags found
No related merge requests found
......@@ -54,7 +54,8 @@ def partition_mesh(mesh, part_per_element, part_nr):
*part_to_global* is a :class:`numpy.ndarray` mapping element
numbers on *part_mesh* to ones in *mesh*.
"""
assert len(part_per_element) == mesh.nelements, "part_per_element must have shape (mesh.nelements,)"
assert len(part_per_element) == mesh.nelements, (
"part_per_element must have shape (mesh.nelements,)")
# Contains the indices of the elements requested.
queried_elems = np.where(np.array(part_per_element) == part_nr)[0]
......
......@@ -50,7 +50,7 @@ logger = logging.getLogger(__name__)
# {{{ partition_mesh
@pytest.mark.parameterize("mesh_type", ["torus", "box"])
@pytest.mark.parametrize("mesh_type", ["torus", "box"])
def test_partition_mesh(mesh_type):
if mesh_type == "torus":
from meshmode.mesh.generation import generate_torus
......@@ -59,14 +59,14 @@ def test_partition_mesh(mesh_type):
part_per_element = np.array([0, 1, 2, 1, 1, 2, 1, 0])
from meshmode.mesh.processing import partition_mesh
(part_mesh, part_to_global) = partition_mesh(my_mesh, part_per_element, 0)
assert part_mesh.nelements == 2
(part_mesh0, _) = partition_mesh(my_mesh, part_per_element, 0)
(part_mesh1, _) = partition_mesh(my_mesh, part_per_element, 1)
(part_mesh2, _) = partition_mesh(my_mesh, part_per_element, 2)
(part_mesh, part_to_global) = partition_mesh(my_mesh, part_per_element, 1)
assert part_mesh.nelements == 4
assert part_mesh0.nelements == 2
assert part_mesh1.nelements == 4
assert part_mesh2.nelements == 2
(part_mesh, part_to_global) = partition_mesh(my_mesh, part_per_element, 2)
assert part_mesh.nelements == 2
elif mesh_type == "box":
from meshmode.mesh.generation import generate_box_mesh
seg = np.linspace(0, 1, 10)
......@@ -76,7 +76,8 @@ def test_partition_mesh(mesh_type):
adjacency_list = np.zeros((mesh.nelements,), dtype=set)
for elem in range(mesh.nelements):
adjacency_list[elem] = set()
for n in range(mesh.nodal_adjacency.neighbors_starts[elem], mesh.nodal_adjacency.neighbors_starts[elem + 1]):
starts = mesh.nodal_adjacency.neighbors_starts
for n in range(starts[elem], starts[elem + 1]):
adjacency_list[elem].add(mesh.nodal_adjacency.neighbors[n])
from pymetis import part_graph
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment