From 811702ef330bae758d70a89cee5d19aa7a1e4d55 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 11 Nov 2020 20:28:28 -0600 Subject: [PATCH] Improve awkward DOFArray special-casing in eager, now that it's no longer an ndarray subclass --- grudge/eager.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/grudge/eager.py b/grudge/eager.py index 60fc7bd3..482a37ca 100644 --- a/grudge/eager.py +++ b/grudge/eager.py @@ -28,7 +28,7 @@ import pyopencl.array as cla # noqa from grudge import sym, bind from meshmode.mesh import BTAG_ALL, BTAG_NONE, BTAG_PARTITION # noqa -from meshmode.dof_array import freeze, DOFArray, flatten, unflatten +from meshmode.dof_array import freeze, flatten, unflatten from grudge.discretization import DGDiscretizationWithBoundaries from grudge.symbolic.primitives import TracePair @@ -83,9 +83,7 @@ class EagerDGDiscretization(DGDiscretizationWithBoundaries): :arg tgt: a :class:`~grudge.sym.DOFDesc`, or a value convertible to one :arg vec: a :class:`~meshmode.dof_array.DOFArray` """ - if (isinstance(vec, np.ndarray) - and vec.dtype.char == "O" - and not isinstance(vec, DOFArray)): + if isinstance(vec, np.ndarray): return obj_array_vectorize( lambda el: self.project(src, tgt, el), vec) @@ -270,9 +268,7 @@ class EagerDGDiscretization(DGDiscretizationWithBoundaries): local_only=True) def inverse_mass(self, vec): - if (isinstance(vec, np.ndarray) - and vec.dtype.char == "O" - and not isinstance(vec, DOFArray)): + if isinstance(vec, np.ndarray): return obj_array_vectorize( lambda el: self.inverse_mass(el), vec) @@ -292,9 +288,7 @@ class EagerDGDiscretization(DGDiscretizationWithBoundaries): else: raise TypeError("invalid number of arguments") - if (isinstance(vec, np.ndarray) - and vec.dtype.char == "O" - and not isinstance(vec, DOFArray)): + if isinstance(vec, np.ndarray): return obj_array_vectorize( lambda el: self.face_mass(dd, el), vec) @@ -312,9 +306,7 @@ class EagerDGDiscretization(DGDiscretizationWithBoundaries): dd = sym.as_dofdesc(dd) - if (isinstance(vec, np.ndarray) - and vec.dtype.char == "O" - and not isinstance(vec, DOFArray)): + if isinstance(vec, np.ndarray): if p == 2: return sum( self.norm(vec[idx], dd=dd)**2 @@ -350,9 +342,7 @@ class EagerDGDiscretization(DGDiscretizationWithBoundaries): def interior_trace_pair(discrwb, vec): i = discrwb.project("vol", "int_faces", vec) - if (isinstance(vec, np.ndarray) - and vec.dtype.char == "O" - and not isinstance(vec, DOFArray)): + if isinstance(vec, np.ndarray): e = obj_array_vectorize( lambda el: discrwb.opposite_face_connection()(el), i) @@ -412,9 +402,7 @@ def _cross_rank_trace_pairs_scalar_field(discrwb, vec, tag=None): def cross_rank_trace_pairs(discrwb, vec, tag=None): - if (isinstance(vec, np.ndarray) - and vec.dtype.char == "O" - and not isinstance(vec, DOFArray)): + if isinstance(vec, np.ndarray): n, = vec.shape result = {} -- GitLab