From 458e2eaf3eae1c2a839583b68c281a525ee6827e Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 1 Mar 2022 15:32:59 -0600
Subject: [PATCH] Make get_container_context_recursively return actx or error

---
 arraycontext/__init__.py           | 10 +++++--
 arraycontext/container/__init__.py | 44 ++++++++++++++++++++++++------
 arraycontext/context.py            |  2 +-
 test/test_arraycontext.py          | 21 +++++++-------
 4 files changed, 55 insertions(+), 22 deletions(-)

diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py
index 6f7308d..de51ee2 100644
--- a/arraycontext/__init__.py
+++ b/arraycontext/__init__.py
@@ -40,7 +40,8 @@ from .metadata import _FirstAxisIsElementsTag
 from .container import (
         ArrayContainer, NotAnArrayContainerError,
         is_array_container, is_array_container_type,
-        get_container_context, get_container_context_recursively,
+        get_container_context_opt,
+        get_container_context_recursively, get_container_context_recursively_opt,
         serialize_container, deserialize_container,
         register_multivector_as_array_container)
 from .container.arithmetic import with_container_arithmetic
@@ -81,7 +82,9 @@ __all__ = (
 
         "ArrayContainer", "NotAnArrayContainerError",
         "is_array_container", "is_array_container_type",
-        "get_container_context", "get_container_context_recursively",
+        "get_container_context_opt",
+        "get_container_context_recursively_opt",
+        "get_container_context_recursively",
         "serialize_container", "deserialize_container",
         "register_multivector_as_array_container",
         "with_container_arithmetic",
@@ -122,6 +125,8 @@ def _deprecated_acf():
 
 
 _depr_name_to_replacement_and_obj = {
+        "get_container_context": ("get_container_context_opt",
+            get_container_context_opt),
         "FirstAxisIsElementsTag":
         ("meshmode.transform_metadata.FirstAxisIsElementsTag",
             _FirstAxisIsElementsTag),
@@ -145,6 +150,7 @@ if sys.version_info >= (3, 7):
 else:
     FirstAxisIsElementsTag = _FirstAxisIsElementsTag
     _acf = _deprecated_acf
+    get_container_context = get_container_context_opt
 
 # }}}
 
diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py
index fcf3ed1..789dd29 100644
--- a/arraycontext/container/__init__.py
+++ b/arraycontext/container/__init__.py
@@ -30,8 +30,9 @@ Serialization/deserialization
 
 Context retrieval
 -----------------
-.. autofunction:: get_container_context
+.. autofunction:: get_container_context_opt
 .. autofunction:: get_container_context_recursively
+.. autofunction:: get_container_context_recursively_opt
 
 :class:`~pymbolic.geometric_algebra.MultiVector` support
 ---------------------------------------------------------
@@ -92,7 +93,7 @@ class ArrayContainer:
       of the array.
     * :func:`deserialize_container` for deserialization, which constructs a
       container from a set of components.
-    * :func:`get_container_context` retrieves the :class:`ArrayContext` from
+    * :func:`get_container_context_opt` retrieves the :class:`ArrayContext` from
       a container, if it has one.
 
     This allows enumeration of the component arrays in a container and the
@@ -198,7 +199,7 @@ def is_array_container(ary: Any) -> bool:
 
 
 @singledispatch
-def get_container_context(ary: ArrayContainer) -> Optional[ArrayContext]:
+def get_container_context_opt(ary: ArrayContainer) -> Optional[ArrayContext]:
     """Retrieves the :class:`ArrayContext` from the container, if any.
 
     This function is not recursive, so it will only search at the root level
@@ -249,15 +250,18 @@ def _deserialize_ndarray_container(
 
 # {{{ get_container_context_recursively
 
-def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]:
+def get_container_context_recursively_opt(
+        ary: ArrayContainer) -> Optional[ArrayContext]:
     """Walks the :class:`ArrayContainer` hierarchy to find an
     :class:`ArrayContext` associated with it.
 
     If different components that have different array contexts are found at
     any level, an assertion error is raised.
+
+    Returns *None* if no array context was found.
     """
     # try getting the array context directly
-    actx = get_container_context(ary)
+    actx = get_container_context_opt(ary)
     if actx is not None:
         return actx
 
@@ -267,7 +271,7 @@ def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]:
         return actx
     else:
         for _, subary in iterable:
-            context = get_container_context_recursively(subary)
+            context = get_container_context_recursively_opt(subary)
             if context is None:
                 continue
 
@@ -280,6 +284,28 @@ def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]:
 
         return actx
 
+
+def get_container_context_recursively(ary: ArrayContainer) -> Optional[ArrayContext]:
+    """Walks the :class:`ArrayContainer` hierarchy to find an
+    :class:`ArrayContext` associated with it.
+
+    If different components that have different array contexts are found at
+    any level, an assertion error is raised.
+
+    Raises an error if no array container is found.
+    """
+    actx = get_container_context_recursively_opt(ary)
+    if actx is None:
+        # raise ValueError("no array context was found")
+        from warnings import warn
+        warn("No array context was found. This will be an error starting in "
+                "July of 2022. If you would like the function to return "
+                "None if no array context was found, use "
+                "get_container_context_recursively_opt.",
+                DeprecationWarning, stacklevel=2)
+
+    return actx
+
 # }}}
 
 
@@ -298,7 +324,7 @@ def _deserialize_multivec_as_container(template: "MultiVector",
     return MultiVector(dict(iterable), space=template.space)
 
 
-def _get_container_context_from_multivec(mv: "MultiVector") -> None:
+def _get_container_context_opt_from_multivec(mv: "MultiVector") -> None:
     return None
 
 
@@ -312,8 +338,8 @@ def register_multivector_as_array_container() -> None:
         serialize_container.register(MultiVector)(_serialize_multivec_as_container)
         deserialize_container.register(MultiVector)(
                 _deserialize_multivec_as_container)
-        get_container_context.register(MultiVector)(
-                _get_container_context_from_multivec)
+        get_container_context_opt.register(MultiVector)(
+                _get_container_context_opt_from_multivec)
         assert MultiVector in serialize_container.registry
 
 # }}}
diff --git a/arraycontext/context.py b/arraycontext/context.py
index 2bff0ac..a127b16 100644
--- a/arraycontext/context.py
+++ b/arraycontext/context.py
@@ -39,7 +39,7 @@ Here are some rules of thumb to use when dealing with thawing and freezing:
 -   Note that array contexts need not necessarily be passed as a separate
     argument. Passing thawed data as an argument to a function suffices
     to supply an array context. The array context can be extracted from
-    a thawed argument using, e.g., :func:`~arraycontext.get_container_context`
+    a thawed argument using, e.g., :func:`~arraycontext.get_container_context_opt`
     or :func:`~arraycontext.get_container_context_recursively`.
 
 What does this mean concretely?
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 97f853e..92195f5 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -941,24 +941,25 @@ def test_container_freeze_thaw(actx_factory):
 
     # {{{ check
 
-    from arraycontext import get_container_context
-    from arraycontext import get_container_context_recursively
+    from arraycontext import (
+            get_container_context_opt,
+            get_container_context_recursively_opt)
 
-    assert get_container_context(ary_of_dofs) is None
-    assert get_container_context(mat_of_dofs) is None
-    assert get_container_context(ary_dof) is actx
-    assert get_container_context(dc_of_dofs) is actx
+    assert get_container_context_opt(ary_of_dofs) is None
+    assert get_container_context_opt(mat_of_dofs) is None
+    assert get_container_context_opt(ary_dof) is actx
+    assert get_container_context_opt(dc_of_dofs) is actx
 
-    assert get_container_context_recursively(ary_of_dofs) is actx
-    assert get_container_context_recursively(mat_of_dofs) is actx
+    assert get_container_context_recursively_opt(ary_of_dofs) is actx
+    assert get_container_context_recursively_opt(mat_of_dofs) is actx
 
     for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
         frozen_ary = freeze(ary)
         thawed_ary = thaw(frozen_ary, actx)
         frozen_ary = freeze(thawed_ary)
 
-        assert get_container_context_recursively(frozen_ary) is None
-        assert get_container_context_recursively(thawed_ary) is actx
+        assert get_container_context_recursively_opt(frozen_ary) is None
+        assert get_container_context_recursively_opt(thawed_ary) is actx
 
     actx2 = actx.clone()
 
-- 
GitLab