diff --git a/pytato/__init__.py b/pytato/__init__.py index dcfaf28c87db410713c6d97ba67fa8565baeaf62..f47a61c6c13c0c0f14408c06401e10aecdb917e7 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -26,7 +26,7 @@ THE SOFTWARE. from pytato.array import ( Namespace, Array, DictOfNamedArrays, Tag, UniqueTag, - DottedName, Placeholder, make_placeholder, + DottedName, Placeholder, make_placeholder, IndexLambda ) from pytato.codegen import generate_loopy @@ -35,6 +35,7 @@ from pytato.program import Target, PyOpenCLTarget __all__ = ( "DottedName", "Namespace", "Array", "DictOfNamedArrays", "Tag", "UniqueTag", "Placeholder", "make_placeholder", + "IndexLambda", "generate_loopy", "Target", "PyOpenCLTarget", diff --git a/pytato/array.py b/pytato/array.py index 737ea463a82bb6f54e629e6368640137237b5699..73cac6fafc3eebbbee13c8c88a7f621a0ba047ce 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -180,7 +180,8 @@ class Namespace(Mapping[str, "Array"]): return len(self._symbol_table) def copy(self) -> Namespace: - raise NotImplementedError + from pytato.transform import CopyMapper, copy_namespace + return copy_namespace(self, CopyMapper(Namespace())) def assign(self, name: str, value: Array) -> str: """Declare a new array. diff --git a/pytato/transform.py b/pytato/transform.py index 13546653188a6da64a5f65b1ef3ed68832e115b3..739081a347f413b03c59095c72ef410737e95161 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -24,9 +24,9 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Any, Callable +from typing import Any, Callable, Dict -from pytato.array import Array +from pytato.array import Array, IndexLambda, Namespace, Placeholder __doc__ = """ .. currentmodule:: pytato.transform @@ -35,6 +35,8 @@ Transforming Computations ------------------------- .. autoclass:: Mapper +.. autoclass:: CopyMapper +.. autofunction:: copy_namespace """ @@ -74,7 +76,56 @@ class Mapper: rec = __call__ + +class CopyMapper(Mapper): + namespace: Namespace + + def __init__(self, namespace: Namespace): + self.namespace = namespace + self.cache: Dict[Array, Array] = {} + + def rec(self, expr: Array) -> Array: # type: ignore[override] + if expr in self.cache: + return self.cache[expr] + result: Array = super().rec(expr) + self.cache[expr] = result + return result + + def __call__(self, expr: Array) -> Array: # type: ignore[override] + return self.rec(expr) + + def map_index_lambda(self, expr: IndexLambda) -> Array: + bindings = { + name: self.rec(subexpr) + for name, subexpr in expr.bindings.items()} + return IndexLambda(self.namespace, + expr=expr.expr, + shape=expr.shape, + dtype=expr.dtype, + bindings=bindings) + + def map_placeholder(self, expr: Placeholder) -> Array: + return Placeholder(self.namespace, expr.name, expr.shape, expr.dtype, + expr.tags) + # }}} +# {{{ mapper frontends + +def copy_namespace(namespace: Namespace, copy_mapper: CopyMapper) -> Namespace: + """Copy the elements of *namespace* into a new namespace. + + :param namespace: The source namespace + :param copy_mapper: A mapper that performs copies into a new namespace + :returns: The new namespace + """ + for name, val in namespace.items(): + mapped_val = copy_mapper(val) + if name not in copy_mapper.namespace: + copy_mapper.namespace.assign(name, mapped_val) + return copy_mapper.namespace + +# }}} + # vim: foldmethod=marker diff --git a/test/test_transform.py b/test/test_transform.py new file mode 100755 index 0000000000000000000000000000000000000000..cb9f757b25bd37351a67502e62cdf2efc53bb099 --- /dev/null +++ b/test/test_transform.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python + +__copyright__ = "Copyright (C) 2020 Andreas Kloeckner" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import sys + +import numpy as np +import pyopencl as cl # noqa +import pyopencl.array as cl_array # noqa +import pyopencl.cltypes as cltypes # noqa +import pyopencl.tools as cl_tools # noqa +from pyopencl.tools import ( # noqa + pytest_generate_tests_for_pyopencl as pytest_generate_tests) +import pytest # noqa + +import pytato as pt + + +def test_copy_namespace(): + namespace = pt.Namespace() + x = pt.Placeholder(namespace, "x", (5,), np.int) + namespace.assign("xsquared", x * x) + + namespace_copy = namespace.copy() + assert len(namespace_copy) == 2 + assert isinstance(namespace_copy["x"], pt.Placeholder) + assert isinstance(namespace_copy["xsquared"], pt.IndexLambda) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__]) + +# vim: filetype=pyopencl:fdm=marker