From 387f7ad44f97a2e4ffb6ffa6b807a705dafca04f Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Mon, 20 Sep 2021 16:21:05 -0500
Subject: [PATCH] support reductions with targeted axes

---
 arraycontext/impl/pyopencl/fake_numpy.py | 43 +++++++++++++++++++-----
 arraycontext/impl/pytato/fake_numpy.py   | 12 +++----
 2 files changed, 40 insertions(+), 15 deletions(-)

diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py
index 01054ba..a984ef3 100644
--- a/arraycontext/impl/pyopencl/fake_numpy.py
+++ b/arraycontext/impl/pyopencl/fake_numpy.py
@@ -105,32 +105,57 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace):
 
         return rec_multimap_array_container(where_inner, criterion, then, else_)
 
-    def sum(self, a, dtype=None):
-        result = rec_map_reduce_array_container(
-                sum,
-                partial(cl_array.sum, dtype=dtype, queue=self._array_context.queue),
-                a)
+    def sum(self, a, axis=None, dtype=None):
+
+        if isinstance(axis, int):
+            axis = axis,
+
+        def _rec_sum(ary):
+            if axis not in [None, tuple(range(ary.ndim))]:
+                raise NotImplementedError(f"Sum over '{axis}' axes not supported.")
+
+            return cl_array.sum(ary, dtype=dtype, queue=self._array_context.queue)
+
+        result = rec_map_reduce_array_container(sum, _rec_sum, a)
 
         if not self._array_context._force_device_scalars:
             result = result.get()[()]
         return result
 
-    def min(self, a):
+    def min(self, a, axis=None):
         queue = self._array_context.queue
+
+        if isinstance(axis, int):
+            axis = axis,
+
+        def _rec_min(ary):
+            if axis not in [None, tuple(range(ary.ndim))]:
+                raise NotImplementedError(f"Min. over '{axis}' axes not supported.")
+            return cl_array.min(ary, queue=queue)
+
         result = rec_map_reduce_array_container(
                 partial(reduce, partial(cl_array.minimum, queue=queue)),
-                partial(cl_array.min, queue=queue),
+                _rec_min,
                 a)
 
         if not self._array_context._force_device_scalars:
             result = result.get()[()]
         return result
 
-    def max(self, a):
+    def max(self, a, axis=None):
         queue = self._array_context.queue
+
+        if isinstance(axis, int):
+            axis = axis,
+
+        def _rec_max(ary):
+            if axis not in [None, tuple(range(ary.ndim))]:
+                raise NotImplementedError(f"Max. over '{axis}' axes not supported.")
+            return cl_array.max(ary, queue=queue)
+
         result = rec_map_reduce_array_container(
                 partial(reduce, partial(cl_array.maximum, queue=queue)),
-                partial(cl_array.max, queue=queue),
+                _rec_max,
                 a)
 
         if not self._array_context._force_device_scalars:
diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py
index f17a4ab..f89bc45 100644
--- a/arraycontext/impl/pytato/fake_numpy.py
+++ b/arraycontext/impl/pytato/fake_numpy.py
@@ -85,22 +85,22 @@ class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace):
     def where(self, criterion, then, else_):
         return rec_multimap_array_container(pt.where, criterion, then, else_)
 
-    def sum(self, a, dtype=None):
+    def sum(self, a, axis=None, dtype=None):
         def _pt_sum(ary):
             if dtype not in [ary.dtype, None]:
                 raise NotImplementedError
 
-            return pt.sum(ary)
+            return pt.sum(ary, axis=axis)
 
         return rec_map_reduce_array_container(sum, _pt_sum, a)
 
-    def min(self, a):
+    def min(self, a, axis=None):
         return rec_map_reduce_array_container(
-                partial(reduce, pt.minimum), pt.amin, a)
+                partial(reduce, pt.minimum), partial(pt.amin, axis=axis), a)
 
-    def max(self, a):
+    def max(self, a, axis=None):
         return rec_map_reduce_array_container(
-                partial(reduce, pt.maximum), pt.amax, a)
+                partial(reduce, pt.maximum), partial(pt.amax, axis=axis), a)
 
     def stack(self, arrays, axis=0):
         return rec_multimap_array_container(
-- 
GitLab