From a90e8ba148a21ad8b1913416b69c8d73beafdf71 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Sun, 17 Oct 2021 10:37:32 -0500 Subject: [PATCH] do not force host transfers when computing norms --- arraycontext/fake_numpy.py | 36 +++++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index cdb9534..4f7c177 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -147,6 +147,10 @@ class BaseFakeNumpyNamespace: def __getattr__(self, name): def loopy_implemented_elwise_func(*args): + from numbers import Number + if all(isinstance(ary, Number) for ary in args): + return getattr(np, name)(*args) + actx = self._array_context prg = _get_scalar_func_loopy_program(actx, c_name, nargs=len(args), naxes=len(args[0].shape)) @@ -221,6 +225,26 @@ def _scalar_list_norm(ary, ord): raise NotImplementedError(f"unsupported value of 'ord': {ord}") +def _reduce_norm(actx, arys, ord): + from numbers import Number + from functools import reduce + + if ord is None: + ord = 2 + + # NOTE: these are ordered by an expected usage frequency + if ord == 2: + return actx.np.sqrt(sum(subary*subary for subary in arys)) + elif ord == np.inf: + return reduce(actx.np.maximum, arys) + elif ord == -np.inf: + return reduce(actx.np.minimum, arys) + elif isinstance(ord, Number) and ord > 0: + return sum(subary**ord for subary in arys)**(1/ord) + else: + raise NotImplementedError(f"unsupported value of 'ord': {ord}") + + class BaseFakeNumpyLinalgNamespace: def __init__(self, array_context): self._array_context = array_context @@ -250,7 +274,7 @@ class BaseFakeNumpyLinalgNamespace: return flat_norm(ary, ord=ord) if is_array_container(ary): - return _scalar_list_norm([ + return _reduce_norm(actx, [ self.norm(subary, ord=ord) for _, subary in serialize_container(ary) ], ord=ord) @@ -262,14 +286,16 @@ class BaseFakeNumpyLinalgNamespace: raise NotImplementedError("only vector norms are implemented") if ary.size == 0: - return 0 + return ary.dtype.type(0) + if ord == 2: + return actx.np.sqrt(actx.np.sum(abs(ary)**2)) if ord == np.inf: - return self._array_context.np.max(abs(ary)) + return actx.np.max(abs(ary)) elif ord == -np.inf: - return self._array_context.np.min(abs(ary)) + return actx.np.min(abs(ary)) elif isinstance(ord, Number) and ord > 0: - return self._array_context.np.sum(abs(ary)**ord)**(1/ord) + return actx.np.sum(abs(ary)**ord)**(1/ord) else: raise NotImplementedError(f"unsupported value of 'ord': {ord}") # }}} -- GitLab