From 66fd9ed4d99480b44770f48322d9db08e49df8df Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Fri, 26 May 2023 10:09:40 -0500
Subject: [PATCH] homogenize dipatching into pytato routines by characterizing
 into unary vs multi-ary

---
 arraycontext/impl/pytato/fake_numpy.py | 47 ++++++--------------------
 1 file changed, 11 insertions(+), 36 deletions(-)

diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py
index a97f1c5..67dc60b 100644
--- a/arraycontext/impl/pytato/fake_numpy.py
+++ b/arraycontext/impl/pytato/fake_numpy.py
@@ -50,20 +50,29 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
         :ref:`Pytato docs <pytato:memory-layout>` for more on this.
     """
 
-    _pt_funcs = frozenset({
+    _pt_unary_funcs = frozenset({
         "sin", "cos", "tan", "arcsin", "arccos", "arctan",
         "sinh", "cosh", "tanh", "exp", "log", "log10",
         "sqrt", "abs", "isnan", "real", "imag", "conj",
         })
 
+    _pt_multi_ary_funcs = frozenset({
+        "arctan2", "equal", "greater", "greater_equal", "less", "less_equal",
+        "not_equal", "minimum", "maximum", "where",
+    })
+
     def _get_fake_numpy_linalg_namespace(self):
         return PytatoFakeNumpyLinalgNamespace(self._array_context)
 
     def __getattr__(self, name):
-        if name in self._pt_funcs:
+        if name in self._pt_unary_funcs:
             from functools import partial
             return partial(rec_map_array_container, getattr(pt, name))
 
+        if name in self._pt_multi_ary_funcs:
+            from functools import partial
+            return partial(rec_multimap_array_container, getattr(pt, name))
+
         return super().__getattr__(name)
 
     # NOTE: the order of these follows the order in numpy docs
@@ -175,31 +184,10 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
 
         return rec_equal(a, b)
 
-    def greater(self, x, y):
-        return rec_multimap_array_container(pt.greater, x, y)
-
-    def greater_equal(self, x, y):
-        return rec_multimap_array_container(pt.greater_equal, x, y)
-
-    def less(self, x, y):
-        return rec_multimap_array_container(pt.less, x, y)
-
-    def less_equal(self, x, y):
-        return rec_multimap_array_container(pt.less_equal, x, y)
-
-    def equal(self, x, y):
-        return rec_multimap_array_container(pt.equal, x, y)
-
-    def not_equal(self, x, y):
-        return rec_multimap_array_container(pt.not_equal, x, y)
-
     # }}}
 
     # {{{ mathematical functions
 
-    def arctan2(self, y, x):
-        return rec_multimap_array_container(pt.arctan2, y, x)
-
     def sum(self, a, axis=None, dtype=None):
         def _pt_sum(ary):
             if dtype not in [ary.dtype, None]:
@@ -209,18 +197,12 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
 
         return rec_map_reduce_array_container(sum, _pt_sum, a)
 
-    def maximum(self, x, y):
-        return rec_multimap_array_container(pt.maximum, x, y)
-
     def amax(self, a, axis=None):
         return rec_map_reduce_array_container(
                 partial(reduce, pt.maximum), partial(pt.amax, axis=axis), a)
 
     max = amax
 
-    def minimum(self, x, y):
-        return rec_multimap_array_container(pt.minimum, x, y)
-
     def amin(self, a, axis=None):
         return rec_map_reduce_array_container(
                 partial(reduce, pt.minimum), partial(pt.amin, axis=axis), a)
@@ -231,10 +213,3 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
         return self.abs(a)
 
     # }}}
-
-    # {{{ sorting, searching, and counting
-
-    def where(self, criterion, then, else_):
-        return rec_multimap_array_container(pt.where, criterion, then, else_)
-
-    # }}}
-- 
GitLab