From d964b2833d495076366c07a1f73c750a75ac2b7c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Andreas=20Kl=C3=B6ckner?= <inform@tiker.net>
Date: Thu, 10 Jun 2021 22:13:30 -0500
Subject: [PATCH] Improve freeze/thaw usability (#22)

* Add a longer explainer on freeze and thaw

* Add ArrayContext.clone

* with_container_arithmetic: Add _same_cls_check

* Add test for error-on-mixed-array-contexts

* Improvements/fixes to freeze/thaw explainer

Co-authored-by: Alex Fikl <alexfikl@gmail.com>

* Remove trailing whitespace (flake8)

* Add link to lazy eval functionality under the hood of freeze/thaw

* Stop using code injection for actx match checking in with_container_arithmetic

* Make usage guidelines for freeze/thaw a separate section

Co-authored-by: Thomas H. Gibson <gibsonthomas1120@hotmail.com>

* Tweak phrasing in freeze/thaw usage guidelines, add anchors

* Remove an extraneous word in the freeze/thaw description

Co-authored-by: Alex Fikl <alexfikl@gmail.com>
Co-authored-by: Thomas H. Gibson <gibsonthomas1120@hotmail.com>
---
 arraycontext/container/arithmetic.py | 44 ++++++++-----
 arraycontext/context.py              | 93 ++++++++++++++++++++++++++++
 arraycontext/impl/pyopencl.py        |  3 +
 test/test_arraycontext.py            | 10 ++-
 4 files changed, 134 insertions(+), 16 deletions(-)

diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py
index db989c1..84cfa20 100644
--- a/arraycontext/container/arithmetic.py
+++ b/arraycontext/container/arithmetic.py
@@ -133,6 +133,7 @@ def with_container_arithmetic(
         matmul: bool = False,
         bitwise: bool = False,
         shift: bool = False,
+        _cls_has_array_context_attr: bool = False,
         eq_comparison: Optional[bool] = None,
         rel_comparison: Optional[bool] = None) -> Callable[[type], type]:
     """A class decorator that implements built-in operators for array containers
@@ -160,6 +161,11 @@ def with_container_arithmetic(
     :arg rel_comparison: If *True*, implement ``<``, ``<=``, ``>``, ``>=``.
         In that case, if *eq_comparison* is unspecified, it is also set to
         *True*.
+    :arg _cls_has_array_context_attr: A flag indicating whether the decorated
+        class has an ``array_context`` attribute. If so, and if :data:`__debug__`
+        is *True*, an additional check is performed in binary operators
+        to ensure that both containers use the same array context.
+        Consider this argument an unstable interface. It may disappear at any moment.
 
     Each operator class also includes the "reverse" operators if applicable.
 
@@ -245,7 +251,7 @@ def with_container_arithmetic(
                     "'_deserialize_init_arrays_code'. If this is a dataclass, "
                     "use the 'dataclass_array_container' decorator first.")
 
-        from pytools.codegen import CodeGenerator
+        from pytools.codegen import CodeGenerator, Indentation
         gen = CodeGenerator()
         gen("""
             from numbers import Number
@@ -317,20 +323,28 @@ def with_container_arithmetic(
                     cls._serialize_init_arrays_code("arg1").items()
                     })
 
-            gen(f"""
-                def {fname}(arg1, arg2):
-                    if arg2.__class__ is cls:
-                        return cls({zip_init_args})
-                    if {bool(outer_bcast_type_names)}:  # optimized away
-                        if isinstance(arg2, {tup_str(outer_bcast_type_names)}):
-                            return cls({bcast_init_args})
-                    if {numpy_pred("arg2")}:  # optimized away
-                        result = np.empty_like(arg2, dtype=object)
-                        for i in np.ndindex(arg2.shape):
-                            result[i] = {op_str.format("arg1", "arg2[i]")}
-                        return result
-                    return NotImplemented
-                cls.__{dunder_name}__ = {fname}""")
+            gen(f"def {fname}(arg1, arg2):")
+            with Indentation(gen):
+                gen("if arg2.__class__ is cls:")
+                with Indentation(gen):
+                    if __debug__ and _cls_has_array_context_attr:
+                        gen("""
+                            if arg1.array_context is not arg2.array_context:
+                                raise ValueError("array contexts of both arguments "
+                                    "must match")""")
+                    gen(f"return cls({zip_init_args})")
+                gen(f"""
+                if {bool(outer_bcast_type_names)}:  # optimized away
+                    if isinstance(arg2, {tup_str(outer_bcast_type_names)}):
+                        return cls({bcast_init_args})
+                if {numpy_pred("arg2")}:  # optimized away
+                    result = np.empty_like(arg2, dtype=object)
+                    for i in np.ndindex(arg2.shape):
+                        result[i] = {op_str.format("arg1", "arg2[i]")}
+                    return result
+                return NotImplemented
+                """)
+            gen(f"cls.__{dunder_name}__ = {fname}")
             gen("")
 
             # }}}
