diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index aafcfd8fba2eef655e2bf8d75c69b7ed1c950c14..76242ef45afa83aca5d734ac39a59c77e8f58ea8 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -41,7 +41,8 @@ from .container import ( ArrayContainer, is_array_container, is_array_container_type, get_container_context, get_container_context_recursively, - serialize_container, deserialize_container) + serialize_container, deserialize_container, + register_multivector_as_array_container) from .container.arithmetic import with_container_arithmetic from .container.dataclass import dataclass_array_container @@ -78,6 +79,7 @@ __all__ = ( "is_array_container", "is_array_container_type", "get_container_context", "get_container_context_recursively", "serialize_container", "deserialize_container", + "register_multivector_as_array_container", "with_container_arithmetic", "dataclass_array_container", diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index ea208cb33fcb1c517baca9f6be0675ed614e522c..e92b527a789c6275f4e4c918ef9299cf7e279064 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -30,6 +30,11 @@ Context retrieval ----------------- .. autofunction:: get_container_context .. autofunction:: get_container_context_recursively + +:class:`~pymbolic.geometric_algebra.MultiVector` support +--------------------------------------------------------- + +.. autofunction:: register_multivector_as_array_container """ @@ -59,13 +64,16 @@ THE SOFTWARE. from functools import singledispatch from arraycontext.context import ArrayContext -from typing import Any, Iterable, Tuple, TypeVar, Optional, Union +from typing import Any, Iterable, Tuple, TypeVar, Optional, Union, TYPE_CHECKING import numpy as np ArrayT = TypeVar("ArrayT") ContainerT = TypeVar("ContainerT") ArrayOrContainerT = Union[ArrayT, ContainerT] +if TYPE_CHECKING: + from pymbolic.geometric_algebra import MultiVector + # {{{ ArrayContainer @@ -248,4 +256,40 @@ def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]: # }}} +# {{{ MultiVector support, see pymbolic.geometric_algebra + +# FYI: This doesn't, and never should, make arraycontext directly depend on pymbolic. +# (Though clearly there exists a dependency via loopy.) + +def _serialize_multivec_as_container(mv: "MultiVector") -> Iterable[Tuple[Any, Any]]: + return list(mv.data.items()) + + +def _deserialize_multivec_as_container(template: "MultiVector", + iterable: Iterable[Tuple[Any, Any]]) -> "MultiVector": + from pymbolic.geometric_algebra import MultiVector + return MultiVector(dict(iterable), space=template.space) + + +def _get_container_context_from_multivec(mv: "MultiVector") -> None: + return None + + +def register_multivector_as_array_container() -> None: + """Registers :class:`~pymbolic.geometric_algebra.MultiVector` as an + :class:`ArrayContainer`. This function may be called multiple times. The + second and subsequent calls have no effect. + """ + from pymbolic.geometric_algebra import MultiVector + if MultiVector not in serialize_container.registry: + serialize_container.register(MultiVector)(_serialize_multivec_as_container) + deserialize_container.register(MultiVector)( + _deserialize_multivec_as_container) + get_container_context.register(MultiVector)( + _get_container_context_from_multivec) + assert MultiVector in serialize_container.registry + +# }}} + + # vim: foldmethod=marker diff --git a/doc/conf.py b/doc/conf.py index ed6f89e08c18a8b75ff5c9c17a146ca43280f4d1..bee0e10b98630bf89addaabe7931fd5eeaf86478 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -18,6 +18,7 @@ intersphinx_mapping = { "https://docs.python.org/3/": None, "https://numpy.org/doc/stable/": None, "https://documen.tician.de/pytools": None, + "https://documen.tician.de/pymbolic": None, "https://documen.tician.de/pyopencl": None, "https://documen.tician.de/pytato": None, "https://documen.tician.de/loopy": None,