From 458e2eaf3eae1c2a839583b68c281a525ee6827e Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Tue, 1 Mar 2022 15:32:59 -0600 Subject: [PATCH] Make get_container_context_recursively return actx or error --- arraycontext/__init__.py | 10 +++++-- arraycontext/container/__init__.py | 44 ++++++++++++++++++++++++------ arraycontext/context.py | 2 +- test/test_arraycontext.py | 21 +++++++------- 4 files changed, 55 insertions(+), 22 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 6f7308d..de51ee2 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -40,7 +40,8 @@ from .metadata import _FirstAxisIsElementsTag from .container import ( ArrayContainer, NotAnArrayContainerError, is_array_container, is_array_container_type, - get_container_context, get_container_context_recursively, + get_container_context_opt, + get_container_context_recursively, get_container_context_recursively_opt, serialize_container, deserialize_container, register_multivector_as_array_container) from .container.arithmetic import with_container_arithmetic @@ -81,7 +82,9 @@ __all__ = ( "ArrayContainer", "NotAnArrayContainerError", "is_array_container", "is_array_container_type", - "get_container_context", "get_container_context_recursively", + "get_container_context_opt", + "get_container_context_recursively_opt", + "get_container_context_recursively", "serialize_container", "deserialize_container", "register_multivector_as_array_container", "with_container_arithmetic", @@ -122,6 +125,8 @@ def _deprecated_acf(): _depr_name_to_replacement_and_obj = { + "get_container_context": ("get_container_context_opt", + get_container_context_opt), "FirstAxisIsElementsTag": ("meshmode.transform_metadata.FirstAxisIsElementsTag", _FirstAxisIsElementsTag), @@ -145,6 +150,7 @@ if sys.version_info >= (3, 7): else: FirstAxisIsElementsTag = _FirstAxisIsElementsTag _acf = _deprecated_acf + get_container_context = get_container_context_opt # }}} diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index fcf3ed1..789dd29 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -30,8 +30,9 @@ Serialization/deserialization Context retrieval ----------------- -.. autofunction:: get_container_context +.. autofunction:: get_container_context_opt .. autofunction:: get_container_context_recursively +.. autofunction:: get_container_context_recursively_opt :class:`~pymbolic.geometric_algebra.MultiVector` support --------------------------------------------------------- @@ -92,7 +93,7 @@ class ArrayContainer: of the array. * :func:`deserialize_container` for deserialization, which constructs a container from a set of components. - * :func:`get_container_context` retrieves the :class:`ArrayContext` from + * :func:`get_container_context_opt` retrieves the :class:`ArrayContext` from a container, if it has one. This allows enumeration of the component arrays in a container and the @@ -198,7 +199,7 @@ def is_array_container(ary: Any) -> bool: @singledispatch -def get_container_context(ary: ArrayContainer) -> Optional[ArrayContext]: +def get_container_context_opt(ary: ArrayContainer) -> Optional[ArrayContext]: """Retrieves the :class:`ArrayContext` from the container, if any. This function is not recursive, so it will only search at the root level @@ -249,15 +250,18 @@ def _deserialize_ndarray_container( # {{{ get_container_context_recursively -def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]: +def get_container_context_recursively_opt( + ary: ArrayContainer) -> Optional[ArrayContext]: """Walks the :class:`ArrayContainer` hierarchy to find an :class:`ArrayContext` associated with it. If different components that have different array contexts are found at any level, an assertion error is raised. + + Returns *None* if no array context was found. """ # try getting the array context directly - actx = get_container_context(ary) + actx = get_container_context_opt(ary) if actx is not None: return actx @@ -267,7 +271,7 @@ def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]: return actx else: for _, subary in iterable: - context = get_container_context_recursively(subary) + context = get_container_context_recursively_opt(subary) if context is None: continue @@ -280,6 +284,28 @@ def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]: return actx + +def get_container_context_recursively(ary: ArrayContainer) -> Optional[ArrayContext]: + """Walks the :class:`ArrayContainer` hierarchy to find an + :class:`ArrayContext` associated with it. + + If different components that have different array contexts are found at + any level, an assertion error is raised. + + Raises an error if no array container is found. + """ + actx = get_container_context_recursively_opt(ary) + if actx is None: + # raise ValueError("no array context was found") + from warnings import warn + warn("No array context was found. This will be an error starting in " + "July of 2022. If you would like the function to return " + "None if no array context was found, use " + "get_container_context_recursively_opt.", + DeprecationWarning, stacklevel=2) + + return actx + # }}} @@ -298,7 +324,7 @@ def _deserialize_multivec_as_container(template: "MultiVector", return MultiVector(dict(iterable), space=template.space) -def _get_container_context_from_multivec(mv: "MultiVector") -> None: +def _get_container_context_opt_from_multivec(mv: "MultiVector") -> None: return None @@ -312,8 +338,8 @@ def register_multivector_as_array_container() -> None: serialize_container.register(MultiVector)(_serialize_multivec_as_container) deserialize_container.register(MultiVector)( _deserialize_multivec_as_container) - get_container_context.register(MultiVector)( - _get_container_context_from_multivec) + get_container_context_opt.register(MultiVector)( + _get_container_context_opt_from_multivec) assert MultiVector in serialize_container.registry # }}} diff --git a/arraycontext/context.py b/arraycontext/context.py index 2bff0ac..a127b16 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -39,7 +39,7 @@ Here are some rules of thumb to use when dealing with thawing and freezing: - Note that array contexts need not necessarily be passed as a separate argument. Passing thawed data as an argument to a function suffices to supply an array context. The array context can be extracted from - a thawed argument using, e.g., :func:`~arraycontext.get_container_context` + a thawed argument using, e.g., :func:`~arraycontext.get_container_context_opt` or :func:`~arraycontext.get_container_context_recursively`. What does this mean concretely? diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 97f853e..92195f5 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -941,24 +941,25 @@ def test_container_freeze_thaw(actx_factory): # {{{ check - from arraycontext import get_container_context - from arraycontext import get_container_context_recursively + from arraycontext import ( + get_container_context_opt, + get_container_context_recursively_opt) - assert get_container_context(ary_of_dofs) is None - assert get_container_context(mat_of_dofs) is None - assert get_container_context(ary_dof) is actx - assert get_container_context(dc_of_dofs) is actx + assert get_container_context_opt(ary_of_dofs) is None + assert get_container_context_opt(mat_of_dofs) is None + assert get_container_context_opt(ary_dof) is actx + assert get_container_context_opt(dc_of_dofs) is actx - assert get_container_context_recursively(ary_of_dofs) is actx - assert get_container_context_recursively(mat_of_dofs) is actx + assert get_container_context_recursively_opt(ary_of_dofs) is actx + assert get_container_context_recursively_opt(mat_of_dofs) is actx for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: frozen_ary = freeze(ary) thawed_ary = thaw(frozen_ary, actx) frozen_ary = freeze(thawed_ary) - assert get_container_context_recursively(frozen_ary) is None - assert get_container_context_recursively(thawed_ary) is actx + assert get_container_context_recursively_opt(frozen_ary) is None + assert get_container_context_recursively_opt(thawed_ary) is actx actx2 = actx.clone() -- GitLab