From ed2374fda1e00f74bc7c91943f9da1a0da2003f5 Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Tue, 1 Jun 2021 16:32:28 -0500 Subject: [PATCH] refactor norm() --- arraycontext/fake_numpy.py | 28 +++++++++++++++++++- arraycontext/impl/pyopencl.py | 26 ++---------------- arraycontext/impl/pytato.py | 50 ++--------------------------------- 3 files changed, 31 insertions(+), 73 deletions(-) diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index 6b0163d..0657ecb 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -24,7 +24,7 @@ THE SOFTWARE. import numpy as np -from arraycontext.container import is_array_container +from arraycontext.container import is_array_container, serialize_container from arraycontext.container.traversal import ( rec_map_array_container, multimapped_over_array_containers) @@ -174,6 +174,32 @@ class BaseFakeNumpyLinalgNamespace: def __init__(self, array_context): self._array_context = array_context + def norm(self, ary, ord=None): + from numbers import Number + if isinstance(ary, Number): + return abs(ary) + + if is_array_container(ary): + import numpy.linalg as la + return la.norm( + [self.norm(subary, ord=ord) + for _, subary in serialize_container(ary)], + ord=ord) + + if len(ary.shape) != 1: + raise NotImplementedError("only vector norms are implemented") + + if ary.size == 0: + return 0 + + if ord == np.inf: + return self._array_context.np.max(abs(ary)) + elif isinstance(ord, Number) and ord > 0: + return self._array_context.np.sum(abs(ary)**ord)**(1/ord) + else: + raise NotImplementedError(f"unsupported value of 'ord': {ord}") + + # }}} diff --git a/arraycontext/impl/pyopencl.py b/arraycontext/impl/pyopencl.py index 1906509..e50a58d 100644 --- a/arraycontext/impl/pyopencl.py +++ b/arraycontext/impl/pyopencl.py @@ -39,7 +39,7 @@ from arraycontext.metadata import FirstAxisIsElementsTag from arraycontext.fake_numpy import \ BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace from arraycontext.container.traversal import rec_multimap_array_container -from arraycontext.container import serialize_container, is_array_container +from arraycontext.container import serialize_container from arraycontext.context import ArrayContext @@ -165,10 +165,6 @@ def _flatten_array(ary): class _PyOpenCLFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): def norm(self, ary, ord=None): - from numbers import Number - if isinstance(ary, Number): - return abs(ary) - if ord is None: ord = 2 @@ -192,25 +188,7 @@ class _PyOpenCLFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): for _, subary in serialize_container(ary)], ord=ord) - if is_array_container(ary): - import numpy.linalg as la - return la.norm( - [self.norm(subary, ord=ord) - for _, subary in serialize_container(ary)], - ord=ord) - - if len(ary.shape) != 1: - raise NotImplementedError("only vector norms are implemented") - - if ary.size == 0: - return 0 - - if ord == np.inf: - return self._array_context.np.max(abs(ary)) - elif isinstance(ord, Number) and ord > 0: - return self._array_context.np.sum(abs(ary)**ord)**(1/ord) - else: - raise NotImplementedError(f"unsupported value of 'ord': {ord}") + return super().norm(ary, ord) # }}} diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py index 5ecf9cf..67fd554 100644 --- a/arraycontext/impl/pytato.py +++ b/arraycontext/impl/pytato.py @@ -36,57 +36,11 @@ from pytools.tag import Tag from numbers import Number import loopy as lp -from arraycontext.container import serialize_container, is_array_container - class _PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): def norm(self, ary, ord=None): - from numbers import Number - if isinstance(ary, Number): - return abs(ary) - - if ord is None: - ord = 2 - - try: - from meshmode.dof_array import DOFArray - except ImportError: - pass - else: - if isinstance(ary, DOFArray): - from warnings import warn - warn("Taking an actx.np.linalg.norm of a DOFArray is deprecated. " - "(DOFArrays use 2D arrays internally, and " - "actx.np.linalg.norm should compute matrix norms of those.) " - "This will stop working in 2022. " - "Use meshmode.dof_array.flat_norm instead.", - DeprecationWarning, stacklevel=2) - - import numpy.linalg as la - return la.norm( - [self.norm(_flatten_array(subary), ord=ord) - for _, subary in serialize_container(ary)], - ord=ord) - - if is_array_container(ary): - import numpy.linalg as la - return la.norm( - [self.norm(subary, ord=ord) - for _, subary in serialize_container(ary)], - ord=ord) - - if len(ary.shape) != 1: - raise NotImplementedError("only vector norms are implemented") - - if ary.size == 0: - return 0 - - if ord == np.inf: - return self._array_context.np.max(abs(ary)) - elif isinstance(ord, Number) and ord > 0: - return self._array_context.np.sum(abs(ary)**ord)**(1/ord) - else: - raise NotImplementedError(f"unsupported value of 'ord': {ord}") + # FIXME: handle isinstance(ary, DOFArray) case + return super().norm(ary, ord) class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): -- GitLab