diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index cdb95348c6ca912bb39b01428aa7a0a96ecbfdb2..4f7c177104ca32555dc6184125e384dc4b3d586e 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -147,6 +147,10 @@ class BaseFakeNumpyNamespace: def __getattr__(self, name): def loopy_implemented_elwise_func(*args): + from numbers import Number + if all(isinstance(ary, Number) for ary in args): + return getattr(np, name)(*args) + actx = self._array_context prg = _get_scalar_func_loopy_program(actx, c_name, nargs=len(args), naxes=len(args[0].shape)) @@ -221,6 +225,26 @@ def _scalar_list_norm(ary, ord): raise NotImplementedError(f"unsupported value of 'ord': {ord}") +def _reduce_norm(actx, arys, ord): + from numbers import Number + from functools import reduce + + if ord is None: + ord = 2 + + # NOTE: these are ordered by an expected usage frequency + if ord == 2: + return actx.np.sqrt(sum(subary*subary for subary in arys)) + elif ord == np.inf: + return reduce(actx.np.maximum, arys) + elif ord == -np.inf: + return reduce(actx.np.minimum, arys) + elif isinstance(ord, Number) and ord > 0: + return sum(subary**ord for subary in arys)**(1/ord) + else: + raise NotImplementedError(f"unsupported value of 'ord': {ord}") + + class BaseFakeNumpyLinalgNamespace: def __init__(self, array_context): self._array_context = array_context @@ -250,7 +274,7 @@ class BaseFakeNumpyLinalgNamespace: return flat_norm(ary, ord=ord) if is_array_container(ary): - return _scalar_list_norm([ + return _reduce_norm(actx, [ self.norm(subary, ord=ord) for _, subary in serialize_container(ary) ], ord=ord) @@ -262,14 +286,16 @@ class BaseFakeNumpyLinalgNamespace: raise NotImplementedError("only vector norms are implemented") if ary.size == 0: - return 0 + return ary.dtype.type(0) + if ord == 2: + return actx.np.sqrt(actx.np.sum(abs(ary)**2)) if ord == np.inf: - return self._array_context.np.max(abs(ary)) + return actx.np.max(abs(ary)) elif ord == -np.inf: - return self._array_context.np.min(abs(ary)) + return actx.np.min(abs(ary)) elif isinstance(ord, Number) and ord > 0: - return self._array_context.np.sum(abs(ary)**ord)**(1/ord) + return actx.np.sum(abs(ary)**ord)**(1/ord) else: raise NotImplementedError(f"unsupported value of 'ord': {ord}") # }}}