diff --git a/arraycontext/context.py b/arraycontext/context.py
index a4b2402..838f969 100644
--- a/arraycontext/context.py
+++ b/arraycontext/context.py
@@ -1,4 +1,78 @@
 """
+.. _freeze-thaw:
+
+Freezing and thawing
+--------------------
+
+One of the central concepts introduced by the array context formalism is
+the notion of :meth:`~arraycontext.ArrayContext.freeze` and
+:meth:`~arraycontext.ArrayContext.thaw`. Each array handled by the array context
+is either "thawed" or "frozen". Unlike the real-world concept of freezing and
+thawing, these operations leave the original array alone; instead, a semantically
+separate array in the desired state is returned.
+
+*   "Thawed" arrays are associated with an array context. They use that context
+    to carry out operations (arithmetic, function calls).
+
+*   "Frozen" arrays are static data. They are not associated with an array context,
+    and no operations can be performed on them.
+
+Freezing and thawing may be used to move arrays from one array context to another,
+as long as both array contexts use identical in-memory data representation.
+Otherwise, a common format must be agreed upon, for example using
+:mod:`numpy` through :meth:`~arraycontext.ArrayContext.to_numpy` and
+:meth:`~arraycontext.ArrayContext.from_numpy`.
+
+.. _freeze-thaw-guidelines:
+
+Usage guidelines
+^^^^^^^^^^^^^^^^
+Here are some rules of thumb to use when dealing with thawing and freezing:
+
+-   Any array that is stored for a long time needs to be frozen.
+    "Memoized" data (cf. :func:`pytools.memoize` and friends) is a good example
+    of long-lived data that should be frozen.
+
+-   Within a function, if the user did not supply an array context,
+    then any data returned to the user should be frozen.
+
+-   Note that array contexts need not necessarily be passed as a separate
+    argument. Passing thawed data as an argument to a function suffices
+    to supply an array context. The array context can be extracted from
+    a thawed argument using, e.g., :func:`~arraycontext.get_container_context`
+    or :func:`~arraycontext.get_container_context_recursively`.
+
+What does this mean concretely?
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Freezing and thawing are abstract names for concrete operations. It may be helpful
+to understand what these operations mean in the concrete case of various
+actual array contexts:
+
+-   Each :class:`~arraycontext.PyOpenCLArrayContext` is associated with a
+    :class:`pyopencl.CommandQueue`. In order to operate on array data,
+    such a command queue is necessary; it is the main means of synchronization
+    between the host program and the compute device. "Thawing" here
+    means associating an array with a command queue, and "freezing" means
+    ensuring that the array data is fully computed in memory and
+    decoupling the array from the command queue. It is not valid to "mix"
+    arrays associated with multiple queues within an operation: if it were allowed,
+    a dependent operation might begin computing before an input is fully
+    available. (Since bugs of this nature would be very difficult to
+    find, :class:`pyopencl.array.Array` and
+    :class:`~meshmode.dof_array.DOFArray` will not allow them.)
+
+-   For the lazily-evaluating array context based on :mod:`pytato`,
+    "thawing" corresponds to the creation of a symbolic "handle"
+    (specifically, a :class:`pytato.array.DataWrapper`) representing
+    the array that can then be used in computation, and "freezing"
+    corresponds to triggering (code generation and) evaluation of
+    an array expression that has been built up by the user
+    (using, e.g. :func:`pytato.generate_loopy`).
+
+The interface of an array context
+---------------------------------
+
 .. currentmodule:: arraycontext
 .. autoclass:: ArrayContext
 """
@@ -256,6 +330,25 @@ class ArrayContext(ABC):
             prg, **{arg_names[i]: arg for i, arg in enumerate(args)}
         )["out"]
 
+    @abstractmethod
+    def clone(self):
+        """If possible, return a version of *self* that is semantically
+        equivalent (i.e. implements all array operations in the same way)
+        but is a separate object. May return *self* if that is not possible.
+
+        .. note::
+
+            The main objective of this semi-documented method is to help
+            flag errors more clearly when array contexts are mixed that
+            should not be. For example, at the time of this writing,
+            :class:`meshmode.meshmode.Discretization` objects have a private
+            array context that is only to be used for setup-related tasks.
+            By using :meth:`clone` to make this a separate array context,
+            and by checking that arithmetic does not mix array contexts,
+            it becomes easier to detect and flag if unfrozen data attached to a
+            "setup-only" array context "leaks" into the application.
+        """
+
 # }}}
 
 # vim: foldmethod=marker
diff --git a/arraycontext/impl/pyopencl.py b/arraycontext/impl/pyopencl.py
index 1906509..c2726ad 100644
--- a/arraycontext/impl/pyopencl.py
+++ b/arraycontext/impl/pyopencl.py
@@ -388,6 +388,9 @@ class PyOpenCLArrayContext(ArrayContext):
         # Sorry, not capable.
         return array
 
+    def clone(self):
+        return type(self)(self.queue, self.allocator, self._wait_event_queue_length)
+
 # }}}
 
 # vim: foldmethod=marker
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index aeadb70..30b7103 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -46,7 +46,8 @@ logger = logging.getLogger(__name__)
 @with_container_arithmetic(
         bcast_obj_array=True,
         bcast_numpy_array=True,
-        rel_comparison=True)
+        rel_comparison=True,
+        _cls_has_array_context_attr=True)
 class DOFArray:
     def __init__(self, actx, data):
         if not (actx is None or isinstance(actx, ArrayContext)):
@@ -624,6 +625,13 @@ def test_container_freeze_thaw(actx_factory):
         assert get_container_context_recursively(frozen_ary) is None
         assert get_container_context_recursively(thawed_ary) is actx
 
+    actx2 = actx.clone()
+
+    ary_dof_2 = thaw(freeze(ary_dof), actx2)
+
+    with pytest.raises(ValueError):
+        ary_dof + ary_dof_2
+
     # }}}
 
 
-- 
GitLab