Skip to content
Snippets Groups Projects
Commit ed2374fd authored by Matthias Diener's avatar Matthias Diener
Browse files

refactor norm()

parent 295741d8
No related branches found
No related tags found
No related merge requests found
......@@ -24,7 +24,7 @@ THE SOFTWARE.
import numpy as np
from arraycontext.container import is_array_container
from arraycontext.container import is_array_container, serialize_container
from arraycontext.container.traversal import (
rec_map_array_container, multimapped_over_array_containers)
......@@ -174,6 +174,32 @@ class BaseFakeNumpyLinalgNamespace:
def __init__(self, array_context):
self._array_context = array_context
def norm(self, ary, ord=None):
from numbers import Number
if isinstance(ary, Number):
return abs(ary)
if is_array_container(ary):
import numpy.linalg as la
return la.norm(
[self.norm(subary, ord=ord)
for _, subary in serialize_container(ary)],
ord=ord)
if len(ary.shape) != 1:
raise NotImplementedError("only vector norms are implemented")
if ary.size == 0:
return 0
if ord == np.inf:
return self._array_context.np.max(abs(ary))
elif isinstance(ord, Number) and ord > 0:
return self._array_context.np.sum(abs(ary)**ord)**(1/ord)
else:
raise NotImplementedError(f"unsupported value of 'ord': {ord}")
# }}}
......
......@@ -39,7 +39,7 @@ from arraycontext.metadata import FirstAxisIsElementsTag
from arraycontext.fake_numpy import \
BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace
from arraycontext.container.traversal import rec_multimap_array_container
from arraycontext.container import serialize_container, is_array_container
from arraycontext.container import serialize_container
from arraycontext.context import ArrayContext
......@@ -165,10 +165,6 @@ def _flatten_array(ary):
class _PyOpenCLFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
def norm(self, ary, ord=None):
from numbers import Number
if isinstance(ary, Number):
return abs(ary)
if ord is None:
ord = 2
......@@ -192,25 +188,7 @@ class _PyOpenCLFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
for _, subary in serialize_container(ary)],
ord=ord)
if is_array_container(ary):
import numpy.linalg as la
return la.norm(
[self.norm(subary, ord=ord)
for _, subary in serialize_container(ary)],
ord=ord)
if len(ary.shape) != 1:
raise NotImplementedError("only vector norms are implemented")
if ary.size == 0:
return 0
if ord == np.inf:
return self._array_context.np.max(abs(ary))
elif isinstance(ord, Number) and ord > 0:
return self._array_context.np.sum(abs(ary)**ord)**(1/ord)
else:
raise NotImplementedError(f"unsupported value of 'ord': {ord}")
return super().norm(ary, ord)
# }}}
......
......@@ -36,57 +36,11 @@ from pytools.tag import Tag
from numbers import Number
import loopy as lp
from arraycontext.container import serialize_container, is_array_container
class _PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
def norm(self, ary, ord=None):
from numbers import Number
if isinstance(ary, Number):
return abs(ary)
if ord is None:
ord = 2
try:
from meshmode.dof_array import DOFArray
except ImportError:
pass
else:
if isinstance(ary, DOFArray):
from warnings import warn
warn("Taking an actx.np.linalg.norm of a DOFArray is deprecated. "
"(DOFArrays use 2D arrays internally, and "
"actx.np.linalg.norm should compute matrix norms of those.) "
"This will stop working in 2022. "
"Use meshmode.dof_array.flat_norm instead.",
DeprecationWarning, stacklevel=2)
import numpy.linalg as la
return la.norm(
[self.norm(_flatten_array(subary), ord=ord)
for _, subary in serialize_container(ary)],
ord=ord)
if is_array_container(ary):
import numpy.linalg as la
return la.norm(
[self.norm(subary, ord=ord)
for _, subary in serialize_container(ary)],
ord=ord)
if len(ary.shape) != 1:
raise NotImplementedError("only vector norms are implemented")
if ary.size == 0:
return 0
if ord == np.inf:
return self._array_context.np.max(abs(ary))
elif isinstance(ord, Number) and ord > 0:
return self._array_context.np.sum(abs(ary)**ord)**(1/ord)
else:
raise NotImplementedError(f"unsupported value of 'ord': {ord}")
# FIXME: handle isinstance(ary, DOFArray) case
return super().norm(ary, ord)
class _PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment