From 4b2287878a88e454e030a2423d4087d9fd5e3434 Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Tue, 1 Jun 2021 11:27:38 -0500 Subject: [PATCH] copy norm() from PyOpenCLArrayContext --- arraycontext/impl/pytato.py | 52 ++++++++++++++++++++++++++++++++++--- test/test_arraycontext.py | 2 +- 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py index c021829..25220e6 100644 --- a/arraycontext/impl/pytato.py +++ b/arraycontext/impl/pytato.py @@ -36,11 +36,57 @@ from pytools.tag import Tag from numbers import Number import loopy as lp +from arraycontext.container import serialize_container, is_array_container -class _PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): - def norm(self, array, ord=None): - raise NotImplementedError +class _PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): + def norm(self, ary, ord=None): + from numbers import Number + if isinstance(ary, Number): + return abs(ary) + + if ord is None: + ord = 2 + + try: + from meshmode.dof_array import DOFArray + except ImportError: + pass + else: + if isinstance(ary, DOFArray): + from warnings import warn + warn("Taking an actx.np.linalg.norm of a DOFArray is deprecated. " + "(DOFArrays use 2D arrays internally, and " + "actx.np.linalg.norm should compute matrix norms of those.) " + "This will stop working in 2022. " + "Use meshmode.dof_array.flat_norm instead.", + DeprecationWarning, stacklevel=2) + + import numpy.linalg as la + return la.norm( + [self.norm(_flatten_array(subary), ord=ord) + for _, subary in serialize_container(ary)], + ord=ord) + + if is_array_container(ary): + import numpy.linalg as la + return la.norm( + [self.norm(subary, ord=ord) + for _, subary in serialize_container(ary)], + ord=ord) + + if len(ary.shape) != 1: + raise NotImplementedError("only vector norms are implemented") + + if ary.size == 0: + return 0 + + if ord == np.inf: + return self._array_context.np.max(abs(ary)) + elif isinstance(ord, Number) and ord > 0: + return self._array_context.np.sum(abs(ary)**ord)**(1/ord) + else: + raise NotImplementedError(f"unsupported value of 'ord': {ord}") class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): def _get_fake_numpy_linalg_namespace(self): diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index abd9843..119a9fc 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -212,7 +212,7 @@ def test_actx_stack(actx_factory): ndofs = 5000 args = [np.random.randn(ndofs) for i in range(10)] - assert_close_to_numpy_in_containers( + assert_close_to_numpy( actx, lambda _np, *_args: _np.stack(_args), args) -- GitLab