From bbcc7becf07b4d1534060f8f23fad0a920b18b9c Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Fri, 12 Jul 2024 16:37:34 -0500 Subject: [PATCH] Array container arithemtic: drop deprecated fail-safe actx retrieval --- arraycontext/container/arithmetic.py | 107 ++++++--------------------- 1 file changed, 23 insertions(+), 84 deletions(-) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 663cdde..dbfdd5a 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -1,4 +1,6 @@ # mypy: disallow-untyped-defs +from __future__ import annotations + """ .. currentmodule:: arraycontext @@ -32,7 +34,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Any, Callable, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Optional, Tuple, TypeVar, Union import numpy as np @@ -125,10 +127,6 @@ def _format_binary_op_str(op_str: str, return op_str.format(arg1, arg2) -class _FailSafe: - pass - - def with_container_arithmetic( *, bcast_number: bool = True, @@ -266,34 +264,28 @@ def with_container_arithmetic( # }}} def wrap(cls: Any) -> Any: - cls_has_array_context_attr: Optional[Union[bool, Type[_FailSafe]]] = \ + cls_has_array_context_attr: bool | None = \ _cls_has_array_context_attr - bcast_actx_array_type: Optional[Union[bool, Type[_FailSafe]]] = \ + bcast_actx_array_type: bool | None = \ _bcast_actx_array_type if cls_has_array_context_attr is None: if hasattr(cls, "array_context"): - cls_has_array_context_attr = _FailSafe - warn(f"{cls} has an 'array_context' attribute, but it does not " - "set '_cls_has_array_context_attr' to 'True' when calling " - "'with_container_arithmetic'. This is being interpreted " - "as 'array_context' being permitted to fail. Tolerating " - "these failures comes at a substantial cost. It is " - "deprecated and will stop working in 2023. " - "Having a working 'array_context' attribute is desirable " - "to enable arithmetic with other array types supported " - "by the array context. " - f"If '{cls.__name__}.array_context' will not fail, pass " + raise TypeError( + f"{cls} has an 'array_context' attribute, but it does not " + "set '_cls_has_array_context_attr' to *True* when calling " + "with_container_arithmetic. This is being interpreted " + "as 'array_context' being permitted to fail with an exception, " + "which is no longer allowed. " + f"If {cls.__name__}.array_context will not fail, pass " "'_cls_has_array_context_attr=True'. " "If you do not want container arithmetic to make " "use of the array context, set " - "'_cls_has_array_context_attr=False'.", - stacklevel=2) + "'_cls_has_array_context_attr=False'.") if bcast_actx_array_type is None: if cls_has_array_context_attr: if bcast_number: - # copy over _FailSafe if present bcast_actx_array_type = cls_has_array_context_attr else: bcast_actx_array_type = False @@ -310,20 +302,12 @@ def with_container_arithmetic( "'_deserialize_init_arrays_code'. If this is a dataclass, " "use the 'dataclass_array_container' decorator first.") - if cls_has_array_context_attr is _FailSafe: - def actx_getter_code(arg: str) -> str: - return f"_get_actx({arg})" - else: - def actx_getter_code(arg: str) -> str: - return f"{arg}.array_context" - from pytools.codegen import CodeGenerator, Indentation gen = CodeGenerator() gen(""" from numbers import Number import numpy as np - from arraycontext import ( - ArrayContainer, get_container_context_recursively) + from arraycontext import ArrayContainer from warnings import warn def _raise_if_actx_none(actx): @@ -331,45 +315,6 @@ def with_container_arithmetic( raise ValueError("array containers with frozen arrays " "cannot be operated upon") return actx - - def _get_actx(ary): - try: - return ary.array_context - except Exception as e: - warn(f"Accessing '{type(ary).__name__}.array_context' failed " - f"({type(e)}: {e}). This should not happen and is " - "deprecated. " - "Please fix the implementation of " - f"'{type(ary).__name__}.array_context' " - "and then set _cls_has_array_context_attr=True when " - "calling with_container_arithmetic to avoid the run time " - "cost of the check that gave you this warning. " - "Using expensive recovery for now.", - DeprecationWarning, stacklevel=3) - - return get_container_context_recursively(ary) - - def _get_actx_array_types_failsafe(ary): - try: - actx = ary.array_context - except Exception as e: - warn(f"Accessing '{type(ary).__name__}.array_context' failed " - f"({type(e)}: {e}). This should not happen and is " - "deprecated. " - "Please fix the implementation of " - f"'{type(ary).__name__}.array_context' " - "and then set _cls_has_array_context_attr=True when " - "calling with_container_arithmetic to avoid the run time " - "cost of the check that gave you this warning. " - "Using expensive recovery for now.", - DeprecationWarning, stacklevel=3) - - actx = get_container_context_recursively(ary) - - if actx is None: - return () - - return actx.array_types """) gen("") @@ -459,9 +404,9 @@ def with_container_arithmetic( gen("if arg2.__class__ is cls:") with Indentation(gen): if __debug__ and cls_has_array_context_attr: - gen(f""" - arg1_actx = {actx_getter_code("arg1")} - arg2_actx = {actx_getter_code("arg2")} + gen(""" + arg1_actx = arg1.array_context + arg2_actx = arg2.array_context if arg1_actx is not arg2_actx: msg = ("array contexts of both arguments " "must match") @@ -477,17 +422,14 @@ def with_container_arithmetic( raise ValueError(msg)""") gen(f"return cls({zip_init_args})") - if bcast_actx_array_type is _FailSafe: - bcast_actx_ary_types: Tuple[str, ...] = ( - "*_get_actx_array_types_failsafe(arg1)",) - elif bcast_actx_array_type: + if bcast_actx_array_type: if __debug__: bcast_actx_ary_types = ( "*_raise_if_actx_none(" - f"{actx_getter_code('arg1')}).array_types",) + "arg1.array_context).array_types",) else: bcast_actx_ary_types = ( - f"*{actx_getter_code('arg1')}.array_types",) + "*arg1.array_context.array_types",) else: bcast_actx_ary_types = () @@ -521,17 +463,14 @@ def with_container_arithmetic( cls._serialize_init_arrays_code("arg2").items() }) - if bcast_actx_array_type is _FailSafe: - bcast_actx_ary_types = ( - "*_get_actx_array_types_failsafe(arg2)",) - elif bcast_actx_array_type: + if bcast_actx_array_type: if __debug__: bcast_actx_ary_types = ( "*_raise_if_actx_none(" - f"{actx_getter_code('arg2')}).array_types",) + "arg2.array_context).array_types",) else: bcast_actx_ary_types = ( - f"*{actx_getter_code('arg2')}.array_types",) + "*arg2.array_context.array_types",) else: bcast_actx_ary_types = () -- GitLab