Skip to content
Snippets Groups Projects
Commit ef786c92 authored by Matthias Diener's avatar Matthias Diener
Browse files

_is_meshmode_dofarray

parent 2f137694
No related branches found
No related tags found
No related merge requests found
......@@ -21,3 +21,12 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
def _is_meshmode_dofarray(x):
try:
from meshmode.dof_array import DOFArray
except ImportError:
return False
else:
return isinstance(x, DOFArray)
......@@ -168,25 +168,22 @@ class _PyOpenCLFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
if ord is None:
ord = 2
try:
from meshmode.dof_array import DOFArray
except ImportError:
pass
else:
if isinstance(ary, DOFArray):
from warnings import warn
warn("Taking an actx.np.linalg.norm of a DOFArray is deprecated. "
"(DOFArrays use 2D arrays internally, and "
"actx.np.linalg.norm should compute matrix norms of those.) "
"This will stop working in 2022. "
"Use meshmode.dof_array.flat_norm instead.",
DeprecationWarning, stacklevel=2)
import numpy.linalg as la
return la.norm(
[self.norm(_flatten_array(subary), ord=ord)
for _, subary in serialize_container(ary)],
ord=ord)
from arraycontext.impl import _is_meshmode_dofarray
if _is_meshmode_dofarray(ary):
from warnings import warn
warn("Taking an actx.np.linalg.norm of a DOFArray is deprecated. "
"(DOFArrays use 2D arrays internally, and "
"actx.np.linalg.norm should compute matrix norms of those.) "
"This will stop working in 2022. "
"Use meshmode.dof_array.flat_norm instead.",
DeprecationWarning, stacklevel=2)
import numpy.linalg as la
return la.norm(
[self.norm(_flatten_array(subary), ord=ord)
for _, subary in serialize_container(ary)],
ord=ord)
return super().norm(ary, ord)
......
......@@ -121,7 +121,7 @@ class PytatoCompiledOperator:
def __call__(self, *args):
import pytato as pt
import pyopencl.array as cla
from meshmode.dof_array import DOFArray
from arraycontext.impl import _is_meshmode_dofarray
from pytools.obj_array import flat_obj_array
updated_kwargs = {}
......@@ -149,6 +149,7 @@ class PytatoCompiledOperator:
return input_dict
def from_return_dict_to_obj_array(return_dict):
from meshmode.dof_array import DOFArray
return flat_obj_array([DOFArray.from_list(self.actx,
[self.actx.thaw(return_dict[f"_msh_out_{i}_{j}"])
for j in range(self.output_spec[i])])
......@@ -163,7 +164,7 @@ class PytatoCompiledOperator:
updated_kwargs[arg_name] = cla.to_device(self.actx.queue,
np.array(arg))
elif isinstance(arg, np.ndarray) and all(isinstance(el, DOFArray)
elif isinstance(arg, np.ndarray) and all(_is_meshmode_dofarray(el)
for el in arg):
updated_kwargs.update(from_obj_array_to_input_dict(arg, iarg))
else:
......@@ -270,6 +271,7 @@ class PytatoArrayContext(ArrayContext):
def compile(self, f: Callable[[Any], Any],
inputs_like: Tuple[Union[Number, np.array], ...]) -> Callable[..., Any]:
from pytools.obj_array import flat_obj_array
from arraycontext.impl import _is_meshmode_dofarray
from meshmode.dof_array import DOFArray
import pytato as pt
......@@ -277,7 +279,7 @@ class PytatoArrayContext(ArrayContext):
if isinstance(input_like, np.number):
return pt.make_placeholder(input_like.dtype,
f"_msh_inp_{pos}")
elif isinstance(input_like, np.ndarray) and all(isinstance(e, DOFArray)
elif isinstance(input_like, np.ndarray) and all(_is_meshmode_dofarray(e)
for e in input_like):
return flat_obj_array([DOFArray.from_list(self,
[pt.make_placeholder(grp_ary.shape,
......@@ -303,7 +305,7 @@ class PytatoArrayContext(ArrayContext):
for iel, el in enumerate(inputs_like)])
if not (isinstance(outputs, np.ndarray)
and all(isinstance(e, DOFArray)
and all(_is_meshmode_dofarray(e)
for e in outputs)):
raise TypeError("Can only pass in functions that return numpy"
" array of DOFArrays.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment