diff --git a/pytato/__init__.py b/pytato/__init__.py index 5eb6acd23a789a2cb47e82728bc25a7fddd8cafb..4a89b01bb775e83fda9a49fce886643c01afb624 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -47,7 +47,7 @@ from pytato.array import ( dot, vdot, ) -from pytato.reductions import sum, amax, amin, prod +from pytato.reductions import sum, amax, amin, prod, any, all from pytato.cmath import (abs, sin, cos, tan, arcsin, arccos, arctan, sinh, cosh, tanh, exp, log, log10, isnan, sqrt, conj, arctan2, real, imag) @@ -86,7 +86,8 @@ __all__ = ( "logical_or", "logical_and", "logical_not", - "sum", "amax", "amin", "prod", + "sum", "amax", "amin", "prod", "all", "any", + "real", "imag", "dot", "vdot", diff --git a/pytato/reductions.py b/pytato/reductions.py index 43fa492d7dd60aeb72ce318269563704d6bcec8c..16b9d5e881de26855617ee80e6f873ebe99a159d 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -41,6 +41,8 @@ __doc__ = """ .. autofunction:: amin .. autofunction:: amax .. autofunction:: prod +.. autofunction:: all +.. autofunction:: any """ # }}} @@ -196,6 +198,30 @@ def prod(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: """ return _make_reduction_lambda("product", a, axis) + +def all(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: + """ + Returns the logical-and array *a*'s elements along the *axis* axes. + + :arg a: The :class:`pytato.Array` on which to perform the reduction. + + :arg axis: The axes along which the elements are to be product-reduced. + Defaults to all axes of the input array. + """ + return _make_reduction_lambda("all", a, axis) + + +def any(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: + """ + Returns the logical-or of array *a*'s elements along the *axis* axes. + + :arg a: The :class:`pytato.Array` on which to perform the reduction. + + :arg axis: The axes along which the elements are to be product-reduced. + Defaults to all axes of the input array. + """ + return _make_reduction_lambda("any", a, axis) + # }}} # vim: foldmethod=marker diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 67f102585cb457acb5f62519f03f447ae97fcaa1..16d00d7622b6aa7b7fa9af56e641893e80c82b18 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -227,7 +227,7 @@ class Reduce(ExpressionBase): .. attribute:: op - One of ``"sum"``, ``"product"``, ``"max"``, ``"min"``. + One of ``"sum"``, ``"product"``, ``"max"``, ``"min"``,``"all"``, ``"any"``. .. attribute:: bounds @@ -240,7 +240,7 @@ class Reduce(ExpressionBase): def __init__(self, inner_expr: ScalarExpression, op: str, bounds: Any) -> None: self.inner_expr = inner_expr - if op not in ["sum", "product", "max", "min"]: + if op not in {"sum", "product", "max", "min", "all", "any"}: raise ValueError(f"unsupported op: {op}") self.op = op self.bounds = bounds diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index e9e7dc2918c0398bc577d5373e94864f131e205f..dc09e9051fa19666cc66945d6d8bbda605188767 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -520,6 +520,8 @@ PYTATO_REDUCTION_TO_LOOPY_REDUCTION = { "product": "product", "max": "max", "min": "min", + "all": "all", + "any": "any", } diff --git a/test/test_codegen.py b/test/test_codegen.py index 46c4efbe6c08228088524af122bc7aba6df7c82a..6ca17b4f08708e4bf4928252ac0d277ce21954d6 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -779,7 +779,7 @@ def test_call_loopy_with_scalar_array_inputs(ctx_factory): @pytest.mark.parametrize("axis", (None, 1, 0)) -@pytest.mark.parametrize("redn", ("sum", "amax", "amin", "prod")) +@pytest.mark.parametrize("redn", ("sum", "amax", "amin", "prod", "all", "any")) @pytest.mark.parametrize("shape", [(2, 2), (1, 2, 1), (3, 4, 5)]) def test_reductions(ctx_factory, axis, redn, shape): queue = cl.CommandQueue(ctx_factory())