diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 5fe9e7b3d6641a72011f42315b5d16daecd61e70..dd21ad4c262063c39ddd14270c0b9e29bccb40b2 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -36,7 +36,7 @@ from dataclasses import dataclass, field from typing import Any, Callable, Dict, FrozenSet, Mapping, Tuple, Type import numpy as np -from pyrsistent import PMap, pmap +from immutabledict import immutabledict import pytato as pt from pytools import ProcessLogger @@ -131,11 +131,9 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...], kwargs: Mapping[str, Any] - ) -> "Tuple[PMap[Tuple[Any, ...],\ - Any],\ - PMap[Tuple[Any, ...],\ - AbstractInputDescriptor]\ - ]": + ) -> \ + Tuple[Mapping[Tuple[Any, ...], Any], + Mapping[Tuple[Any, ...], AbstractInputDescriptor]]: """ Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Extracts mappings from argument id to argument values and from argument id to @@ -171,7 +169,7 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...], " either a scalar, pt.Array or an array container. Got" f" '{arg}'.") - return pmap(arg_id_to_arg), pmap(arg_id_to_descr) + return immutabledict(arg_id_to_arg), immutabledict(arg_id_to_descr) def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext): @@ -259,7 +257,7 @@ class BaseLazilyCompilingFunctionCaller: actx: _BasePytatoArrayContext f: Callable[..., Any] - program_cache: Dict["PMap[Tuple[Any, ...], AbstractInputDescriptor]", + program_cache: Dict[Mapping[Tuple[Any, ...], AbstractInputDescriptor], "CompiledFunction"] = field(default_factory=lambda: {}) # {{{ abstract interface diff --git a/setup.py b/setup.py index b9437965986407807cfa8e329c3bbb444efba99d..6563130307c32e19fc238125f0c7c53f6e38eaf6 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def main(): # https://github.com/inducer/arraycontext/pull/147 "pytools>=2022.1.3", - + "immutabledict", "loopy>=2019.1", ], extras_require={