From 66fd9ed4d99480b44770f48322d9db08e49df8df Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Fri, 26 May 2023 10:09:40 -0500 Subject: [PATCH] homogenize dipatching into pytato routines by characterizing into unary vs multi-ary --- arraycontext/impl/pytato/fake_numpy.py | 47 ++++++-------------------- 1 file changed, 11 insertions(+), 36 deletions(-) diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index a97f1c5..67dc60b 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -50,20 +50,29 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): :ref:`Pytato docs <pytato:memory-layout>` for more on this. """ - _pt_funcs = frozenset({ + _pt_unary_funcs = frozenset({ "sin", "cos", "tan", "arcsin", "arccos", "arctan", "sinh", "cosh", "tanh", "exp", "log", "log10", "sqrt", "abs", "isnan", "real", "imag", "conj", }) + _pt_multi_ary_funcs = frozenset({ + "arctan2", "equal", "greater", "greater_equal", "less", "less_equal", + "not_equal", "minimum", "maximum", "where", + }) + def _get_fake_numpy_linalg_namespace(self): return PytatoFakeNumpyLinalgNamespace(self._array_context) def __getattr__(self, name): - if name in self._pt_funcs: + if name in self._pt_unary_funcs: from functools import partial return partial(rec_map_array_container, getattr(pt, name)) + if name in self._pt_multi_ary_funcs: + from functools import partial + return partial(rec_multimap_array_container, getattr(pt, name)) + return super().__getattr__(name) # NOTE: the order of these follows the order in numpy docs @@ -175,31 +184,10 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): return rec_equal(a, b) - def greater(self, x, y): - return rec_multimap_array_container(pt.greater, x, y) - - def greater_equal(self, x, y): - return rec_multimap_array_container(pt.greater_equal, x, y) - - def less(self, x, y): - return rec_multimap_array_container(pt.less, x, y) - - def less_equal(self, x, y): - return rec_multimap_array_container(pt.less_equal, x, y) - - def equal(self, x, y): - return rec_multimap_array_container(pt.equal, x, y) - - def not_equal(self, x, y): - return rec_multimap_array_container(pt.not_equal, x, y) - # }}} # {{{ mathematical functions - def arctan2(self, y, x): - return rec_multimap_array_container(pt.arctan2, y, x) - def sum(self, a, axis=None, dtype=None): def _pt_sum(ary): if dtype not in [ary.dtype, None]: @@ -209,18 +197,12 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): return rec_map_reduce_array_container(sum, _pt_sum, a) - def maximum(self, x, y): - return rec_multimap_array_container(pt.maximum, x, y) - def amax(self, a, axis=None): return rec_map_reduce_array_container( partial(reduce, pt.maximum), partial(pt.amax, axis=axis), a) max = amax - def minimum(self, x, y): - return rec_multimap_array_container(pt.minimum, x, y) - def amin(self, a, axis=None): return rec_map_reduce_array_container( partial(reduce, pt.minimum), partial(pt.amin, axis=axis), a) @@ -231,10 +213,3 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): return self.abs(a) # }}} - - # {{{ sorting, searching, and counting - - def where(self, criterion, then, else_): - return rec_multimap_array_container(pt.where, criterion, then, else_) - - # }}} -- GitLab