From a5fe5b22ba167b03a16cc2e60542b384fb6adb8f Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Sun, 8 Oct 2023 23:15:28 -0500 Subject: [PATCH] replace pyrsistent.pmap with immutabledict (#248) * replace pyrsistent.pmap with immutabledict * change type annotation to 'Mapping' --- arraycontext/impl/pytato/compile.py | 14 ++++++-------- setup.py | 2 +- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 5fe9e7b..dd21ad4 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -36,7 +36,7 @@ from dataclasses import dataclass, field from typing import Any, Callable, Dict, FrozenSet, Mapping, Tuple, Type import numpy as np -from pyrsistent import PMap, pmap +from immutabledict import immutabledict import pytato as pt from pytools import ProcessLogger @@ -131,11 +131,9 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...], kwargs: Mapping[str, Any] - ) -> "Tuple[PMap[Tuple[Any, ...],\ - Any],\ - PMap[Tuple[Any, ...],\ - AbstractInputDescriptor]\ - ]": + ) -> \ + Tuple[Mapping[Tuple[Any, ...], Any], + Mapping[Tuple[Any, ...], AbstractInputDescriptor]]: """ Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Extracts mappings from argument id to argument values and from argument id to @@ -171,7 +169,7 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...], " either a scalar, pt.Array or an array container. Got" f" '{arg}'.") - return pmap(arg_id_to_arg), pmap(arg_id_to_descr) + return immutabledict(arg_id_to_arg), immutabledict(arg_id_to_descr) def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext): @@ -259,7 +257,7 @@ class BaseLazilyCompilingFunctionCaller: actx: _BasePytatoArrayContext f: Callable[..., Any] - program_cache: Dict["PMap[Tuple[Any, ...], AbstractInputDescriptor]", + program_cache: Dict[Mapping[Tuple[Any, ...], AbstractInputDescriptor], "CompiledFunction"] = field(default_factory=lambda: {}) # {{{ abstract interface diff --git a/setup.py b/setup.py index b943796..6563130 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def main(): # https://github.com/inducer/arraycontext/pull/147 "pytools>=2022.1.3", - + "immutabledict", "loopy>=2019.1", ], extras_require={ -- GitLab