From 2a0006cefb85bf562762c25d22eecf118b5cffcb Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 13 Apr 2022 00:28:07 -0500 Subject: [PATCH] with_container_arithmetic: Default cls_has_array_context_attr based on presence of attribute --- arraycontext/container/arithmetic.py | 143 ++++++++++++++++++++++----- test/test_arraycontext.py | 24 +++-- 2 files changed, 134 insertions(+), 33 deletions(-) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index df5b662..5e2ade2 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -31,7 +31,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Any, Callable, Optional, Tuple, TypeVar, Union +from warnings import warn +from typing import Any, Callable, Optional, Tuple, TypeVar, Union, Type import numpy as np @@ -124,6 +125,10 @@ def _format_binary_op_str(op_str: str, return op_str.format(arg1, arg2) +class _FailSafe: + pass + + def with_container_arithmetic( *, bcast_number: bool = True, @@ -135,7 +140,7 @@ def with_container_arithmetic( matmul: bool = False, bitwise: bool = False, shift: bool = False, - _cls_has_array_context_attr: bool = False, + _cls_has_array_context_attr: Optional[bool] = None, eq_comparison: Optional[bool] = None, rel_comparison: Optional[bool] = None) -> Callable[[type], type]: """A class decorator that implements built-in operators for array containers @@ -172,6 +177,8 @@ def with_container_arithmetic( class has an ``array_context`` attribute. If so, and if :data:`__debug__` is *True*, an additional check is performed in binary operators to ensure that both containers use the same array context. + If *None* (the default), this value is set based on whether the class + has an ``array_context`` attribute. Consider this argument an unstable interface. It may disappear at any moment. Each operator class also includes the "reverse" operators if applicable. @@ -215,16 +222,6 @@ def with_container_arithmetic( if not bcast_obj_array and bcast_numpy_array: raise TypeError("bcast_obj_array must be set if bcast_numpy_array is") - if _bcast_actx_array_type is None: - if _cls_has_array_context_attr: - _bcast_actx_array_type = bcast_number - else: - _bcast_actx_array_type = False - else: - if _bcast_actx_array_type and not _cls_has_array_context_attr: - raise TypeError("_bcast_actx_array_type can be True only if " - "_cls_has_array_context_attr is set.") - if bcast_numpy_array: def numpy_pred(name: str) -> str: return f"isinstance({name}, np.ndarray)" @@ -260,6 +257,42 @@ def with_container_arithmetic( # }}} def wrap(cls: Any) -> Any: + cls_has_array_context_attr: Optional[Union[bool, Type[_FailSafe]]] = \ + _cls_has_array_context_attr + bcast_actx_array_type: Optional[Union[bool, Type[_FailSafe]]] = \ + _bcast_actx_array_type + + if cls_has_array_context_attr is None: + if hasattr(cls, "array_context"): + cls_has_array_context_attr = _FailSafe + warn(f"{cls} has an .array_context attribute, but it does not " + "set _cls_has_array_context_attr to True when calling " + "with_container_arithmetic. This is being interpreted " + "as .array_context being permitted to fail. Tolerating " + "these failures comes at a substantial cost. It is " + "deprecated and will stop working in 2023. " + "Having a working .array_context attribute is desirable " + "to enable arithmetic with other array types supported " + "by the array context." + f"If {cls}.array_context will not fail, pass " + "_cls_has_array_context_attr=True. " + "If you do not want container arithmetic to make " + "use of the array context, set " + "_cls_has_array_context_attr=False.", + stacklevel=2) + + if bcast_actx_array_type is None: + if cls_has_array_context_attr: + if bcast_number: + # copy over _FailSafe if present + bcast_actx_array_type = cls_has_array_context_attr + else: + bcast_actx_array_type = False + else: + if bcast_actx_array_type and not cls_has_array_context_attr: + raise TypeError("_bcast_actx_array_type can be True only if " + "_cls_has_array_context_attr is set.") + if (not hasattr(cls, "_serialize_init_arrays_code") or not hasattr(cls, "_deserialize_init_arrays_code")): raise TypeError(f"class '{cls.__name__}' must provide serialization " @@ -268,18 +301,66 @@ def with_container_arithmetic( "'_deserialize_init_arrays_code'. If this is a dataclass, " "use the 'dataclass_array_container' decorator first.") + if cls_has_array_context_attr is _FailSafe: + def actx_getter_code(arg: str) -> str: + return f"_get_actx({arg})" + else: + def actx_getter_code(arg: str) -> str: + return f"{arg}.array_context" + from pytools.codegen import CodeGenerator, Indentation gen = CodeGenerator() gen(""" from numbers import Number import numpy as np - from arraycontext import ArrayContainer + from arraycontext import ( + ArrayContainer, get_container_context_recursively) + from warnings import warn def _raise_if_actx_none(actx): if actx is None: raise ValueError("array containers with frozen arrays " "cannot be operated upon") return actx + + def _get_actx(ary): + try: + return ary.array_context + except Exception as e: + warn(f"Accessing '{type(ary).__name__}.array_context' failed " + f"({type(e)}: {e}). This should not happen and is " + "deprecated. " + "Please fix the implementation of " + f"'{type(ary).__name__}.array_context' " + "and then set _cls_has_array_context_attr=True when " + "calling with_container_arithmetic to avoid the run time " + "cost of the check that gave you this warning. " + "Using expensive recovery for now.", + DeprecationWarning, stacklevel=3) + + return get_container_context_recursively(ary) + + def _get_actx_array_types_failsafe(ary): + try: + actx = ary.array_context + except Exception as e: + warn(f"Accessing '{type(ary).__name__}.array_context' failed " + f"({type(e)}: {e}). This should not happen and is " + "deprecated. " + "Please fix the implementation of " + f"'{type(ary).__name__}.array_context' " + "and then set _cls_has_array_context_attr=True when " + "calling with_container_arithmetic to avoid the run time " + "cost of the check that gave you this warning. " + "Using expensive recovery for now.", + DeprecationWarning, stacklevel=3) + + actx = get_container_context_recursively(ary) + + if actx is None: + return () + + return actx.array_types """) gen("") @@ -368,16 +449,18 @@ def with_container_arithmetic( with Indentation(gen): gen("if arg2.__class__ is cls:") with Indentation(gen): - if __debug__ and _cls_has_array_context_attr: - gen(""" - if arg1.array_context is not arg2.array_context: + if __debug__ and cls_has_array_context_attr: + gen(f""" + arg1_actx = {actx_getter_code("arg1")} + arg2_actx = {actx_getter_code("arg2")} + if arg1_actx is not arg2_actx: msg = ("array contexts of both arguments " "must match") - if arg1.array_context is None: + if arg1_actx is None: raise ValueError(msg + ": left operand is frozen " "(i.e. has no array context)") - elif arg2.array_context is None: + elif arg2_actx is None: raise ValueError(msg + ": right operand is frozen " "(i.e. has no array context)") @@ -385,12 +468,17 @@ def with_container_arithmetic( raise ValueError(msg)""") gen(f"return cls({zip_init_args})") - if _bcast_actx_array_type: + if bcast_actx_array_type is _FailSafe: + bcast_actx_ary_types: Tuple[str, ...] = ( + "*_get_actx_array_types_failsafe(arg1)",) + elif bcast_actx_array_type: if __debug__: - bcast_actx_ary_types: Tuple[str, ...] = ( - "*_raise_if_actx_none(arg1.array_context).array_types",) + bcast_actx_ary_types = ( + "*_raise_if_actx_none(" + f"{actx_getter_code('arg1')}).array_types",) else: - bcast_actx_ary_types = ("*arg1.array_context.array_types",) + bcast_actx_ary_types = ( + f"*{actx_getter_code('arg1')}.array_types",) else: bcast_actx_ary_types = () @@ -423,12 +511,17 @@ def with_container_arithmetic( cls._serialize_init_arrays_code("arg2").items() }) - if _bcast_actx_array_type: + if bcast_actx_array_type is _FailSafe: + bcast_actx_ary_types = ( + "*_get_actx_array_types_failsafe(arg2)",) + elif bcast_actx_array_type: if __debug__: bcast_actx_ary_types = ( - "*_raise_if_actx_none(arg2.array_context).array_types",) + "*_raise_if_actx_none(" + f"{actx_getter_code('arg2')}).array_types",) else: - bcast_actx_ary_types = ("*arg2.array_context.array_types",) + bcast_actx_ary_types = ( + f"*{actx_getter_code('arg2')}.array_types",) else: bcast_actx_ary_types = () diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 649f5fb..cd61120 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -672,36 +672,44 @@ def test_array_context_einsum_array_tripleprod(actx_factory, spec): # {{{ array container classes for test @with_container_arithmetic(bcast_obj_array=False, - eq_comparison=False, rel_comparison=False) + eq_comparison=False, rel_comparison=False, + _cls_has_array_context_attr=True) @dataclass_array_container @dataclass(frozen=True) class MyContainer: name: str - mass: DOFArray + mass: DOFArray # or np.ndarray momentum: np.ndarray - enthalpy: DOFArray + enthalpy: DOFArray # or np.ndarray @property def array_context(self): - return self.mass.array_context + if isinstance(self.mass, np.ndarray): + return next(iter(self.mass)).array_context + else: + return self.mass.array_context @with_container_arithmetic( bcast_obj_array=False, bcast_container_types=(DOFArray, np.ndarray), matmul=True, - rel_comparison=True,) + rel_comparison=True, + _cls_has_array_context_attr=True) @dataclass_array_container @dataclass(frozen=True) class MyContainerDOFBcast: name: str - mass: DOFArray + mass: DOFArray # or np.ndarray momentum: np.ndarray - enthalpy: DOFArray + enthalpy: DOFArray # or np.ndarray @property def array_context(self): - return self.mass.array_context + if isinstance(self.mass, np.ndarray): + return next(iter(self.mass)).array_context + else: + return self.mass.array_context def _get_test_containers(actx, ambient_dim=2, shapes=50_000): -- GitLab