From a3c7eead940d9ca29b6807f2785b3e4cd83e0430 Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Thu, 30 Jan 2025 09:03:34 -0600
Subject: [PATCH] replace pyrsistent.PMap, immutables.Map, immutabledict with
 constantdict (#884)

* replace pyrsistent.PMap, immutables.Map with immutabledict

* go back to immutables.Map for Tree

* ruff fixes

* spelling

* switch to constantdict

* lint fixes

* doc fix

* work around doc failure

* remove some spurious changes

* clean up types a bit

* simplify with_kernel
---
 doc/conf.py                             |  6 +-----
 loopy/codegen/__init__.py               |  8 +++----
 loopy/frontend/fortran/translator.py    |  4 ++--
 loopy/kernel/__init__.py                |  4 ++--
 loopy/kernel/data.py                    |  4 ++--
 loopy/kernel/function_interface.py      | 12 ++++++-----
 loopy/kernel/instruction.py             | 16 +++++++-------
 loopy/kernel/tools.py                   |  4 ++--
 loopy/preprocess.py                     | 12 +++++------
 loopy/schedule/__init__.py              |  4 ++--
 loopy/schedule/tools.py                 |  4 ++--
 loopy/schedule/tree.py                  |  2 +-
 loopy/symbolic.py                       |  8 +++----
 loopy/target/c/c_execution.py           |  4 ++--
 loopy/target/execution.py               | 28 ++++++++++++-------------
 loopy/target/pyopencl_execution.py      |  4 ++--
 loopy/tools.py                          | 16 ++++++--------
 loopy/transform/buffer.py               |  6 +++---
 loopy/transform/callable.py             |  6 +++---
 loopy/transform/data.py                 |  4 ++--
 loopy/transform/fusion.py               |  8 +++----
 loopy/transform/pack_and_unpack_args.py |  4 ++--
 loopy/transform/precompute.py           |  4 ++--
 loopy/transform/realize_reduction.py    |  4 ++--
 loopy/transform/save.py                 |  4 ++--
 loopy/translation_unit.py               | 28 ++++++++++++-------------
 pyproject.toml                          |  4 +---
 27 files changed, 101 insertions(+), 111 deletions(-)

diff --git a/doc/conf.py b/doc/conf.py
index b23ce311..d12eb17b 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -32,7 +32,7 @@ intersphinx_mapping = {
         "pyopencl": ("https://documen.tician.de/pyopencl", None),
         "cgen": ("https://documen.tician.de/cgen", None),
         "pymbolic": ("https://documen.tician.de/pymbolic", None),
-        "pyrsistent": ("https://pyrsistent.readthedocs.io/en/latest/", None),
+        "constantdict": ("https://matthiasdiener.github.io/constantdict/", None),
         }
 
 nitpicky = True
@@ -43,10 +43,6 @@ nitpick_ignore_regex = [
         ["py:class", r"numpy\.float[0-9]+"],
         ["py:class", r"numpy\.complex[0-9]+"],
 
-        # As of 2022-06-22, it doesn't look like there's sphinx documentation
-        # available.
-        ["py:class", r"immutables\.(.+)"],
-
         # Reference not found from "<unknown>"? I'm not even sure where to look.
         ["py:class", r"ExpressionNode"],
 
diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index 3c3b42f3..3de36e24 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -32,7 +32,7 @@ from typing import (
     Sequence,
 )
 
-import immutables
+import constantdict
 
 
 logger = logging.getLogger(__name__)
@@ -168,7 +168,7 @@ class CodeGenerationState:
     seen_functions: set[SeenFunction]
     seen_atomic_dtypes: set[LoopyType]
 
-    var_subst_map: immutables.Map[str, Expression]
+    var_subst_map: constantdict.constantdict[str, Expression]
     allow_complex: bool
     callables_table: CallablesTable
     is_entrypoint: bool
@@ -381,7 +381,7 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
             seen_dtypes=seen_dtypes,
             seen_functions=seen_functions,
             seen_atomic_dtypes=seen_atomic_dtypes,
-            var_subst_map=immutables.Map(),
+            var_subst_map=constantdict.constantdict(),
             allow_complex=allow_complex,
             var_name_generator=kernel.get_var_name_generator(),
             is_generating_device_code=False,
@@ -482,7 +482,7 @@ def diverge_callee_entrypoints(program):
 
         new_callables[name] = clbl
 
-    return program.copy(callables_table=immutables.Map(new_callables))
+    return program.copy(callables_table=constantdict.constantdict(new_callables))
 
 
 @dataclass(frozen=True)
diff --git a/loopy/frontend/fortran/translator.py b/loopy/frontend/fortran/translator.py
index 5000abf8..4c55b30f 100644
--- a/loopy/frontend/fortran/translator.py
+++ b/loopy/frontend/fortran/translator.py
@@ -29,7 +29,7 @@ from typing import ClassVar
 from warnings import warn
 
 import numpy as np
-from immutables import Map
+from constantdict import constantdict
 
 import islpy as isl
 from islpy import dim_type
@@ -334,7 +334,7 @@ def specialize_fortran_division(t_unit):
 
         new_callables[name] = clbl
 
-    return t_unit.copy(callables_table=Map(new_callables))
+    return t_unit.copy(callables_table=constantdict(new_callables))
 
 # }}}
 
diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py
index d612b5db..f487078c 100644
--- a/loopy/kernel/__init__.py
+++ b/loopy/kernel/__init__.py
@@ -48,7 +48,7 @@ from typing import (
 from warnings import warn
 
 import numpy as np
-from immutables import Map
+from constantdict import constantdict
 
 import islpy  # to help out Sphinx
 import islpy as isl
@@ -183,7 +183,7 @@ class LoopKernel(Taggable):
             Callable[[LoopKernel, str], tuple[LoopyType, str] | None]] = ()
     linearization: Sequence[ScheduleItem] | None = None
     iname_slab_increments: Mapping[InameStr, tuple[int, int]] = field(
-            default_factory=Map)
+            default_factory=constantdict)
     """
     A mapping from inames to (lower_incr,
     upper_incr) tuples that will be separated out in the execution to generate
diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py
index d10401e5..3dd1cf82 100644
--- a/loopy/kernel/data.py
+++ b/loopy/kernel/data.py
@@ -64,7 +64,7 @@ from loopy.typing import Expression, ShapeType, auto
 
 
 if TYPE_CHECKING:
-    from immutables import Map
+    from collections.abc import Mapping
 
     from pymbolic import ArithmeticExpression, Variable
 
@@ -437,7 +437,7 @@ class _ArraySeparationInfo:
     should be used to realize this array.
     """
     sep_axis_indices_set: frozenset[int]
-    subarray_names: Map[tuple[int, ...], str]
+    subarray_names: Mapping[tuple[int, ...], str]
 
 
 class ArrayArg(ArrayBase, KernelArgument):
diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py
index 146d40f4..799e5d91 100644
--- a/loopy/kernel/function_interface.py
+++ b/loopy/kernel/function_interface.py
@@ -27,7 +27,7 @@ from dataclasses import dataclass, replace
 from typing import TYPE_CHECKING, Any, Callable, TypeVar
 from warnings import warn
 
-from immutabledict import immutabledict
+from constantdict import constantdict
 from typing_extensions import Self
 
 from loopy.diagnostic import LoopyError
@@ -348,7 +348,8 @@ class InKernelCallable(ABC):
         try:
             hash(arg_id_to_dtype)
         except TypeError:
-            arg_id_to_dtype = immutabledict(arg_id_to_dtype)
+            assert arg_id_to_dtype is not None
+            arg_id_to_dtype = constantdict(arg_id_to_dtype)
             warn("arg_id_to_dtype passed to InKernelCallable was not hashable. "
                  "This usage is deprecated and will stop working in 2026.",
                  DeprecationWarning, stacklevel=3)
@@ -356,7 +357,8 @@ class InKernelCallable(ABC):
         try:
             hash(arg_id_to_descr)
         except TypeError:
-            arg_id_to_descr = immutabledict(arg_id_to_descr)
+            assert arg_id_to_descr is not None
+            arg_id_to_descr = constantdict(arg_id_to_descr)
             warn("arg_id_to_descr passed to InKernelCallable was not hashable. "
                  "This usage is deprecated and will stop working in 2026.",
                  DeprecationWarning, stacklevel=3)
@@ -773,7 +775,7 @@ class CallableKernel(InKernelCallable):
         # Return the kernel call with specialized subkernel and the corresponding
         # new arg_id_to_dtype
         return self.copy(subkernel=specialized_kernel,
-                arg_id_to_dtype=immutabledict(new_arg_id_to_dtype)), callables_table
+                arg_id_to_dtype=constantdict(new_arg_id_to_dtype)), callables_table
 
     def with_descrs(self, arg_id_to_descr, clbl_inf_ctx):
 
