From 5187a0a5ab1be78b801f3f6941570f4d57f3fdd5 Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Sat, 17 Jul 2021 08:19:26 -0500
Subject: [PATCH] Cache codegen result in freeze() (#56)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* adds utils for normalizing an array expr

Co-authored-by: Andreas Kloeckner <andreask@illinois.edu>

* PytatoArrayContext: hold a cache of frozen arrays to programs

Co-authored-by: Matthias Diener <mdiener@illinois.edu>

* bugfix: change order of pt.make_placeholder

* Clarify hashing -> caching in normalization function docstring

Co-authored-by: Kaushik Kulkarni <kaushikcfd@gmail.com>
Co-authored-by: Andreas Kloeckner <andreask@illinois.edu>
Co-authored-by: Andreas Klöckner <inform@tiker.net>
---
 arraycontext/impl/pytato/__init__.py | 18 ++++--
 arraycontext/impl/pytato/utils.py    | 82 ++++++++++++++++++++++++++++
 2 files changed, 94 insertions(+), 6 deletions(-)
 create mode 100644 arraycontext/impl/pytato/utils.py

diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index 82309d9..dfcdc23 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -70,6 +70,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
         self.queue = queue
         self.allocator = allocator
         self.array_types = (pt.Array, )
+        self._freeze_prg_cache = {}
 
         # unused, but necessary to keep the context alive
         self.context = self.queue.context
@@ -113,9 +114,6 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
         return call_loopy(program, kwargs, entrypoint)
 
     def freeze(self, array):
-        # TODO: This should store a cache of pytato DAG -> build pyopencl
-        # program instead of re-compiling the DAG for every freeze.
-
         import pytato as pt
         import pyopencl.array as cla
 
@@ -125,10 +123,18 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
             raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with "
                             f"non-pytato array of type '{type(array)}'")
 
-        pt_prg = pt.generate_loopy(array, cl_device=self.queue.device)
-        pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program)
+        from arraycontext.impl.pytato.utils import _normalize_pt_expr
+        normalized_expr, bound_arguments = _normalize_pt_expr(array)
+
+        try:
+            pt_prg = self._freeze_prg_cache[normalized_expr]
+        except KeyError:
+            pt_prg = pt.generate_loopy(normalized_expr, cl_device=self.queue.device)
+            pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program)
+            self._freeze_prg_cache[normalized_expr] = pt_prg
 
-        evt, (cl_array,) = pt_prg(self.queue)
+        assert len(pt_prg.bound_arguments) == 0
+        evt, (cl_array,) = pt_prg(self.queue, **bound_arguments)
         evt.wait()
 
         return cl_array.with_queue(None)
diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py
new file mode 100644
index 0000000..1e00c8d
--- /dev/null
+++ b/arraycontext/impl/pytato/utils.py
@@ -0,0 +1,82 @@
+__copyright__ = """
+Copyright (C) 2021 University of Illinois Board of Trustees
+"""
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+
+from typing import Any, Dict, Set, Tuple, Mapping
+from pytato.array import SizeParam, Placeholder
+from pytato.array import Array, DataWrapper
+from pytato.transform import CopyMapper
+from pytools import UniqueNameGenerator
+
+
+class _DatawrapperToBoundPlaceholderMapper(CopyMapper):
+    """
+    Helper mapper for :func:`normalize_pt_expr`. Every
+    :class:`pytato.DataWrapper` is replaced with a deterministic copy of
+    :class:`Placeholder`.
+    """
+    def __init__(self) -> None:
+        super().__init__()
+        self.bound_arguments: Dict[str, Any] = {}
+        self.vng = UniqueNameGenerator()
+        self.seen_inputs: Set[str] = set()
+
+    def map_data_wrapper(self, expr: DataWrapper) -> Array:
+        if expr.name is not None:
+            if expr.name in self.seen_inputs:
+                raise ValueError("Got multiple inputs with the name"
+                                 f"{expr.name} => Illegal.")
+            self.seen_inputs.add(expr.name)
+
+        # Normalizing names so that we more arrays can have the normalized DAG.
+        name = self.vng("_actx_dw")
+        self.bound_arguments[name] = expr.data
+        return Placeholder(name=name,
+                           shape=tuple(self.rec(s) if isinstance(s, Array) else s
+                                       for s in expr.shape),
+                           dtype=expr.dtype,
+                           tags=expr.tags)
+
+    def map_size_param(self, expr: SizeParam) -> Array:
+        raise NotImplementedError
+
+    def map_placeholder(self, expr: Placeholder) -> Array:
+        raise ValueError("Placeholders cannot appear in"
+                         " DatawrapperToBoundPlaceholderMapper.")
+
+
+def _normalize_pt_expr(expr: Array) -> Tuple[Array,
+                                            Mapping[str, Any]]:
+    """
+    Returns ``(normalized_expr, bound_arguments)``.  *normalized_expr* is a
+    normalized form of *expr*, with all instances of
+    :class:`pytato.DataWrapper` replaced with instances of :class:`Placeholder`
+    named in a deterministic manner. The data corresponding to the placeholders
+    in *normalized_expr* is recorded in the mapping *bound_arguments*.
+    Deterministic naming of placeholders permits more effective caching of
+    equivalent graphs.
+    """
+    normalize_mapper = _DatawrapperToBoundPlaceholderMapper()
+    normalized_expr = normalize_mapper(expr)
+    return normalized_expr, normalize_mapper.bound_arguments
-- 
GitLab