From 2d1642d9fa802db588ccc69df9c41084868b6cd6 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 11 Aug 2021 17:30:58 -0500 Subject: [PATCH] support pt.(any|all) --- pytato/__init__.py | 5 +++-- pytato/reductions.py | 26 ++++++++++++++++++++++++++ pytato/scalar_expr.py | 4 ++-- pytato/target/loopy/codegen.py | 2 ++ test/test_codegen.py | 2 +- 5 files changed, 34 insertions(+), 5 deletions(-) diff --git a/pytato/__init__.py b/pytato/__init__.py index 5eb6acd..4a89b01 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -47,7 +47,7 @@ from pytato.array import ( dot, vdot, ) -from pytato.reductions import sum, amax, amin, prod +from pytato.reductions import sum, amax, amin, prod, any, all from pytato.cmath import (abs, sin, cos, tan, arcsin, arccos, arctan, sinh, cosh, tanh, exp, log, log10, isnan, sqrt, conj, arctan2, real, imag) @@ -86,7 +86,8 @@ __all__ = ( "logical_or", "logical_and", "logical_not", - "sum", "amax", "amin", "prod", + "sum", "amax", "amin", "prod", "all", "any", + "real", "imag", "dot", "vdot", diff --git a/pytato/reductions.py b/pytato/reductions.py index 43fa492..16b9d5e 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -41,6 +41,8 @@ __doc__ = """ .. autofunction:: amin .. autofunction:: amax .. autofunction:: prod +.. autofunction:: all +.. autofunction:: any """ # }}} @@ -196,6 +198,30 @@ def prod(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: """ return _make_reduction_lambda("product", a, axis) + +def all(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: + """ + Returns the logical-and array *a*'s elements along the *axis* axes. + + :arg a: The :class:`pytato.Array` on which to perform the reduction. + + :arg axis: The axes along which the elements are to be product-reduced. + Defaults to all axes of the input array. + """ + return _make_reduction_lambda("all", a, axis) + + +def any(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: + """ + Returns the logical-or of array *a*'s elements along the *axis* axes. + + :arg a: The :class:`pytato.Array` on which to perform the reduction. + + :arg axis: The axes along which the elements are to be product-reduced. + Defaults to all axes of the input array. + """ + return _make_reduction_lambda("any", a, axis) + # }}} # vim: foldmethod=marker diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 67f1025..16d00d7 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -227,7 +227,7 @@ class Reduce(ExpressionBase): .. attribute:: op - One of ``"sum"``, ``"product"``, ``"max"``, ``"min"``. + One of ``"sum"``, ``"product"``, ``"max"``, ``"min"``,``"all"``, ``"any"``. .. attribute:: bounds @@ -240,7 +240,7 @@ class Reduce(ExpressionBase): def __init__(self, inner_expr: ScalarExpression, op: str, bounds: Any) -> None: self.inner_expr = inner_expr - if op not in ["sum", "product", "max", "min"]: + if op not in {"sum", "product", "max", "min", "all", "any"}: raise ValueError(f"unsupported op: {op}") self.op = op self.bounds = bounds diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index e9e7dc2..dc09e90 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -520,6 +520,8 @@ PYTATO_REDUCTION_TO_LOOPY_REDUCTION = { "product": "product", "max": "max", "min": "min", + "all": "all", + "any": "any", } diff --git a/test/test_codegen.py b/test/test_codegen.py index 46c4efb..6ca17b4 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -779,7 +779,7 @@ def test_call_loopy_with_scalar_array_inputs(ctx_factory): @pytest.mark.parametrize("axis", (None, 1, 0)) -@pytest.mark.parametrize("redn", ("sum", "amax", "amin", "prod")) +@pytest.mark.parametrize("redn", ("sum", "amax", "amin", "prod", "all", "any")) @pytest.mark.parametrize("shape", [(2, 2), (1, 2, 1), (3, 4, 5)]) def test_reductions(ctx_factory, axis, redn, shape): queue = cl.CommandQueue(ctx_factory()) -- GitLab