@@ -848,7 +850,7 @@ class CallableKernel(InKernelCallable):
         # }}}
 
         return (self.copy(subkernel=subkernel,
-                          arg_id_to_descr=immutabledict(arg_id_to_descr)),
+                          arg_id_to_descr=constantdict(arg_id_to_descr)),
                 clbl_inf_ctx)
 
     def with_added_arg(self, arg_dtype, arg_descr):
diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py
index f882c09f..604b581e 100644
--- a/loopy/kernel/instruction.py
+++ b/loopy/kernel/instruction.py
@@ -292,7 +292,7 @@ class InstructionBase(ImmutableRecord, Taggable):
                  *,
                  depends_on: frozenset[str] | str | None = None,
                  ) -> None:
-        from immutabledict import immutabledict
+        from constantdict import constantdict
 
         if predicates is None:
             predicates = frozenset()
@@ -324,29 +324,29 @@ class InstructionBase(ImmutableRecord, Taggable):
             raise LoopyError("Setting depends_on_is_final to True requires "
                     "actually specifying happens_after/depends_on")
 
-        if isinstance(happens_after, immutabledict):
+        if isinstance(happens_after, constantdict):
             pass
         elif happens_after is None:
-            happens_after = immutabledict()
+            happens_after = constantdict()
         elif isinstance(happens_after, str):
             warn("Passing a string for happens_after/depends_on is deprecated and "
                  "will stop working in 2025. Instead, pass a full-fledged "
                  "happens_after data structure.", DeprecationWarning, stacklevel=2)
 
