From a5fe5b22ba167b03a16cc2e60542b384fb6adb8f Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Sun, 8 Oct 2023 23:15:28 -0500
Subject: [PATCH] replace pyrsistent.pmap with immutabledict (#248)

* replace pyrsistent.pmap with immutabledict

* change type annotation to 'Mapping'
---
 arraycontext/impl/pytato/compile.py | 14 ++++++--------
 setup.py                            |  2 +-
 2 files changed, 7 insertions(+), 9 deletions(-)

diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index 5fe9e7b..dd21ad4 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -36,7 +36,7 @@ from dataclasses import dataclass, field
 from typing import Any, Callable, Dict, FrozenSet, Mapping, Tuple, Type
 
 import numpy as np
-from pyrsistent import PMap, pmap
+from immutabledict import immutabledict
 
 import pytato as pt
 from pytools import ProcessLogger
@@ -131,11 +131,9 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str:
 
 def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...],
                                            kwargs: Mapping[str, Any]
-                                           ) -> "Tuple[PMap[Tuple[Any, ...],\
-                                                            Any],\
-                                                       PMap[Tuple[Any, ...],\
-                                                            AbstractInputDescriptor]\
-                                                       ]":
+                                           ) -> \
+            Tuple[Mapping[Tuple[Any, ...], Any],
+                  Mapping[Tuple[Any, ...], AbstractInputDescriptor]]:
     """
     Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Extracts
     mappings from argument id to argument values and from argument id to
@@ -171,7 +169,7 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...],
                              " either a scalar, pt.Array or an array container. Got"
                              f" '{arg}'.")
 
-    return pmap(arg_id_to_arg), pmap(arg_id_to_descr)
+    return immutabledict(arg_id_to_arg), immutabledict(arg_id_to_descr)
 
 
 def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext):
@@ -259,7 +257,7 @@ class BaseLazilyCompilingFunctionCaller:
 
     actx: _BasePytatoArrayContext
     f: Callable[..., Any]
-    program_cache: Dict["PMap[Tuple[Any, ...], AbstractInputDescriptor]",
+    program_cache: Dict[Mapping[Tuple[Any, ...], AbstractInputDescriptor],
                         "CompiledFunction"] = field(default_factory=lambda: {})
 
     # {{{ abstract interface
diff --git a/setup.py b/setup.py
index b943796..6563130 100644
--- a/setup.py
+++ b/setup.py
@@ -42,7 +42,7 @@ def main():
 
             # https://github.com/inducer/arraycontext/pull/147
             "pytools>=2022.1.3",
-
+            "immutabledict",
             "loopy>=2019.1",
         ],
         extras_require={
-- 
GitLab