diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 43488c4316f9fb7392bf0c354ef0fe9ab2f16a29..d83ece6ed68011a38e299b61d3b3081f88d4bcd3 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -38,7 +38,13 @@ import pyopencl.array as cla import pytato as pt +# {{{ helper classes: AbstractInputDescriptor + class AbstractInputDescriptor: + """ + Used internally in :class:`LazilyCompilingFunctionCaller` to characterize + an input. + """ def __eq__(self, other): raise NotImplementedError @@ -51,10 +57,23 @@ class ScalarInputDescriptor(AbstractInputDescriptor): dtype: np.dtype +@dataclass(frozen=True, eq=True) +class LeafArrayDescriptor: + dtype: np.dtype + shape: Tuple[int, ...] + + @dataclass(frozen=True, eq=True) class ArrayContainerInputDescriptor(AbstractInputDescriptor): - id_to_ary_descr: "PMap[Tuple[Any, ...], Tuple[np.dtype, \ - Tuple[int, ...]]]" + """ + .. attribute id_to_ary_descr:: + + A mapping from keys of leaf arrays of an array container to their + :class:`LeafArrayDescriptor`. + """ + id_to_ary_descr: "PMap[Tuple[Any, ...], LeafArrayDescriptor]" + +# }}} def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: @@ -116,8 +135,8 @@ class LazilyCompilingFunctionCaller: id_to_ary_descr = {} def id_collector(keys, ary): - id_to_ary_descr[keys] = (np.dtype(ary.dtype), - ary.shape) + id_to_ary_descr[keys] = LeafArrayDescriptor(np.dtype(ary.dtype), + ary.shape) return ary rec_keyed_map_array_container(id_collector, arg)