From a4d6eb224822675d14813af84f36cd3614c35afd Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 31 Jul 2024 16:57:52 -0500
Subject: [PATCH] Rework dataclass array container arithmetic

- Deprecate automatic broadcasting of array context arrays
- Warn about uses of numpy array broadcasting, deprecated earlier
- Clarify documentation, warning wording
---
 arraycontext/__init__.py             |   5 +-
 arraycontext/container/arithmetic.py | 202 +++++++++++++++++++++------
 arraycontext/container/traversal.py  |   4 +-
 test/test_arraycontext.py            |  58 +++-----
 4 files changed, 185 insertions(+), 84 deletions(-)

diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py
index e8e6e9f..4e0ba83 100644
--- a/arraycontext/__init__.py
+++ b/arraycontext/__init__.py
@@ -43,7 +43,9 @@ from .container import (
     register_multivector_as_array_container,
     serialize_container,
 )
-from .container.arithmetic import with_container_arithmetic
+from .container.arithmetic import (
+    with_container_arithmetic,
+)
 from .container.dataclass import dataclass_array_container
 from .container.traversal import (
     flat_size_and_dtype,
@@ -151,7 +153,6 @@ __all__ = (
     "unflatten",
     "with_array_context",
     "with_container_arithmetic",
-    "with_container_arithmetic"
 )
 
 
diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py
index dbfdd5a..b085a7d 100644
--- a/arraycontext/container/arithmetic.py
+++ b/arraycontext/container/arithmetic.py
@@ -2,13 +2,12 @@
 from __future__ import annotations
 
 
-"""
+__doc__ = """
 .. currentmodule:: arraycontext
+
 .. autofunction:: with_container_arithmetic
 """
 
-import enum
-
 
 __copyright__ = """
 Copyright (C) 2020-1 University of Illinois Board of Trustees
@@ -34,7 +33,9 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+import enum
 from typing import Any, Callable, Optional, Tuple, TypeVar, Union
+from warnings import warn
 
 import numpy as np
 
@@ -99,8 +100,8 @@ def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str:
 
 
 def _format_binary_op_str(op_str: str,
-        arg1: Union[Tuple[str, ...], str],
-        arg2: Union[Tuple[str, ...], str]) -> str:
+        arg1: Union[Tuple[str, str], str],
+        arg2: Union[Tuple[str, str], str]) -> str:
     if isinstance(arg1, tuple) and isinstance(arg2, tuple):
         import sys
         if sys.version_info >= (3, 10):
