From 5187a0a5ab1be78b801f3f6941570f4d57f3fdd5 Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Sat, 17 Jul 2021 08:19:26 -0500 Subject: [PATCH] Cache codegen result in freeze() (#56) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * adds utils for normalizing an array expr Co-authored-by: Andreas Kloeckner <andreask@illinois.edu> * PytatoArrayContext: hold a cache of frozen arrays to programs Co-authored-by: Matthias Diener <mdiener@illinois.edu> * bugfix: change order of pt.make_placeholder * Clarify hashing -> caching in normalization function docstring Co-authored-by: Kaushik Kulkarni <kaushikcfd@gmail.com> Co-authored-by: Andreas Kloeckner <andreask@illinois.edu> Co-authored-by: Andreas Klöckner <inform@tiker.net> --- arraycontext/impl/pytato/__init__.py | 18 ++++-- arraycontext/impl/pytato/utils.py | 82 ++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 6 deletions(-) create mode 100644 arraycontext/impl/pytato/utils.py diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 82309d9..dfcdc23 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -70,6 +70,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext): self.queue = queue self.allocator = allocator self.array_types = (pt.Array, ) + self._freeze_prg_cache = {} # unused, but necessary to keep the context alive self.context = self.queue.context @@ -113,9 +114,6 @@ class PytatoPyOpenCLArrayContext(ArrayContext): return call_loopy(program, kwargs, entrypoint) def freeze(self, array): - # TODO: This should store a cache of pytato DAG -> build pyopencl - # program instead of re-compiling the DAG for every freeze. - import pytato as pt import pyopencl.array as cla @@ -125,10 +123,18 @@ class PytatoPyOpenCLArrayContext(ArrayContext): raise TypeError("PytatoPyOpenCLArrayContext.freeze invoked with " f"non-pytato array of type '{type(array)}'") - pt_prg = pt.generate_loopy(array, cl_device=self.queue.device) - pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program) + from arraycontext.impl.pytato.utils import _normalize_pt_expr + normalized_expr, bound_arguments = _normalize_pt_expr(array) + + try: + pt_prg = self._freeze_prg_cache[normalized_expr] + except KeyError: + pt_prg = pt.generate_loopy(normalized_expr, cl_device=self.queue.device) + pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program) + self._freeze_prg_cache[normalized_expr] = pt_prg - evt, (cl_array,) = pt_prg(self.queue) + assert len(pt_prg.bound_arguments) == 0 + evt, (cl_array,) = pt_prg(self.queue, **bound_arguments) evt.wait() return cl_array.with_queue(None) diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py new file mode 100644 index 0000000..1e00c8d --- /dev/null +++ b/arraycontext/impl/pytato/utils.py @@ -0,0 +1,82 @@ +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +from typing import Any, Dict, Set, Tuple, Mapping +from pytato.array import SizeParam, Placeholder +from pytato.array import Array, DataWrapper +from pytato.transform import CopyMapper +from pytools import UniqueNameGenerator + + +class _DatawrapperToBoundPlaceholderMapper(CopyMapper): + """ + Helper mapper for :func:`normalize_pt_expr`. Every + :class:`pytato.DataWrapper` is replaced with a deterministic copy of + :class:`Placeholder`. + """ + def __init__(self) -> None: + super().__init__() + self.bound_arguments: Dict[str, Any] = {} + self.vng = UniqueNameGenerator() + self.seen_inputs: Set[str] = set() + + def map_data_wrapper(self, expr: DataWrapper) -> Array: + if expr.name is not None: + if expr.name in self.seen_inputs: + raise ValueError("Got multiple inputs with the name" + f"{expr.name} => Illegal.") + self.seen_inputs.add(expr.name) + + # Normalizing names so that we more arrays can have the normalized DAG. + name = self.vng("_actx_dw") + self.bound_arguments[name] = expr.data + return Placeholder(name=name, + shape=tuple(self.rec(s) if isinstance(s, Array) else s + for s in expr.shape), + dtype=expr.dtype, + tags=expr.tags) + + def map_size_param(self, expr: SizeParam) -> Array: + raise NotImplementedError + + def map_placeholder(self, expr: Placeholder) -> Array: + raise ValueError("Placeholders cannot appear in" + " DatawrapperToBoundPlaceholderMapper.") + + +def _normalize_pt_expr(expr: Array) -> Tuple[Array, + Mapping[str, Any]]: + """ + Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a + normalized form of *expr*, with all instances of + :class:`pytato.DataWrapper` replaced with instances of :class:`Placeholder` + named in a deterministic manner. The data corresponding to the placeholders + in *normalized_expr* is recorded in the mapping *bound_arguments*. + Deterministic naming of placeholders permits more effective caching of + equivalent graphs. + """ + normalize_mapper = _DatawrapperToBoundPlaceholderMapper() + normalized_expr = normalize_mapper(expr) + return normalized_expr, normalize_mapper.bound_arguments -- GitLab