-            happens_after = immutabledict({
+            happens_after = constantdict({
                     after_id.strip(): HappensAfter(
                         variable_name=None,
                         instances_rel=None)
                     for after_id in happens_after.split(",")
                     if after_id.strip()})
         elif isinstance(happens_after, frozenset):
-            happens_after = immutabledict({
+            happens_after = constantdict({
                     after_id: HappensAfter(
                         variable_name=None,
                         instances_rel=None)
                     for after_id in happens_after})
         elif isinstance(happens_after, dict):
-            happens_after = immutabledict(happens_after)
+            happens_after = constantdict(happens_after)
         else:
             raise TypeError("'happens_after' has unexpected type: "
                             f"{type(happens_after)}")
@@ -569,13 +569,13 @@ class InstructionBase(ImmutableRecord, Taggable):
     def __setstate__(self, val):
         super().__setstate__(val)
 
-        from immutabledict import immutabledict
+        from constantdict import constantdict
 
         from loopy.tools import intern_frozenset_of_ids
 
         if self.id is not None:  # pylint:disable=access-member-before-definition
             self.id = intern(self.id)
-        self.happens_after = immutabledict({
+        self.happens_after = constantdict({
                 intern(after_id): ha
                 for after_id, ha in self.happens_after.items()})
         self.groups = intern_frozenset_of_ids(self.groups)
diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py
index c48da4be..856ba19c 100644
--- a/loopy/kernel/tools.py
+++ b/loopy/kernel/tools.py
@@ -2089,7 +2089,7 @@ def get_call_graph(t_unit, only_kernel_callables=False):
 
     :arg t_unit: An instance of :class:`TranslationUnit`.
     """
-    from pyrsistent import pmap
+    from constantdict import constantdict
 
     from loopy.kernel import KernelState
 
@@ -2116,7 +2116,7 @@ def get_call_graph(t_unit, only_kernel_callables=False):
                 call_graph[name] = clbl.get_called_callables(t_unit.callables_table,
                                                              recursive=False)
 
-    return pmap(call_graph)
+    return constantdict(call_graph)
 
 # }}}
 
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index aee4044b..7600c97b 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
 from functools import partial
 
 import numpy as np
-from immutables import Map
+from constantdict import constantdict
 
 from pytools import ProcessLogger
 
@@ -197,7 +197,7 @@ def make_arrays_for_sep_arrays(kernel: LoopKernel) -> LoopKernel:
 
         sep_info = _ArraySeparationInfo(
                 sep_axis_indices_set=sep_axis_indices_set,
-                subarray_names=Map({
+                subarray_names=constantdict({
                     ind: vng(f"{arg.name}_s{'_'.join(str(i) for i in ind)}")
                     for ind in np.ndindex(*cast("List[int]", sep_shape))}))
 
@@ -605,8 +605,6 @@ class ArgDescrInferenceMapper(RuleAwareIdentityMapper):
         raise NotImplementedError
 
     def __call__(self, expr, kernel, insn, assignees=None):
-        import immutables
-
         from loopy.kernel.data import InstructionBase
         from loopy.symbolic import ExpansionState, UncachedIdentityMapper
         assert insn is None or isinstance(insn, InstructionBase)
@@ -616,7 +614,7 @@ class ArgDescrInferenceMapper(RuleAwareIdentityMapper):
                     kernel=kernel,
                     instruction=insn,
                     stack=(),
-                    arg_context=immutables.Map()), assignees=assignees)
+                    arg_context=constantdict()), assignees=assignees)
 
     def map_kernel(self, kernel):
 
@@ -750,7 +748,7 @@ def filter_reachable_callables(t_unit):
                                                                  t_unit.entrypoints)
     new_callables = {name: clbl for name, clbl in t_unit.callables_table.items()
                      if name in (reachable_function_ids | t_unit.entrypoints)}
-    return t_unit.copy(callables_table=Map(new_callables))
+    return t_unit.copy(callables_table=constantdict(new_callables))
 
 
 def _preprocess_single_kernel(kernel: LoopKernel, is_entrypoint: bool) -> LoopKernel:
@@ -875,7 +873,7 @@ def preprocess_program(t_unit: TranslationUnit) -> TranslationUnit:
 
         new_callables[func_id] = in_knl_callable
 
-    t_unit = t_unit.copy(callables_table=Map(new_callables))
+    t_unit = t_unit.copy(callables_table=constantdict(new_callables))
 
     # }}}
 
diff --git a/loopy/schedule/__init__.py b/loopy/schedule/__init__.py
index a009fc9c..8de619fe 100644
--- a/loopy/schedule/__init__.py
+++ b/loopy/schedule/__init__.py
@@ -34,7 +34,7 @@ from typing import (
     TypeVar,
 )
 
-from immutables import Map
+from constantdict import constantdict
 
 import islpy as isl
 from pytools import ImmutableRecord, MinRecursionLimit, ProcessLogger
@@ -2480,7 +2480,7 @@ def linearize(t_unit: TranslationUnit) -> TranslationUnit:
         else:
             raise NotImplementedError(type(clbl))
 
-    return t_unit.copy(callables_table=Map(new_callables))
+    return t_unit.copy(callables_table=constantdict(new_callables))
 
 
 # vim: foldmethod=marker
diff --git a/loopy/schedule/tools.py b/loopy/schedule/tools.py
index b659ee7b..709e3705 100644
--- a/loopy/schedule/tools.py
+++ b/loopy/schedule/tools.py
@@ -65,7 +65,7 @@ from dataclasses import dataclass
 from functools import cached_property, reduce
 from typing import TYPE_CHECKING, AbstractSet, Sequence
 
-from immutables import Map
+from constantdict import constantdict
 from typing_extensions import TypeAlias
 
 import islpy as isl
@@ -1062,7 +1062,7 @@ def _get_iname_to_tree_node_id_from_partial_loop_nest_tree(
         for iname in node:
             iname_to_tree_node_id[iname] = node
 
-    return Map(iname_to_tree_node_id)
+    return constantdict(iname_to_tree_node_id)
 
 
 def get_loop_tree(kernel: LoopKernel) -> LoopTree:
diff --git a/loopy/schedule/tree.py b/loopy/schedule/tree.py
index 3861aa75..fbd26477 100644
--- a/loopy/schedule/tree.py
+++ b/loopy/schedule/tree.py
@@ -40,7 +40,7 @@ from dataclasses import dataclass
 from functools import cached_property, reduce
 from typing import Generic, TypeVar
 
-from immutables import Map
+from constantdict import constantdict as Map  # noqa: N812
 
 from pytools import memoize_method
 
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index 7b4b9805..20ff55fe 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -42,8 +42,8 @@ from typing import (
 )
 from warnings import warn
 
-import immutables
 import numpy as np
+from constantdict import constantdict
 from typing_extensions import Self
 
 import islpy as isl
@@ -1114,7 +1114,7 @@ class ExpansionState:
     kernel: LoopKernel
     instruction: InstructionBase
     stack: tuple[tuple[str, Tag], ...]
-    arg_context: immutables.Map[str, Expression]
+    arg_context: Mapping[str, Expression]
 
     def __post_init__(self) -> None:
         hash(self.arg_context)
@@ -1352,7 +1352,7 @@ class RuleAwareIdentityMapper(IdentityMapper[Concatenate[ExpansionState, P]]):
 
         from pymbolic.mapper.substitutor import make_subst_func
         arg_subst_map = SubstitutionMapper(make_subst_func(arg_context))
-        return immutables.Map({
+        return constantdict({
             formal_arg_name: arg_subst_map(arg_value)
             for formal_arg_name, arg_value in zip(arg_names, arguments)})
 
@@ -1398,7 +1398,7 @@ class RuleAwareIdentityMapper(IdentityMapper[Concatenate[ExpansionState, P]]):
                     kernel=kernel,
                     instruction=insn,
                     stack=(),
-                    arg_context=immutables.Map()))
+                    arg_context=constantdict()))
 
     def map_instruction(self, kernel, insn):
         return insn
diff --git a/loopy/target/c/c_execution.py b/loopy/target/c/c_execution.py
index 270c3d0d..b1566d07 100644
--- a/loopy/target/c/c_execution.py
+++ b/loopy/target/c/c_execution.py
@@ -48,7 +48,7 @@ from loopy.types import LoopyType
 
 
 if TYPE_CHECKING:
-    from immutables import Map
+    from constantdict import constantdict
 
     from loopy.codegen.result import GeneratedProgram
     from loopy.kernel import LoopKernel
@@ -500,7 +500,7 @@ class CExecutor(ExecutorBase):
 
     @memoize_method
     def translation_unit_info(self,
-            arg_to_dtype: Map[str, LoopyType] | None = None) -> _KernelInfo:
+            arg_to_dtype: constantdict[str, LoopyType] | None = None) -> _KernelInfo:
         t_unit = self.get_typed_and_scheduled_translation_unit(arg_to_dtype)
 
         from loopy.codegen import generate_code_v2
diff --git a/loopy/target/execution.py b/loopy/target/execution.py
index cb737f95..2bde4ef9 100644
--- a/loopy/target/execution.py
+++ b/loopy/target/execution.py
@@ -36,7 +36,7 @@ from typing import (
     cast,
 )
 
-from immutables import Map
+from constantdict import constantdict
 
 from pymbolic import Variable, var
 from pytools.codegen import CodeGenerator, Indentation
@@ -817,7 +817,7 @@ class ExecutorBase:
                 "your argument.")
 
     def get_typed_and_scheduled_translation_unit_uncached(
-            self, arg_to_dtype: Map[str, LoopyType] | None
+            self, arg_to_dtype: constantdict[str, LoopyType] | None
             ) -> TranslationUnit:
         t_unit = self.t_unit
 
@@ -827,15 +827,15 @@ class ExecutorBase:
             # FIXME: This is not so nice. This transfers types from the
             # subarrays of sep-tagged arrays to the 'main' array, because
             # type inference fails otherwise.
-            with arg_to_dtype.mutate() as mm:
-                for name, sep_info in self.sep_info.items():
-                    if entry_knl.arg_dict[name].dtype is None:
-                        for sep_name in sep_info.subarray_names.values():
-                            if sep_name in arg_to_dtype:
-                                mm.set(name, arg_to_dtype[sep_name])
-                                del mm[sep_name]
+            mm = arg_to_dtype.mutate()
+            for name, sep_info in self.sep_info.items():
+                if entry_knl.arg_dict[name].dtype is None:
+                    for sep_name in sep_info.subarray_names.values():
+                        if sep_name in arg_to_dtype:
+                            mm[name] = arg_to_dtype[sep_name]
+                            del mm[sep_name]
 
-                arg_to_dtype = mm.finish()
+            arg_to_dtype = mm.finish()
 
             from loopy.kernel.tools import add_dtypes
             t_unit = t_unit.with_kernel(add_dtypes(entry_knl, arg_to_dtype))
@@ -854,7 +854,7 @@ class ExecutorBase:
         return t_unit
 
     def get_typed_and_scheduled_translation_unit(
-            self, arg_to_dtype: Map[str, LoopyType] | None
+            self, arg_to_dtype: constantdict[str, LoopyType] | None
             ) -> TranslationUnit:
         from loopy import CACHING_ENABLED
 
@@ -876,7 +876,7 @@ class ExecutorBase:
 
         return t_unit
 
-    def arg_to_dtype(self, kwargs) -> Map[str, LoopyType] | None:
+    def arg_to_dtype(self, kwargs) -> constantdict[str, LoopyType] | None:
         if not self.has_runtime_typed_args:
             return None
 
@@ -893,7 +893,7 @@ class ExecutorBase:
                 else:
                     arg_to_dtype[arg_name] = NumpyType(dtype)
 
-        return Map(arg_to_dtype)
+        return constantdict(arg_to_dtype)
 
     # {{{ debugging aids
 
@@ -904,7 +904,7 @@ class ExecutorBase:
 
     def get_code(
             self, entrypoint: str,
-            arg_to_dtype: Map[str, LoopyType] | None = None) -> str:
+            arg_to_dtype: constantdict[str, LoopyType] | None = None) -> str:
         kernel = self.get_typed_and_scheduled_translation_unit(arg_to_dtype)
 
         from loopy.codegen import generate_code_v2
diff --git a/loopy/target/pyopencl_execution.py b/loopy/target/pyopencl_execution.py
index c9191e1d..8c368cf8 100644
--- a/loopy/target/pyopencl_execution.py
+++ b/loopy/target/pyopencl_execution.py
@@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
 
 
 if TYPE_CHECKING:
-    from immutables import Map
+    from constantdict import constantdict
 
     import pyopencl as cl
 
@@ -311,7 +311,7 @@ class PyOpenCLExecutor(ExecutorBase):
     @memoize_method
     def translation_unit_info(
             self,
-            arg_to_dtype: Map[str, LoopyType] | None = None) -> _KernelInfo:
+            arg_to_dtype: constantdict[str, LoopyType] | None = None) -> _KernelInfo:
         t_unit = self.get_typed_and_scheduled_translation_unit(arg_to_dtype)
 
         # FIXME: now just need to add the types to the arguments
diff --git a/loopy/tools.py b/loopy/tools.py
index 6faaceb9..e9f9932b 100644
--- a/loopy/tools.py
+++ b/loopy/tools.py
@@ -29,7 +29,7 @@ from functools import cached_property
 from sys import intern
 
 import numpy as np
-from immutables import Map
+from constantdict import constantdict
 
 import islpy as isl
 from pytools import ProcessLogger, memoize_method
@@ -70,8 +70,8 @@ class LoopyKeyBuilder(KeyBuilderBase):
     update_for_list = KeyBuilderBase.update_for_tuple
     update_for_set = KeyBuilderBase.update_for_frozenset
 
-    update_for_dict = KeyBuilderBase.update_for_immutabledict
-    update_for_defaultdict = KeyBuilderBase.update_for_immutabledict
+    update_for_dict = KeyBuilderBase.update_for_constantdict
+    update_for_defaultdict = KeyBuilderBase.update_for_constantdict
 
     def update_for_BasicSet(self, key_hash, key):  # noqa
         from islpy import Printer
@@ -80,15 +80,11 @@ class LoopyKeyBuilder(KeyBuilderBase):
         key_hash.update(prn.get_str().encode("utf8"))
 
     def update_for_Map(self, key_hash, key):  # noqa
-        if isinstance(key, Map):
-            self.update_for_dict(key_hash, key)
-        elif isinstance(key, isl.Map):
+        if isinstance(key, isl.Map):
             self.update_for_BasicSet(key_hash, key)
         else:
             raise AssertionError()
 
-    update_for_PMap = update_for_dict  # noqa: N815
-
 # }}}
 
 
@@ -800,7 +796,7 @@ def t_unit_to_python(t_unit, var_name="t_unit",
                                                                .callables_table))
                      for name, clbl in t_unit.callables_table.items()
                      if isinstance(clbl, CallableKernel)}
-    t_unit = t_unit.copy(callables_table=Map(new_callables))
+    t_unit = t_unit.copy(callables_table=constantdict(new_callables))
 
     knl_python_code_srcs = [_kernel_to_python(clbl.subkernel,
                                               name in t_unit.entrypoints,
@@ -815,7 +811,7 @@ def t_unit_to_python(t_unit, var_name="t_unit",
         "import loopy as lp",
         "import numpy as np",
         "from pymbolic.primitives import *",
-        "import immutables",
+        "from constantdict import constantdict",
         ])
     body_str = "\n".join([*knl_python_code_srcs, "\n", merge_stmt])
 
diff --git a/loopy/transform/buffer.py b/loopy/transform/buffer.py
index f113e453..be29eeb4 100644
--- a/loopy/transform/buffer.py
+++ b/loopy/transform/buffer.py
@@ -25,7 +25,7 @@ THE SOFTWARE.
 
 import logging
 
-from immutables import Map
+from constantdict import constantdict
 
 from pymbolic import var
 from pymbolic.mapper.substitutor import make_subst_func
@@ -128,7 +128,7 @@ class ArrayAccessReplacer(RuleAwareIdentityMapper):
         # Can't possibly be nested, but recurse anyway to
         # make sure substitution rules referenced below here
         # do not get thrown away.
-        self.rec(result, expn_state.copy(arg_context=Map()))
+        self.rec(result, expn_state.copy(arg_context=constantdict()))
 
         return result
 
@@ -540,7 +540,7 @@ def buffer_array(program, *args, **kwargs):
 
         new_callables[func_id] = clbl
 
-    return program.copy(callables_table=Map(new_callables))
+    return program.copy(callables_table=constantdict(new_callables))
 
 
 # vim: foldmethod=marker
diff --git a/loopy/transform/callable.py b/loopy/transform/callable.py
index 8669a4ab..ef55db5a 100644
--- a/loopy/transform/callable.py
+++ b/loopy/transform/callable.py
@@ -26,7 +26,7 @@ THE SOFTWARE.
 
 from typing import TYPE_CHECKING
 
-from immutables import Map
+from constantdict import constantdict
 
 import islpy as isl
 from pytools import UniqueNameGenerator
@@ -129,7 +129,7 @@ def merge(translation_units: Sequence[TranslationUnit]) -> TranslationUnit:
     return TranslationUnit(
             entrypoints=frozenset().union(*(
                 t.entrypoints or frozenset() for t in translation_units)),
-            callables_table=Map(callables_table),
+            callables_table=constantdict(callables_table),
             target=translation_units[0].target)
 
 
@@ -605,7 +605,7 @@ def rename_callable(
         new_entrypoints = ((new_entrypoints | frozenset([new_name]))
                            - frozenset([old_name]))
 
-    return t_unit.copy(callables_table=Map(new_callables_table),
+    return t_unit.copy(callables_table=constantdict(new_callables_table),
                         entrypoints=new_entrypoints)
 
 # }}}
diff --git a/loopy/transform/data.py b/loopy/transform/data.py
index 80a0c4a1..2b0606ec 100644
--- a/loopy/transform/data.py
+++ b/loopy/transform/data.py
@@ -28,7 +28,7 @@ from typing import TYPE_CHECKING, cast
 from warnings import warn
 
 import numpy as np
-from immutables import Map
+from constantdict import constantdict
 
 from islpy import dim_type
 from pytools import MovedFunctionDeprecationWrapper
@@ -431,7 +431,7 @@ def add_prefetch(t_unit,
 
         new_callables[func_id] = in_knl_callable
 
-    return t_unit.copy(callables_table=Map(new_callables))
+    return t_unit.copy(callables_table=constantdict(new_callables))
 
 # }}}
 
diff --git a/loopy/transform/fusion.py b/loopy/transform/fusion.py
index b16d837f..f11006af 100644
--- a/loopy/transform/fusion.py
+++ b/loopy/transform/fusion.py
@@ -24,7 +24,7 @@ THE SOFTWARE.
 """
 
 
-from immutables import Map
+from constantdict import constantdict
 
 import islpy as isl
 from islpy import dim_type
@@ -120,8 +120,8 @@ def _merge_dicts(item_name, dict_a, dict_b):
         else:
             result[k] = v
 
-    if isinstance(dict_a, Map):
-        return Map(result)
+    if isinstance(dict_a, constantdict):
+        return constantdict(result)
     else:
         return result
 
@@ -453,7 +453,7 @@ def fuse_kernels(kernels, suffixes=None, data_flow=None):
 
     new_callables[result.name] = CallableKernel(result)
 
-    return TranslationUnit(callables_table=Map(new_callables),
+    return TranslationUnit(callables_table=constantdict(new_callables),
                            target=result.target,
                            entrypoints=frozenset([result.name]))
 
diff --git a/loopy/transform/pack_and_unpack_args.py b/loopy/transform/pack_and_unpack_args.py
index 9dc5f9a9..50fb6e29 100644
--- a/loopy/transform/pack_and_unpack_args.py
+++ b/loopy/transform/pack_and_unpack_args.py
@@ -23,7 +23,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from immutables import Map
+from constantdict import constantdict
 
 from loopy.diagnostic import LoopyError
 from loopy.kernel import LoopKernel
@@ -342,6 +342,6 @@ def pack_and_unpack_args_for_call(program, *args, **kwargs):
 
         new_callables[func_id] = in_knl_callable
 
-    return program.copy(callables_table=Map(new_callables))
+    return program.copy(callables_table=constantdict(new_callables))
 
 # vim: foldmethod=marker
diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py
index 3988b1f5..40666412 100644
--- a/loopy/transform/precompute.py
+++ b/loopy/transform/precompute.py
@@ -28,7 +28,7 @@ from dataclasses import dataclass
 from typing import TYPE_CHECKING, Sequence, cast
 
 import numpy as np
-from immutables import Map
+from constantdict import constantdict
 
 import islpy as isl
 from pymbolic import ArithmeticExpression, var
@@ -1187,6 +1187,6 @@ def precompute(program, *args, **kwargs):
 
         new_callables[func_id] = clbl
 
-    return program.copy(callables_table=Map(new_callables))
+    return program.copy(callables_table=constantdict(new_callables))
 
 # vim: foldmethod=marker
diff --git a/loopy/transform/realize_reduction.py b/loopy/transform/realize_reduction.py
index 5f504e72..ceb824b4 100644
--- a/loopy/transform/realize_reduction.py
+++ b/loopy/transform/realize_reduction.py
@@ -34,7 +34,7 @@ from typing import TYPE_CHECKING, Callable, Sequence
 
 logger = logging.getLogger(__name__)
 
-from immutables import Map
+from constantdict import constantdict
 
 import islpy as isl
 from pytools import memoize_on_first_arg
@@ -2184,6 +2184,6 @@ def realize_reduction(t_unit, *args, **kwargs):
                 subkernel=new_knl)
         callables_table[knl.name] = in_knl_callable
 
-    return t_unit.copy(callables_table=Map(callables_table))
+    return t_unit.copy(callables_table=constantdict(callables_table))
 
 # vim: foldmethod=marker
diff --git a/loopy/transform/save.py b/loopy/transform/save.py
index e1dbfd99..fe3cf190 100644
--- a/loopy/transform/save.py
+++ b/loopy/transform/save.py
@@ -26,7 +26,7 @@ THE SOFTWARE.
 import logging
 from functools import cached_property
 
-from immutables import Map
+from constantdict import constantdict
 
 from pytools import Record, memoize_method
 
@@ -252,7 +252,7 @@ class TemporarySaver:
         from collections import defaultdict
         self.insns_to_insert = []
         self.insns_to_update = {}
-        self.updated_iname_objs = Map()
+        self.updated_iname_objs = constantdict()
         self.updated_temporary_variables = {}
 
         # temporary name -> save or reload insn ids
diff --git a/loopy/translation_unit.py b/loopy/translation_unit.py
index 670feeef..93ecbe9a 100644
--- a/loopy/translation_unit.py
+++ b/loopy/translation_unit.py
@@ -37,7 +37,7 @@ from typing import (
 )
 from warnings import warn
 
-from immutables import Map
+from constantdict import constantdict
 from typing_extensions import Concatenate, ParamSpec, Self
 
 from pymbolic.primitives import Call, Variable
@@ -175,7 +175,7 @@ class CallableResolver(RuleAwareIdentityMapper):
 # {{{ translation unit
 
 FunctionIdT = Union[str, ReductionOpFunction]
-ConcreteCallablesTable = Map[FunctionIdT, InKernelCallable]
+ConcreteCallablesTable = constantdict[FunctionIdT, InKernelCallable]
 CallablesTable = Mapping[FunctionIdT, InKernelCallable]
 
 
@@ -203,7 +203,7 @@ class TranslationUnit:
 
     .. attribute:: callables_table
 
-        An instance of :class:`pyrsistent.PMap` mapping the function
+        An instance of :class:`constantdict.constantdict` mapping the function
         identifiers in a kernel to their associated instances of
         :class:`~loopy.kernel.function_interface.InKernelCallable`.
 
@@ -239,7 +239,7 @@ class TranslationUnit:
 
     def __post_init__(self):
         assert isinstance(self.entrypoints, abc_Set)
-        assert isinstance(self.callables_table, Map)
+        assert isinstance(self.callables_table, constantdict)
 
     def copy(self, **kwargs: Any) -> Self:
         target = kwargs.pop("target", None)
@@ -267,7 +267,7 @@ class TranslationUnit:
                 new_callables[func_id] = clbl
 
             t_unit = replace(
-                    self, callables_table=Map(new_callables), target=target)
+                    self, callables_table=constantdict(new_callables), target=target)
 
         return t_unit
 
@@ -302,14 +302,14 @@ class TranslationUnit:
             # update the callable kernel
             new_in_knl_callable = self.callables_table[kernel.name].copy(
                     subkernel=kernel)
-            new_callables = self.callables_table.delete(kernel.name).set(
-                    kernel.name, new_in_knl_callable)
-            return self.copy(callables_table=new_callables)
+            return self.copy(
+                callables_table=self.callables_table.set(
+                                        kernel.name, new_in_knl_callable))
         else:
             # add a new callable kernel
             clbl = CallableKernel(kernel)
-            new_callables = self.callables_table.set(kernel.name, clbl)
-            return self.copy(callables_table=new_callables)
+            return self.copy(
+                callables_table=self.callables_table.set(kernel.name, clbl))
 
     def __getitem__(self, name) -> LoopKernel:
         """
@@ -720,7 +720,7 @@ class CallablesInferenceContext:
 
         # }}}
 
-        return program.copy(callables_table=Map(new_callables))
+        return program.copy(callables_table=constantdict(new_callables))
 
     def __getitem__(self, name):
         result = self.callables[name]
@@ -741,7 +741,7 @@ def make_program(kernel: LoopKernel) -> TranslationUnit:
     """
 
     return TranslationUnit(
-            callables_table=Map({
+            callables_table=constantdict({
                 kernel.name: CallableKernel(kernel)}),
             target=kernel.target,
             entrypoints=frozenset())
@@ -803,7 +803,7 @@ def for_each_kernel(
 
                 new_callables[func_id] = clbl
 
-            return t_unit.copy(callables_table=Map(new_callables))
+            return t_unit.copy(callables_table=constantdict(new_callables))
         elif isinstance(t_unit_or_kernel, LoopKernel):
             kernel = t_unit_or_kernel
             return transform(kernel, *args, **kwargs)
@@ -899,7 +899,7 @@ def resolve_callables(t_unit: TranslationUnit) -> TranslationUnit:
         else:
             raise NotImplementedError(f"{type(clbl)}")
 
-    t_unit = t_unit.copy(callables_table=Map(callables_table))
+    t_unit = t_unit.copy(callables_table=constantdict(callables_table))
 
     validate_kernel_call_sites(t_unit)
 
diff --git a/pyproject.toml b/pyproject.toml
index d61564a7..7916bea5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -41,9 +41,7 @@ dependencies = [
     "codepy>=2017.1",
     "colorama",
     "Mako",
-    "pyrsistent",
-    "immutables",
-    "immutabledict",
+    "constantdict",
 
     "typing-extensions>=4",
 ]
-- 
GitLab