From a4d6eb224822675d14813af84f36cd3614c35afd Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 31 Jul 2024 16:57:52 -0500 Subject: [PATCH] Rework dataclass array container arithmetic - Deprecate automatic broadcasting of array context arrays - Warn about uses of numpy array broadcasting, deprecated earlier - Clarify documentation, warning wording --- arraycontext/__init__.py | 5 +- arraycontext/container/arithmetic.py | 202 +++++++++++++++++++++------ arraycontext/container/traversal.py | 4 +- test/test_arraycontext.py | 58 +++----- 4 files changed, 185 insertions(+), 84 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index e8e6e9f..4e0ba83 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -43,7 +43,9 @@ from .container import ( register_multivector_as_array_container, serialize_container, ) -from .container.arithmetic import with_container_arithmetic +from .container.arithmetic import ( + with_container_arithmetic, +) from .container.dataclass import dataclass_array_container from .container.traversal import ( flat_size_and_dtype, @@ -151,7 +153,6 @@ __all__ = ( "unflatten", "with_array_context", "with_container_arithmetic", - "with_container_arithmetic" ) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index dbfdd5a..b085a7d 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -2,13 +2,12 @@ from __future__ import annotations -""" +__doc__ = """ .. currentmodule:: arraycontext + .. autofunction:: with_container_arithmetic """ -import enum - __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -34,7 +33,9 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import enum from typing import Any, Callable, Optional, Tuple, TypeVar, Union +from warnings import warn import numpy as np @@ -99,8 +100,8 @@ def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str: def _format_binary_op_str(op_str: str, - arg1: Union[Tuple[str, ...], str], - arg2: Union[Tuple[str, ...], str]) -> str: + arg1: Union[Tuple[str, str], str], + arg2: Union[Tuple[str, str], str]) -> str: if isinstance(arg1, tuple) and isinstance(arg2, tuple): import sys if sys.version_info >= (3, 10): @@ -127,6 +128,36 @@ def _format_binary_op_str(op_str: str, return op_str.format(arg1, arg2) +class NumpyObjectArrayMetaclass(type): + def __instancecheck__(cls, instance: Any) -> bool: + return isinstance(instance, np.ndarray) and instance.dtype == object + + +class NumpyObjectArray(metaclass=NumpyObjectArrayMetaclass): + pass + + +class ComplainingNumpyNonObjectArrayMetaclass(type): + def __instancecheck__(cls, instance: Any) -> bool: + if isinstance(instance, np.ndarray) and instance.dtype != object: + # Example usage site: + # https://github.com/illinois-ceesd/mirgecom/blob/f5d0d97c41e8c8a05546b1d1a6a2979ec8ea3554/mirgecom/inviscid.py#L148-L149 + # where normal is passed in by test_lfr_flux as a 'custom-made' + # numpy array of dtype float64. + warn( + "Broadcasting container against non-object numpy array. " + "This was never documented to work and will now stop working in " + "2025. Convert the array to an object array to preserve the " + "current semantics.", DeprecationWarning, stacklevel=3) + return True + else: + return False + + +class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMetaclass): + pass + + def with_container_arithmetic( *, bcast_number: bool = True, @@ -146,22 +177,16 @@ def with_container_arithmetic( :arg bcast_number: If *True*, numbers broadcast over the container (with the container as the 'outer' structure). - :arg _bcast_actx_array_type: If *True*, instances of base array types of the - container's array context are broadcasted over the container. Can be - *True* only if the container has *_cls_has_array_context_attr* set. - Defaulted to *bcast_number* if *_cls_has_array_context_attr* is set, - else *False*. - :arg bcast_obj_array: If *True*, :mod:`numpy` object arrays broadcast over - the container. (with the container as the 'inner' structure) - :arg bcast_numpy_array: If *True*, any :class:`numpy.ndarray` will broadcast - over the container. (with the container as the 'inner' structure) - If this is set to *True*, *bcast_obj_array* must also be *True*. + :arg bcast_obj_array: If *True*, this container will be broadcast + across :mod:`numpy` object arrays + (with the object array as the 'outer' structure). + Add :class:`numpy.ndarray` to *bcast_container_types* to achieve + the 'reverse' broadcasting. :arg bcast_container_types: A sequence of container types that will broadcast - over this container (with this container as the 'outer' structure). + across this container, with this container as the 'outer' structure. :class:`numpy.ndarray` is permitted to be part of this sequence to - indicate that, in such broadcasting situations, this container should - be the 'outer' structure. In this case, *bcast_obj_array* - (and consequently *bcast_numpy_array*) must be *False*. + indicate that object arrays (and *only* object arrays) will be broadcasat. + In this case, *bcast_obj_array* must be *False*. :arg arithmetic: Implement the conventional arithmetic operators, including ``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as :func:`abs`. @@ -203,6 +228,17 @@ def with_container_arithmetic( should nest "outside" :func:dataclass_array_container`. """ + # Hard-won design lessons: + # + # - Anything that special-cases np.ndarray by type is broken by design because: + # - np.ndarray is an array context array. + # - numpy object arrays can be array containers. + # Using NumpyObjectArray and NumpyNonObjectArray *may* be better? + # They're new, so there is no operational experience with them. + # + # - Broadcast rules are hard to change once established, particularly + # because one cannot grep for their use. + # {{{ handle inputs if bcast_obj_array is None: @@ -212,9 +248,8 @@ def with_container_arithmetic( raise TypeError("rel_comparison must be specified") if bcast_numpy_array: - from warnings import warn warn("'bcast_numpy_array=True' is deprecated and will be unsupported" - " from December 2021", DeprecationWarning, stacklevel=2) + " from 2025.", DeprecationWarning, stacklevel=2) if _bcast_actx_array_type: raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'" @@ -231,7 +266,7 @@ def with_container_arithmetic( if bcast_numpy_array: def numpy_pred(name: str) -> str: - return f"isinstance({name}, np.ndarray)" + return f"is_numpy_array({name})" elif bcast_obj_array: def numpy_pred(name: str) -> str: return f"isinstance({name}, np.ndarray) and {name}.dtype.char == 'O'" @@ -241,12 +276,21 @@ def with_container_arithmetic( if bcast_container_types is None: bcast_container_types = () - bcast_container_types_count = len(bcast_container_types) if np.ndarray in bcast_container_types and bcast_obj_array: raise ValueError("If numpy.ndarray is part of bcast_container_types, " "bcast_obj_array must be False.") + numpy_check_types: list[type] = [NumpyObjectArray, ComplainingNumpyNonObjectArray] + bcast_container_types = tuple( + new_ct + for old_ct in bcast_container_types + for new_ct in + (numpy_check_types + if old_ct is np.ndarray + else [old_ct]) + ) + desired_op_classes = set() if arithmetic: desired_op_classes.add(_OpClass.ARITHMETIC) @@ -264,10 +308,15 @@ def with_container_arithmetic( # }}} def wrap(cls: Any) -> Any: - cls_has_array_context_attr: bool | None = \ - _cls_has_array_context_attr - bcast_actx_array_type: bool | None = \ - _bcast_actx_array_type + if not hasattr(cls, "__array_ufunc__"): + warn(f"{cls} does not have __array_ufunc__ set. " + "This will cause numpy to attempt broadcasting, in a way that " + "is likely undesired. " + f"To avoid this, set __array_ufunc__ = None in {cls}.", + stacklevel=2) + + cls_has_array_context_attr: bool | None = _cls_has_array_context_attr + bcast_actx_array_type: bool | None = _bcast_actx_array_type if cls_has_array_context_attr is None: if hasattr(cls, "array_context"): @@ -275,8 +324,8 @@ def with_container_arithmetic( f"{cls} has an 'array_context' attribute, but it does not " "set '_cls_has_array_context_attr' to *True* when calling " "with_container_arithmetic. This is being interpreted " - "as 'array_context' being permitted to fail with an exception, " - "which is no longer allowed. " + "as '.array_context' being permitted to fail " + "with an exception, which is no longer allowed. " f"If {cls.__name__}.array_context will not fail, pass " "'_cls_has_array_context_attr=True'. " "If you do not want container arithmetic to make " @@ -294,6 +343,30 @@ def with_container_arithmetic( raise TypeError("_bcast_actx_array_type can be True only if " "_cls_has_array_context_attr is set.") + if bcast_actx_array_type: + if _bcast_actx_array_type: + warn( + f"Broadcasting array context array types across {cls} " + "has been explicitly " + "enabled. As of 2025, this will stop working. " + "There is no replacement as of right now. " + "See the discussion in " + "https://github.com/inducer/arraycontext/pull/190. " + "To opt out now (and avoid this warning), " + "pass _bcast_actx_array_type=False. ", + DeprecationWarning, stacklevel=2) + else: + warn( + f"Broadcasting array context array types across {cls} " + "has been implicitly " + "enabled. As of 2025, this will no longer work. " + "There is no replacement as of right now. " + "See the discussion in " + "https://github.com/inducer/arraycontext/pull/190. " + "To opt out now (and avoid this warning), " + "pass _bcast_actx_array_type=False.", + DeprecationWarning, stacklevel=2) + if (not hasattr(cls, "_serialize_init_arrays_code") or not hasattr(cls, "_deserialize_init_arrays_code")): raise TypeError(f"class '{cls.__name__}' must provide serialization " @@ -304,7 +377,7 @@ def with_container_arithmetic( from pytools.codegen import CodeGenerator, Indentation gen = CodeGenerator() - gen(""" + gen(f""" from numbers import Number import numpy as np from arraycontext import ArrayContainer @@ -315,6 +388,24 @@ def with_container_arithmetic( raise ValueError("array containers with frozen arrays " "cannot be operated upon") return actx + + def is_numpy_array(arg): + if isinstance(arg, np.ndarray): + if arg.dtype != "O": + warn("Operand is a non-object numpy array, " + "and the broadcasting behavior of this array container " + "({cls}) " + "is influenced by this because of its use of " + "the deprecated bcast_numpy_array. This broadcasting " + "behavior will change in 2025. If you would like the " + "broadcasting behavior to stay the same, make sure " + "to convert the passed numpy array to an " + "object array.", + DeprecationWarning, stacklevel=3) + return True + else: + return False + """) gen("") @@ -323,7 +414,7 @@ def with_container_arithmetic( gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}") gen("") outer_bcast_type_names = tuple([ - f"_bctype{i}" for i in range(bcast_container_types_count) + f"_bctype{i}" for i in range(len(bcast_container_types)) ]) if bcast_number: outer_bcast_type_names += ("Number",) @@ -384,8 +475,6 @@ def with_container_arithmetic( continue - # {{{ "forward" binary operators - zip_init_args = cls._deserialize_init_arrays_code("arg1", { same_key(key_arg1, key_arg2): _format_binary_op_str(op_str, expr_arg1, expr_arg2) @@ -393,11 +482,18 @@ def with_container_arithmetic( cls._serialize_init_arrays_code("arg1").items(), cls._serialize_init_arrays_code("arg2").items()) }) - bcast_same_cls_init_args = cls._deserialize_init_arrays_code("arg1", { + bcast_init_args_arg1_is_outer = cls._deserialize_init_arrays_code("arg1", { key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2") for key_arg1, expr_arg1 in cls._serialize_init_arrays_code("arg1").items() }) + bcast_init_args_arg2_is_outer = cls._deserialize_init_arrays_code("arg2", { + key_arg2: _format_binary_op_str(op_str, "arg1", expr_arg2) + for key_arg2, expr_arg2 in + cls._serialize_init_arrays_code("arg2").items() + }) + + # {{{ "forward" binary operators gen(f"def {fname}(arg1, arg2):") with Indentation(gen): @@ -424,7 +520,7 @@ def with_container_arithmetic( if bcast_actx_array_type: if __debug__: - bcast_actx_ary_types = ( + bcast_actx_ary_types: tuple[str, ...] = ( "*_raise_if_actx_none(" "arg1.array_context).array_types",) else: @@ -444,7 +540,19 @@ def with_container_arithmetic( if isinstance(arg2, {tup_str(outer_bcast_type_names + bcast_actx_ary_types)}): - return cls({bcast_same_cls_init_args}) + if __debug__: + if isinstance(arg2, {tup_str(bcast_actx_ary_types)}): + warn("Broadcasting {cls} over array " + f"context array type {{type(arg2)}} is deprecated " + "and will no longer work in 2025. " + "There is no replacement as of right now. " + "See the discussion in " + "https://github.com/inducer/arraycontext/" + "pull/190. ", + DeprecationWarning, stacklevel=2) + + return cls({bcast_init_args_arg1_is_outer}) + return NotImplemented """) gen(f"cls.__{dunder_name}__ = {fname}") @@ -456,12 +564,6 @@ def with_container_arithmetic( if reversible: fname = f"_{cls.__name__.lower()}_r{dunder_name}" - bcast_init_args = cls._deserialize_init_arrays_code("arg2", { - key_arg2: _format_binary_op_str( - op_str, "arg1", expr_arg2) - for key_arg2, expr_arg2 in - cls._serialize_init_arrays_code("arg2").items() - }) if bcast_actx_array_type: if __debug__: @@ -487,7 +589,21 @@ def with_container_arithmetic( if isinstance(arg1, {tup_str(outer_bcast_type_names + bcast_actx_ary_types)}): - return cls({bcast_init_args}) + if __debug__: + if isinstance(arg1, + {tup_str(bcast_actx_ary_types)}): + warn("Broadcasting {cls} over array " + f"context array type {{type(arg1)}} " + "is deprecated " + "and will no longer work in 2025." + "There is no replacement as of right now. " + "See the discussion in " + "https://github.com/inducer/arraycontext/" + "pull/190. ", + DeprecationWarning, stacklevel=2) + + return cls({bcast_init_args_arg2_is_outer}) + return NotImplemented cls.__r{dunder_name}__ = {fname}""") diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 31b3bcf..100f077 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -43,6 +43,8 @@ Algebraic operations from __future__ import annotations +from arraycontext.container.arithmetic import NumpyObjectArray + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -964,7 +966,7 @@ def outer(a: Any, b: Any) -> Any: return ( not isinstance(x, np.ndarray) # This condition is whether "ndarrays should broadcast inside x". - and np.ndarray not in x.__class__._outer_bcast_types) + and NumpyObjectArray not in x.__class__._outer_bcast_types) if treat_as_scalar(a) or treat_as_scalar(b): return a*b diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 94d7d74..868790f 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -22,6 +22,7 @@ THE SOFTWARE. import logging from dataclasses import dataclass +from functools import partial from typing import Union import numpy as np @@ -34,6 +35,7 @@ from arraycontext import ( ArrayContainer, ArrayContext, EagerJAXArrayContext, + NumpyArrayContext, PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, dataclass_array_container, @@ -116,10 +118,10 @@ def _acf(): @with_container_arithmetic( bcast_obj_array=True, - bcast_numpy_array=True, bitwise=True, rel_comparison=True, - _cls_has_array_context_attr=True) + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) class DOFArray: def __init__(self, actx, data): if not (actx is None or isinstance(actx, ArrayContext)): @@ -207,7 +209,8 @@ def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type: @with_container_arithmetic(bcast_obj_array=False, eq_comparison=False, rel_comparison=False, - _cls_has_array_context_attr=True) + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) @dataclass_array_container @dataclass(frozen=True) class MyContainer: @@ -229,7 +232,8 @@ class MyContainer: bcast_container_types=(DOFArray, np.ndarray), matmul=True, rel_comparison=True, - _cls_has_array_context_attr=True) + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) @dataclass_array_container @dataclass(frozen=True) class MyContainerDOFBcast: @@ -936,8 +940,6 @@ def test_container_arithmetic(actx_factory): def _check_allclose(f, arg1, arg2, atol=5.0e-14): assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol - from functools import partial - from arraycontext import rec_multimap_array_container for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: rec_multimap_array_container( @@ -1350,13 +1352,13 @@ def test_container_equality(actx_factory): # }}} -# {{{ test_leaf_array_type_broadcasting +# {{{ test_no_leaf_array_type_broadcasting @with_container_arithmetic( bcast_obj_array=True, - bcast_numpy_array=True, rel_comparison=True, - _cls_has_array_context_attr=True) + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) @dataclass_array_container @dataclass(frozen=True) class Foo: @@ -1369,39 +1371,19 @@ class Foo: return self.u.array_context -def test_leaf_array_type_broadcasting(actx_factory): - # test support for https://github.com/inducer/arraycontext/issues/49 +def test_no_leaf_array_type_broadcasting(actx_factory): + # test lack of support for https://github.com/inducer/arraycontext/issues/49 actx = actx_factory() - foo = Foo(DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, ))) - bar = foo + 4 - baz = foo + actx.from_numpy(4*np.ones((3, ))) - qux = actx.from_numpy(4*np.ones((3, ))) + foo - - np.testing.assert_allclose(actx.to_numpy(bar.u[0]), - actx.to_numpy(baz.u[0])) - - np.testing.assert_allclose(actx.to_numpy(bar.u[0]), - actx.to_numpy(qux.u[0])) - - def _actx_allows_scalar_broadcast(actx): - if not isinstance(actx, PyOpenCLArrayContext): - return True - else: - import pyopencl as cl - - # See https://github.com/inducer/pyopencl/issues/498 - return cl.version.VERSION > (2021, 2, 5) - - if _actx_allows_scalar_broadcast(actx): - quux = foo + actx.from_numpy(np.array(4)) - quuz = actx.from_numpy(np.array(4)) + foo + dof_ary = DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, )) + foo = Foo(dof_ary) - np.testing.assert_allclose(actx.to_numpy(bar.u[0]), - actx.to_numpy(quux.u[0])) + actx_ary = actx.from_numpy(4*np.ones((3, ))) + with pytest.raises(TypeError): + foo + actx_ary - np.testing.assert_allclose(actx.to_numpy(bar.u[0]), - actx.to_numpy(quuz.u[0])) + with pytest.raises(TypeError): + foo + actx.from_numpy(np.array(4)) # }}} -- GitLab