@@ -127,6 +128,36 @@ def _format_binary_op_str(op_str: str,
         return op_str.format(arg1, arg2)
 
 
+class NumpyObjectArrayMetaclass(type):
+    def __instancecheck__(cls, instance: Any) -> bool:
+        return isinstance(instance, np.ndarray) and instance.dtype == object
+
+
+class NumpyObjectArray(metaclass=NumpyObjectArrayMetaclass):
+    pass
+
+
+class ComplainingNumpyNonObjectArrayMetaclass(type):
+    def __instancecheck__(cls, instance: Any) -> bool:
+        if isinstance(instance, np.ndarray) and instance.dtype != object:
+            # Example usage site:
+            # https://github.com/illinois-ceesd/mirgecom/blob/f5d0d97c41e8c8a05546b1d1a6a2979ec8ea3554/mirgecom/inviscid.py#L148-L149
+            # where normal is passed in by test_lfr_flux as a 'custom-made'
+            # numpy array of dtype float64.
+            warn(
+                 "Broadcasting container against non-object numpy array. "
+                 "This was never documented to work and will now stop working in "
+                 "2025. Convert the array to an object array to preserve the "
+                 "current semantics.", DeprecationWarning, stacklevel=3)
+            return True
+        else:
+            return False
+
+
+class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMetaclass):
+    pass
+
+
 def with_container_arithmetic(
         *,
         bcast_number: bool = True,
@@ -146,22 +177,16 @@ def with_container_arithmetic(
 
     :arg bcast_number: If *True*, numbers broadcast over the container
         (with the container as the 'outer' structure).
-    :arg _bcast_actx_array_type: If *True*, instances of base array types of the
-        container's array context are broadcasted over the container. Can be
-        *True* only if the container has *_cls_has_array_context_attr* set.
-        Defaulted to *bcast_number* if *_cls_has_array_context_attr* is set,
-        else *False*.
-    :arg bcast_obj_array: If *True*, :mod:`numpy` object arrays broadcast over
-        the container.  (with the container as the 'inner' structure)
-    :arg bcast_numpy_array: If *True*, any :class:`numpy.ndarray` will broadcast
-        over the container.  (with the container as the 'inner' structure)
-        If this is set to *True*, *bcast_obj_array* must also be *True*.
+    :arg bcast_obj_array: If *True*, this container will be broadcast
+        across :mod:`numpy` object arrays
+        (with the object array as the 'outer' structure).
+        Add :class:`numpy.ndarray` to *bcast_container_types* to achieve
+        the 'reverse' broadcasting.
     :arg bcast_container_types: A sequence of container types that will broadcast
-        over this container (with this container as the 'outer' structure).
+        across this container, with this container as the 'outer' structure.
         :class:`numpy.ndarray` is permitted to be part of this sequence to
-        indicate that, in such broadcasting situations, this container should
-        be the 'outer' structure. In this case, *bcast_obj_array*
-        (and consequently *bcast_numpy_array*) must be *False*.
+        indicate that object arrays (and *only* object arrays) will be broadcasat.
+        In this case, *bcast_obj_array* must be *False*.
     :arg arithmetic: Implement the conventional arithmetic operators, including
         ``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as
         :func:`abs`.
@@ -203,6 +228,17 @@ def with_container_arithmetic(
     should nest "outside" :func:dataclass_array_container`.
     """
 
+    # Hard-won design lessons:
+    #
+    # - Anything that special-cases np.ndarray by type is broken by design because:
+    #   - np.ndarray is an array context array.
+    #   - numpy object arrays can be array containers.
+    #   Using NumpyObjectArray and NumpyNonObjectArray *may* be better?
+    #   They're new, so there is no operational experience with them.
+    #
+    # - Broadcast rules are hard to change once established, particularly
+    #   because one cannot grep for their use.
+
     # {{{ handle inputs
 
     if bcast_obj_array is None:
@@ -212,9 +248,8 @@ def with_container_arithmetic(
         raise TypeError("rel_comparison must be specified")
 
     if bcast_numpy_array:
-        from warnings import warn
         warn("'bcast_numpy_array=True' is deprecated and will be unsupported"
-             " from December 2021", DeprecationWarning, stacklevel=2)
+             " from 2025.", DeprecationWarning, stacklevel=2)
 
         if _bcast_actx_array_type:
             raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'"
@@ -231,7 +266,7 @@ def with_container_arithmetic(
 
     if bcast_numpy_array:
         def numpy_pred(name: str) -> str:
-            return f"isinstance({name}, np.ndarray)"
+            return f"is_numpy_array({name})"
     elif bcast_obj_array:
         def numpy_pred(name: str) -> str:
             return f"isinstance({name}, np.ndarray) and {name}.dtype.char == 'O'"
@@ -241,12 +276,21 @@ def with_container_arithmetic(
 
     if bcast_container_types is None:
         bcast_container_types = ()
-    bcast_container_types_count = len(bcast_container_types)
 
     if np.ndarray in bcast_container_types and bcast_obj_array:
         raise ValueError("If numpy.ndarray is part of bcast_container_types, "
                 "bcast_obj_array must be False.")
 
+    numpy_check_types: list[type] = [NumpyObjectArray, ComplainingNumpyNonObjectArray]
+    bcast_container_types = tuple(
+        new_ct
+        for old_ct in bcast_container_types
+        for new_ct in
+        (numpy_check_types
+        if old_ct is np.ndarray
+        else [old_ct])
+    )
+
     desired_op_classes = set()
     if arithmetic:
         desired_op_classes.add(_OpClass.ARITHMETIC)
@@ -264,10 +308,15 @@ def with_container_arithmetic(
     # }}}
 
     def wrap(cls: Any) -> Any:
-        cls_has_array_context_attr: bool | None = \
-                _cls_has_array_context_attr
-        bcast_actx_array_type: bool | None = \
-                _bcast_actx_array_type
+        if not hasattr(cls, "__array_ufunc__"):
+            warn(f"{cls} does not have __array_ufunc__ set. "
+                 "This will cause numpy to attempt broadcasting, in a way that "
+                 "is likely undesired. "
+                 f"To avoid this, set __array_ufunc__ = None in {cls}.",
+                 stacklevel=2)
+
+        cls_has_array_context_attr: bool | None = _cls_has_array_context_attr
+        bcast_actx_array_type: bool | None = _bcast_actx_array_type
 
         if cls_has_array_context_attr is None:
             if hasattr(cls, "array_context"):
@@ -275,8 +324,8 @@ def with_container_arithmetic(
                         f"{cls} has an 'array_context' attribute, but it does not "
                         "set '_cls_has_array_context_attr' to *True* when calling "
                         "with_container_arithmetic. This is being interpreted "
-                        "as 'array_context' being permitted to fail with an exception, "
-                        "which is no longer allowed. "
+                        "as '.array_context' being permitted to fail "
+                        "with an exception, which is no longer allowed. "
                         f"If {cls.__name__}.array_context will not fail, pass "
                         "'_cls_has_array_context_attr=True'. "
                         "If you do not want container arithmetic to make "
@@ -294,6 +343,30 @@ def with_container_arithmetic(
                 raise TypeError("_bcast_actx_array_type can be True only if "
                                 "_cls_has_array_context_attr is set.")
 
+        if bcast_actx_array_type:
+            if _bcast_actx_array_type:
+                warn(
+                    f"Broadcasting array context array types across {cls} "
+                    "has been explicitly "
+                    "enabled. As of 2025, this will stop working. "
+                    "There is no replacement as of right now. "
+                    "See the discussion in "
+                    "https://github.com/inducer/arraycontext/pull/190. "
+                    "To opt out now (and avoid this warning), "
+                    "pass _bcast_actx_array_type=False. ",
+                    DeprecationWarning, stacklevel=2)
+            else:
+                warn(
+                    f"Broadcasting array context array types across {cls} "
+                    "has been implicitly "
+                    "enabled. As of 2025, this will no longer work. "
+                    "There is no replacement as of right now. "
+                    "See the discussion in "
+                    "https://github.com/inducer/arraycontext/pull/190. "
+                    "To opt out now (and avoid this warning), "
+                    "pass _bcast_actx_array_type=False.",
+                    DeprecationWarning, stacklevel=2)
+
         if (not hasattr(cls, "_serialize_init_arrays_code")
                 or not hasattr(cls, "_deserialize_init_arrays_code")):
             raise TypeError(f"class '{cls.__name__}' must provide serialization "
@@ -304,7 +377,7 @@ def with_container_arithmetic(
 
         from pytools.codegen import CodeGenerator, Indentation
         gen = CodeGenerator()
-        gen("""
+        gen(f"""
             from numbers import Number
             import numpy as np
             from arraycontext import ArrayContainer
@@ -315,6 +388,24 @@ def with_container_arithmetic(
                     raise ValueError("array containers with frozen arrays "
                         "cannot be operated upon")
                 return actx
+
+            def is_numpy_array(arg):
+                if isinstance(arg, np.ndarray):
+                    if arg.dtype != "O":
+                        warn("Operand is a non-object numpy array, "
+                            "and the broadcasting behavior of this array container "
+                            "({cls}) "
+                            "is influenced by this because of its use of "
+                            "the deprecated bcast_numpy_array. This broadcasting "
+                            "behavior will change in 2025. If you would like the "
+                            "broadcasting behavior to stay the same, make sure "
+                            "to convert the passed numpy array to an "
+                            "object array.",
+                            DeprecationWarning, stacklevel=3)
+                    return True
+                else:
+                    return False
+
             """)
         gen("")
 
@@ -323,7 +414,7 @@ def with_container_arithmetic(
                 gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}")
             gen("")
         outer_bcast_type_names = tuple([
-                f"_bctype{i}" for i in range(bcast_container_types_count)
+                f"_bctype{i}" for i in range(len(bcast_container_types))
                 ])
         if bcast_number:
             outer_bcast_type_names += ("Number",)
@@ -384,8 +475,6 @@ def with_container_arithmetic(
 
                 continue
 
-            # {{{ "forward" binary operators
-
             zip_init_args = cls._deserialize_init_arrays_code("arg1", {
                     same_key(key_arg1, key_arg2):
                     _format_binary_op_str(op_str, expr_arg1, expr_arg2)
@@ -393,11 +482,18 @@ def with_container_arithmetic(
                         cls._serialize_init_arrays_code("arg1").items(),
                         cls._serialize_init_arrays_code("arg2").items())
                     })
-            bcast_same_cls_init_args = cls._deserialize_init_arrays_code("arg1", {
+            bcast_init_args_arg1_is_outer = cls._deserialize_init_arrays_code("arg1", {
                     key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2")
                     for key_arg1, expr_arg1 in
                     cls._serialize_init_arrays_code("arg1").items()
                     })
+            bcast_init_args_arg2_is_outer = cls._deserialize_init_arrays_code("arg2", {
+                    key_arg2: _format_binary_op_str(op_str, "arg1", expr_arg2)
+                    for key_arg2, expr_arg2 in
+                    cls._serialize_init_arrays_code("arg2").items()
+                    })
+
+            # {{{ "forward" binary operators
 
             gen(f"def {fname}(arg1, arg2):")
             with Indentation(gen):
@@ -424,7 +520,7 @@ def with_container_arithmetic(
 
                 if bcast_actx_array_type:
                     if __debug__:
-                        bcast_actx_ary_types = (
+                        bcast_actx_ary_types: tuple[str, ...] = (
                             "*_raise_if_actx_none("
                             "arg1.array_context).array_types",)
                     else:
@@ -444,7 +540,19 @@ def with_container_arithmetic(
                     if isinstance(arg2,
                                   {tup_str(outer_bcast_type_names
                                            + bcast_actx_ary_types)}):
-                        return cls({bcast_same_cls_init_args})
+                        if __debug__:
+                            if isinstance(arg2, {tup_str(bcast_actx_ary_types)}):
+                                warn("Broadcasting {cls} over array "
+                                    f"context array type {{type(arg2)}} is deprecated "
+                                    "and will no longer work in 2025. "
+                                    "There is no replacement as of right now. "
+                                    "See the discussion in "
+                                    "https://github.com/inducer/arraycontext/"
+                                    "pull/190. ",
+                                    DeprecationWarning, stacklevel=2)
+
+                        return cls({bcast_init_args_arg1_is_outer})
+
                 return NotImplemented
                 """)
             gen(f"cls.__{dunder_name}__ = {fname}")
@@ -456,12 +564,6 @@ def with_container_arithmetic(
 
             if reversible:
                 fname = f"_{cls.__name__.lower()}_r{dunder_name}"
-                bcast_init_args = cls._deserialize_init_arrays_code("arg2", {
-                        key_arg2: _format_binary_op_str(
-                            op_str, "arg1", expr_arg2)
-                        for key_arg2, expr_arg2 in
-                        cls._serialize_init_arrays_code("arg2").items()
-                        })
 
                 if bcast_actx_array_type:
                     if __debug__:
@@ -487,7 +589,21 @@ def with_container_arithmetic(
                             if isinstance(arg1,
                                           {tup_str(outer_bcast_type_names
                                                    + bcast_actx_ary_types)}):
-                                return cls({bcast_init_args})
+                                if __debug__:
+                                    if isinstance(arg1,
+                                            {tup_str(bcast_actx_ary_types)}):
+                                        warn("Broadcasting {cls} over array "
+                                            f"context array type {{type(arg1)}} "
+                                            "is deprecated "
+                                            "and will no longer work in 2025."
+                                            "There is no replacement as of right now. "
+                                            "See the discussion in "
+                                            "https://github.com/inducer/arraycontext/"
+                                            "pull/190. ",
+                                            DeprecationWarning, stacklevel=2)
+
+                                return cls({bcast_init_args_arg2_is_outer})
+
                         return NotImplemented
 
                     cls.__r{dunder_name}__ = {fname}""")
diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 31b3bcf..100f077 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -43,6 +43,8 @@ Algebraic operations
 
 from __future__ import annotations
 
+from arraycontext.container.arithmetic import NumpyObjectArray
+
 
 __copyright__ = """
 Copyright (C) 2020-1 University of Illinois Board of Trustees
@@ -964,7 +966,7 @@ def outer(a: Any, b: Any) -> Any:
             return (
                 not isinstance(x, np.ndarray)
                 # This condition is whether "ndarrays should broadcast inside x".
-                and np.ndarray not in x.__class__._outer_bcast_types)
+                and NumpyObjectArray not in x.__class__._outer_bcast_types)
 
     if treat_as_scalar(a) or treat_as_scalar(b):
         return a*b
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 94d7d74..868790f 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -22,6 +22,7 @@ THE SOFTWARE.
 
 import logging
 from dataclasses import dataclass
+from functools import partial
 from typing import Union
 
 import numpy as np
@@ -34,6 +35,7 @@ from arraycontext import (
     ArrayContainer,
     ArrayContext,
     EagerJAXArrayContext,
+    NumpyArrayContext,
     PyOpenCLArrayContext,
     PytatoPyOpenCLArrayContext,
     dataclass_array_container,
@@ -116,10 +118,10 @@ def _acf():
 
 @with_container_arithmetic(
         bcast_obj_array=True,
-        bcast_numpy_array=True,
         bitwise=True,
         rel_comparison=True,
-        _cls_has_array_context_attr=True)
+        _cls_has_array_context_attr=True,
+        _bcast_actx_array_type=False)
 class DOFArray:
     def __init__(self, actx, data):
         if not (actx is None or isinstance(actx, ArrayContext)):
@@ -207,7 +209,8 @@ def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray:  # type:
 
 @with_container_arithmetic(bcast_obj_array=False,
         eq_comparison=False, rel_comparison=False,
-        _cls_has_array_context_attr=True)
+        _cls_has_array_context_attr=True,
+        _bcast_actx_array_type=False)
 @dataclass_array_container
 @dataclass(frozen=True)
 class MyContainer:
@@ -229,7 +232,8 @@ class MyContainer:
         bcast_container_types=(DOFArray, np.ndarray),
         matmul=True,
         rel_comparison=True,
-        _cls_has_array_context_attr=True)
+        _cls_has_array_context_attr=True,
+        _bcast_actx_array_type=False)
 @dataclass_array_container
 @dataclass(frozen=True)
 class MyContainerDOFBcast:
@@ -936,8 +940,6 @@ def test_container_arithmetic(actx_factory):
     def _check_allclose(f, arg1, arg2, atol=5.0e-14):
         assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
 
-    from functools import partial
-
     from arraycontext import rec_multimap_array_container
     for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
         rec_multimap_array_container(
@@ -1350,13 +1352,13 @@ def test_container_equality(actx_factory):
 # }}}
 
 
-# {{{ test_leaf_array_type_broadcasting
+# {{{ test_no_leaf_array_type_broadcasting
 
 @with_container_arithmetic(
     bcast_obj_array=True,
-    bcast_numpy_array=True,
     rel_comparison=True,
-    _cls_has_array_context_attr=True)
+    _cls_has_array_context_attr=True,
+    _bcast_actx_array_type=False)
 @dataclass_array_container
 @dataclass(frozen=True)
 class Foo:
@@ -1369,39 +1371,19 @@ class Foo:
         return self.u.array_context
 
 
-def test_leaf_array_type_broadcasting(actx_factory):
-    # test support for https://github.com/inducer/arraycontext/issues/49
+def test_no_leaf_array_type_broadcasting(actx_factory):
+    # test lack of support for https://github.com/inducer/arraycontext/issues/49
     actx = actx_factory()
 
-    foo = Foo(DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, )))
-    bar = foo + 4
-    baz = foo + actx.from_numpy(4*np.ones((3, )))
-    qux = actx.from_numpy(4*np.ones((3, ))) + foo
-
-    np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
-                               actx.to_numpy(baz.u[0]))
-
-    np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
-                               actx.to_numpy(qux.u[0]))
-
-    def _actx_allows_scalar_broadcast(actx):
-        if not isinstance(actx, PyOpenCLArrayContext):
-            return True
-        else:
-            import pyopencl as cl
-
-            # See https://github.com/inducer/pyopencl/issues/498
-            return cl.version.VERSION > (2021, 2, 5)
-
-    if _actx_allows_scalar_broadcast(actx):
-        quux = foo + actx.from_numpy(np.array(4))
-        quuz = actx.from_numpy(np.array(4)) + foo
+    dof_ary = DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, ))
+    foo = Foo(dof_ary)
 
-        np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
-                                   actx.to_numpy(quux.u[0]))
+    actx_ary = actx.from_numpy(4*np.ones((3, )))
+    with pytest.raises(TypeError):
+        foo + actx_ary
 
-        np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
-                                   actx.to_numpy(quuz.u[0]))
+    with pytest.raises(TypeError):
+        foo + actx.from_numpy(np.array(4))
 
 # }}}
 
-- 
GitLab