diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 01054bac6b90d2960f3ddc6ee25cd13fc1d91d4d..a984ef3c4543636269889ee3a305d5f45666d0a9 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -105,32 +105,57 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace): return rec_multimap_array_container(where_inner, criterion, then, else_) - def sum(self, a, dtype=None): - result = rec_map_reduce_array_container( - sum, - partial(cl_array.sum, dtype=dtype, queue=self._array_context.queue), - a) + def sum(self, a, axis=None, dtype=None): + + if isinstance(axis, int): + axis = axis, + + def _rec_sum(ary): + if axis not in [None, tuple(range(ary.ndim))]: + raise NotImplementedError(f"Sum over '{axis}' axes not supported.") + + return cl_array.sum(ary, dtype=dtype, queue=self._array_context.queue) + + result = rec_map_reduce_array_container(sum, _rec_sum, a) if not self._array_context._force_device_scalars: result = result.get()[()] return result - def min(self, a): + def min(self, a, axis=None): queue = self._array_context.queue + + if isinstance(axis, int): + axis = axis, + + def _rec_min(ary): + if axis not in [None, tuple(range(ary.ndim))]: + raise NotImplementedError(f"Min. over '{axis}' axes not supported.") + return cl_array.min(ary, queue=queue) + result = rec_map_reduce_array_container( partial(reduce, partial(cl_array.minimum, queue=queue)), - partial(cl_array.min, queue=queue), + _rec_min, a) if not self._array_context._force_device_scalars: result = result.get()[()] return result - def max(self, a): + def max(self, a, axis=None): queue = self._array_context.queue + + if isinstance(axis, int): + axis = axis, + + def _rec_max(ary): + if axis not in [None, tuple(range(ary.ndim))]: + raise NotImplementedError(f"Max. over '{axis}' axes not supported.") + return cl_array.max(ary, queue=queue) + result = rec_map_reduce_array_container( partial(reduce, partial(cl_array.maximum, queue=queue)), - partial(cl_array.max, queue=queue), + _rec_max, a) if not self._array_context._force_device_scalars: diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index f17a4abbc13035bf63956f96dafe005afdc6606f..f89bc451279785072598f658690037a83908718b 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -85,22 +85,22 @@ class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): def where(self, criterion, then, else_): return rec_multimap_array_container(pt.where, criterion, then, else_) - def sum(self, a, dtype=None): + def sum(self, a, axis=None, dtype=None): def _pt_sum(ary): if dtype not in [ary.dtype, None]: raise NotImplementedError - return pt.sum(ary) + return pt.sum(ary, axis=axis) return rec_map_reduce_array_container(sum, _pt_sum, a) - def min(self, a): + def min(self, a, axis=None): return rec_map_reduce_array_container( - partial(reduce, pt.minimum), pt.amin, a) + partial(reduce, pt.minimum), partial(pt.amin, axis=axis), a) - def max(self, a): + def max(self, a, axis=None): return rec_map_reduce_array_container( - partial(reduce, pt.maximum), pt.amax, a) + partial(reduce, pt.maximum), partial(pt.amax, axis=axis), a) def stack(self, arrays, axis=0): return rec_multimap_array_container(