From 6adf9ae0a40f28d2864c7766fabdbf1aa1784344 Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Thu, 5 Oct 2023 23:02:44 -0500
Subject: [PATCH] replace immutables.Map with immutabledict (#461)

* replace immutables.Map with immutabledict

* Use Mapping rather than any concrete type in public type annotations

---------

Co-authored-by: Andreas Kloeckner <inform@tiker.net>
---
 doc/conf.py                                 |  4 +-
 pytato/array.py                             | 99 +++++++++++----------
 pytato/cmath.py                             |  6 +-
 pytato/codegen.py                           | 12 +--
 pytato/distributed/partition.py             |  6 +-
 pytato/function.py                          | 15 ++--
 pytato/loopy.py                             | 27 +++---
 pytato/raising.py                           |  6 +-
 pytato/reductions.py                        |  6 +-
 pytato/scalar_expr.py                       |  4 +-
 pytato/stringifier.py                       |  4 +-
 pytato/target/loopy/__init__.py             |  9 +-
 pytato/target/python/numpy_like.py          |  4 +-
 pytato/transform/__init__.py                | 26 +++---
 pytato/transform/calls.py                   |  5 +-
 pytato/transform/einsum_distributive_law.py | 18 ++--
 pytato/transform/lower_to_index_lambda.py   | 40 ++++-----
 pytato/transform/metadata.py                |  4 +-
 pytato/utils.py                             |  6 +-
 setup.py                                    |  2 +-
 test/test_pytato.py                         |  6 +-
 21 files changed, 163 insertions(+), 146 deletions(-)

diff --git a/doc/conf.py b/doc/conf.py
index b5f851e..081642f 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -46,9 +46,9 @@ sys._BUILDING_SPHINX_DOCS = True
 nitpick_ignore_regex = [
     ["py:class", r"numpy.(u?)int[\d]+"],
     ["py:class", r"typing_extensions(.+)"],
-    # As of 2022-10-20, it doesn't look like there's sphinx documentation
+    # As of 2023-10-05, it doesn't look like there's sphinx documentation
     # available.
-    ["py:class", r"immutables\.(.+)"],
+    ["py:class", r"immutabledict(.*)"],
     # https://github.com/python-attrs/attrs/issues/1073
     ["py:mod", "attrs"],
 ]
diff --git a/pytato/array.py b/pytato/array.py
index 1271631..237d3be 100644
--- a/pytato/array.py
+++ b/pytato/array.py
@@ -178,7 +178,7 @@ from pytato.scalar_expr import (ScalarType, SCALAR_CLASSES,
                                 ScalarExpression, IntegralT,
                                 INT_CLASSES, get_reduction_induction_variables)
 import re
-from immutables import Map
+from immutabledict import immutabledict
 
 # {{{ get a type variable that represents the type of '...'
 
@@ -556,7 +556,7 @@ class Array(Taggable):
             indices = tuple(var(f"_{i}") for i in range(self.ndim))
             expr = op(var("_in0")[indices])
 
-        bindings = Map({"_in0": self})
+        bindings: Mapping[str, Array] = immutabledict({"_in0": self})
         return IndexLambda(
                 expr=expr,
                 shape=self.shape,
@@ -564,7 +564,7 @@ class Array(Taggable):
                 bindings=bindings,
                 tags=_get_default_tags(),
                 axes=_get_default_axes(self.ndim),
-                var_to_reduction_descr=Map())
+                var_to_reduction_descr=immutabledict())
 
     __mul__ = partialmethod(_binary_op, operator.mul)
     __rmul__ = partialmethod(_binary_op, operator.mul, reverse=True)
@@ -895,17 +895,18 @@ class IndexLambda(_SuppliedShapeAndDtypeMixin, Array):
                 f" '{self.var_to_reduction_descr.keys()}',"
                 f" got '{reduction_variable}'.")
 
-        assert isinstance(self.var_to_reduction_descr, Map)
-        new_var_to_redn_descr = self.var_to_reduction_descr.set(
-            reduction_variable,
-            self.var_to_reduction_descr[reduction_variable].tagged(tag))
+        assert isinstance(self.var_to_reduction_descr, immutabledict)
+        new_var_to_redn_descr = dict(self.var_to_reduction_descr)
+        new_var_to_redn_descr[reduction_variable] = \
+            self.var_to_reduction_descr[reduction_variable].tagged(tag)
 
         return type(self)(expr=self.expr,
                           shape=self.shape,
                           dtype=self.dtype,
                           bindings=self.bindings,
                           axes=self.axes,
-                          var_to_reduction_descr=new_var_to_redn_descr,
+                          var_to_reduction_descr=immutabledict
+                            (new_var_to_redn_descr),
                           tags=self.tags)
 
 # }}}
@@ -1006,7 +1007,7 @@ class Einsum(Array):
                 else:
                     descr_to_axis_len[descr] = arg_axis_len
 
-        return Map(descr_to_axis_len)
+        return immutabledict(descr_to_axis_len)
 
     @cached_property
     def shape(self) -> ShapeType:
@@ -1063,14 +1064,16 @@ class Einsum(Array):
 
         # }}}
 
-        assert isinstance(self.redn_axis_to_redn_descr, Map)
-        new_redn_axis_to_redn_descr = self.redn_axis_to_redn_descr.set(
-            redn_axis, self.redn_axis_to_redn_descr[redn_axis].tagged(tag))
+        assert isinstance(self.redn_axis_to_redn_descr, immutabledict)
+        new_redn_axis_to_redn_descr = dict(self.redn_axis_to_redn_descr)
+        new_redn_axis_to_redn_descr[redn_axis] = \
+            self.redn_axis_to_redn_descr[redn_axis].tagged(tag)
 
         return type(self)(access_descriptors=self.access_descriptors,
                           args=self.args,
                           axes=self.axes,
-                          redn_axis_to_redn_descr=new_redn_axis_to_redn_descr,
+                          redn_axis_to_redn_descr=immutabledict
+                            (new_redn_axis_to_redn_descr),
                           tags=self.tags,
                           index_to_access_descr=self.index_to_access_descr,
                           )
@@ -1079,7 +1082,7 @@ class Einsum(Array):
 EINSUM_FIRST_INDEX = re.compile(r"^\s*((?P<alpha>[a-zA-Z])|(?P<ellipsis>\.\.\.))\s*")
 
 
-def _normalize_einsum_out_subscript(subscript: str) -> Map[str,
+def _normalize_einsum_out_subscript(subscript: str) -> immutabledict[str,
                                                             EinsumAxisDescriptor]:
     """
     Normalizes the output subscript of an einsum (provided in the explicit
@@ -1119,19 +1122,20 @@ def _normalize_einsum_out_subscript(subscript: str) -> Map[str,
         raise ValueError("Used an input more than once to refer to the"
                          f" output axis in '{subscript}")
 
-    return Map({idx: EinsumElementwiseAxis(i)
+    return immutabledict({idx: EinsumElementwiseAxis(i)
                  for i, idx in enumerate(normalized_indices)})
 
 
 def _normalize_einsum_in_subscript(subscript: str,
                                    in_operand: Array,
-                                   index_to_descr: Map[str,
+                                   index_to_descr: Mapping[str,
                                                         EinsumAxisDescriptor],
-                                   index_to_axis_length: Map[str,
+                                   index_to_axis_length: Mapping[str,
                                                                ShapeComponent],
                                    ) -> Tuple[Tuple[EinsumAxisDescriptor, ...],
-                                              Map[str, EinsumAxisDescriptor],
-                                              Map[str, ShapeComponent]]:
+                                              immutabledict
+                                                [str, EinsumAxisDescriptor],
+                                              immutabledict[str, ShapeComponent]]:
     """
     Normalizes the subscript for an input operand in an einsum. Returns
     ``(access_descrs, updated_index_to_descr, updated_to_index_to_axis_length)``,
@@ -1174,12 +1178,14 @@ def _normalize_einsum_in_subscript(subscript: str,
                          f"of corresponding operand ({in_operand.ndim}).")
 
     in_operand_axis_descrs = []
+    index_to_axis_length_dict = dict(index_to_axis_length)
+    index_to_descr_dict = dict(index_to_descr)
 
     for iaxis, index_char in enumerate(normalized_indices):
         in_axis_len = in_operand.shape[iaxis]
-        if index_char in index_to_descr:
-            if index_char in index_to_axis_length:
-                seen_axis_len = index_to_axis_length[index_char]
+        if index_char in index_to_descr_dict:
+            if index_char in index_to_axis_length_dict:
+                seen_axis_len = index_to_axis_length_dict[index_char]
                 if not are_shape_components_equal(in_axis_len,
                                                   seen_axis_len):
                     if are_shape_components_equal(in_axis_len, 1):
@@ -1187,24 +1193,24 @@ def _normalize_einsum_in_subscript(subscript: str,
                         pass
                     elif are_shape_components_equal(seen_axis_len, 1):
                         # Broadcast to the length of the current axis
-                        index_to_axis_length = (index_to_axis_length
-                                                .set(index_char, in_axis_len))
+                        index_to_axis_length_dict[index_char] = in_axis_len
                     else:
                         raise ValueError("Got conflicting lengths for"
                                          f" '{index_char}' -- {in_axis_len},"
                                          f" {seen_axis_len}.")
             else:
-                index_to_axis_length = index_to_axis_length.set(index_char,
-                                                                in_axis_len)
+                index_to_axis_length_dict[index_char] = in_axis_len
         else:
-            redn_sr_no = len([descr for descr in index_to_descr.values()
+            redn_sr_no = len([descr for descr in index_to_descr_dict.values()
                               if isinstance(descr, EinsumReductionAxis)])
             redn_axis_descr = EinsumReductionAxis(redn_sr_no)
-            index_to_descr = index_to_descr.set(index_char, redn_axis_descr)
-            index_to_axis_length = index_to_axis_length.set(index_char,
-                                                             in_axis_len)
+            index_to_descr_dict[index_char] = redn_axis_descr
+            index_to_axis_length_dict[index_char] = in_axis_len
 
-        in_operand_axis_descrs.append(index_to_descr[index_char])
+        in_operand_axis_descrs.append(index_to_descr_dict[index_char])
+
+    index_to_axis_length = immutabledict(index_to_axis_length_dict)
+    index_to_descr = immutabledict(index_to_descr_dict)
 
     return (tuple(in_operand_axis_descrs), index_to_descr, index_to_axis_length)
 
@@ -1239,7 +1245,7 @@ def einsum(subscripts: str, *operands: Array,
         )
 
     index_to_descr = _normalize_einsum_out_subscript(out_spec)
-    index_to_axis_length: Map[str, ShapeComponent] = Map()
+    index_to_axis_length: Mapping[str, ShapeComponent] = immutabledict()
     access_descriptors = []
 
     for in_spec, in_operand in zip(in_specs, operands):
@@ -1274,7 +1280,7 @@ def einsum(subscripts: str, *operands: Array,
                                               if isinstance(descr,
                                                             EinsumElementwiseAxis)})
                                          ),
-                  redn_axis_to_redn_descr=Map(redn_axis_to_redn_descr),
+                  redn_axis_to_redn_descr=immutabledict(redn_axis_to_redn_descr),
                   index_to_access_descr=index_to_descr,
                   )
 
@@ -2088,10 +2094,10 @@ def full(shape: ConvertibleToShape, fill_value: ScalarType,
         fill_value = dtype.type(fill_value)
 
     return IndexLambda(expr=fill_value, shape=shape, dtype=dtype,
-                       bindings=Map(),
+                       bindings=immutabledict(),
                        tags=_get_default_tags(),
                        axes=_get_default_axes(len(shape)),
-                       var_to_reduction_descr=Map())
+                       var_to_reduction_descr=immutabledict())
 
 
 def zeros(shape: ConvertibleToShape, dtype: Any = float,
@@ -2134,10 +2140,10 @@ def eye(N: int, M: Optional[int] = None, k: int = 0,  # noqa: N803
         raise ValueError(f"k must be int, got {type(k)}.")
 
     return IndexLambda(expr=parse(f"1 if ((_1 - _0) == {k}) else 0"),
-                       shape=(N, M), dtype=dtype, bindings=Map({}),
+                       shape=(N, M), dtype=dtype, bindings=immutabledict({}),
                        tags=_get_default_tags(),
                        axes=_get_default_axes(2),
-                       var_to_reduction_descr=Map())
+                       var_to_reduction_descr=immutabledict())
 
 # }}}
 
@@ -2229,10 +2235,10 @@ def arange(*args: Any, **kwargs: Any) -> Array:
 
     from pymbolic.primitives import Variable
     return IndexLambda(expr=start + Variable("_0") * step,
-                       shape=(size,), dtype=dtype, bindings=Map(),
+                       shape=(size,), dtype=dtype, bindings=immutabledict(),
                        tags=_get_default_tags(),
                        axes=_get_default_axes(1),
-                       var_to_reduction_descr=Map())
+                       var_to_reduction_descr=immutabledict())
 
 # }}}
 
@@ -2343,7 +2349,7 @@ def logical_not(x: ArrayOrScalar) -> Union[Array, bool]:
                        bindings={"_in0": x},
                        tags=_get_default_tags(),
                        axes=_get_default_axes(len(x.shape)),
-                       var_to_reduction_descr=Map())
+                       var_to_reduction_descr=immutabledict())
 
 # }}}
 
@@ -2396,10 +2402,10 @@ def where(condition: ArrayOrScalar,
             expr=prim.If(expr1, expr2, expr3),
             shape=result_shape,
             dtype=dtype,
-            bindings=Map(bindings),
+            bindings=immutabledict(bindings),
             tags=_get_default_tags(),
             axes=_get_default_axes(len(result_shape)),
-            var_to_reduction_descr=Map())
+            var_to_reduction_descr=immutabledict())
 
 # }}}
 
@@ -2492,12 +2498,13 @@ def make_index_lambda(
     # }}}
 
     return IndexLambda(expr=expression,
-                       bindings=Map(bindings),
+                       bindings=immutabledict(bindings),
                        shape=shape,
                        dtype=dtype,
                        tags=_get_default_tags(),
                        axes=_get_default_axes(len(shape)),
-                       var_to_reduction_descr=Map(processed_var_to_reduction_descr))
+                       var_to_reduction_descr=immutabledict
+                        (processed_var_to_reduction_descr))
 
 # }}}
 
@@ -2578,10 +2585,10 @@ def broadcast_to(array: Array, shape: ShapeType) -> Array:
                                                                    shape)),
                        shape=shape,
                        dtype=array.dtype,
-                       bindings=Map({"in": array}),
+                       bindings=immutabledict({"in": array}),
                        tags=_get_default_tags(),
                        axes=_get_default_axes(len(shape)),
-                       var_to_reduction_descr=Map())
+                       var_to_reduction_descr=immutabledict())
 
 # }}}
 
diff --git a/pytato/cmath.py b/pytato/cmath.py
index 88fca8f..38c520c 100644
--- a/pytato/cmath.py
+++ b/pytato/cmath.py
@@ -62,7 +62,7 @@ from pytato.array import (Array, ArrayOrScalar, IndexLambda, _dtype_any,
                           _get_default_axes, _get_default_tags)
 from pytato.scalar_expr import SCALAR_CLASSES
 from pymbolic import var
-from immutables import Map
+from immutabledict import immutabledict
 
 
 def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...],
@@ -113,10 +113,10 @@ def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...],
     return IndexLambda(
         expr=prim.Call(var(f"pytato.c99.{func_name}"),
                   tuple(sym_args)),
-        shape=shape, dtype=ret_dtype, bindings=Map(bindings),
+        shape=shape, dtype=ret_dtype, bindings=immutabledict(bindings),
         tags=_get_default_tags(),
         axes=_get_default_axes(len(shape)),
-        var_to_reduction_descr=Map(),
+        var_to_reduction_descr=immutabledict(),
     )
 
 
diff --git a/pytato/codegen.py b/pytato/codegen.py
index 0bc85d6..68f4645 100644
--- a/pytato/codegen.py
+++ b/pytato/codegen.py
@@ -23,7 +23,8 @@ THE SOFTWARE.
 """
 
 import dataclasses
-from typing import Union, Dict, Tuple, List, Any, Optional
+from typing import Union, Dict, Tuple, List, Any, Optional, Mapping
+from immutabledict import immutabledict
 
 from pytato.array import (Array, DictOfNamedArrays, DataWrapper, Placeholder,
                           DataInterface, SizeParam, InputArgumentBase,
@@ -173,9 +174,10 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper):  # type: ignore[misc]
 
         # }}}
 
-        bindings = {name: (self.rec(subexpr) if isinstance(subexpr, Array)
+        bindings: Mapping[str, Any] = immutabledict(
+                    {name: (self.rec(subexpr) if isinstance(subexpr, Array)
                            else subexpr)
-                    for name, subexpr in sorted(expr.bindings.items())}
+                    for name, subexpr in sorted(expr.bindings.items())})
 
         return LoopyCall(translation_unit=translation_unit,
                          bindings=bindings,
@@ -282,8 +284,8 @@ def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult:
                                                 for out in outputs.values()))
 
     # only look for dependencies between the outputs
-    deps = {name: get_deps(output.expr)
-            for name, output in outputs.items()}
+    deps: Mapping[str, Any] = immutabledict({name: get_deps(output.expr)
+            for name, output in outputs.items()})
 
     # represent deps in terms of output names
     output_expr_to_name = {output.expr: name for name, output in outputs.items()}
diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py
index c8ab048..426a8cf 100644
--- a/pytato/distributed/partition.py
+++ b/pytato/distributed/partition.py
@@ -68,7 +68,7 @@ from typing import (
         List, AbstractSet, TypeVar, TYPE_CHECKING, Hashable, Optional)
 
 import attrs
-from immutables import Map
+from immutabledict import immutabledict
 
 from pytools.graph import CycleError
 from pytools import memoize_method
@@ -405,11 +405,11 @@ def _make_distributed_partition(
                 partition_input_names=frozenset(
                     comm_replacer.partition_input_name_to_placeholder.keys()),
                 output_names=frozenset(name_to_ouput.keys()),
-                name_to_recv_node=Map({
+                name_to_recv_node=immutabledict({
                     recvd_ary_to_name[local_recv_id_to_recv_node[recv_id]]:
                     local_recv_id_to_recv_node[recv_id]
                     for recv_id in comm_ids.recv_ids}),
-                name_to_send_nodes=Map(name_to_send_nodes))
+                name_to_send_nodes=immutabledict(name_to_send_nodes))
 
     result = DistributedGraphPartition(
             parts=parts,
diff --git a/pytato/function.py b/pytato/function.py
index b053831..14e0622 100644
--- a/pytato/function.py
+++ b/pytato/function.py
@@ -49,7 +49,7 @@ import enum
 
 from typing import (Callable, Dict, FrozenSet, Tuple, Union, TypeVar, Optional,
                     Hashable, Sequence, ClassVar, Iterator, Iterable, Mapping)
-from immutables import Map
+from immutabledict import immutabledict
 from functools import cached_property
 from pytato.array import (Array,  AbstractResultWithNamedArrays,
                           Placeholder, NamedArray, ShapeType, _dtype_any,
@@ -126,7 +126,7 @@ class FunctionDefinition(Taggable):
     """
     parameters: FrozenSet[str]
     return_type: ReturnType
-    returns: Map[str, Array]
+    returns: Mapping[str, Array]
     tags: FrozenSet[Tag] = attrs.field(kw_only=True)
 
     @cached_property
@@ -142,7 +142,7 @@ class FunctionDefinition(Taggable):
             frozenset()
         )
 
-        return Map({input_arg.name: input_arg
+        return immutabledict({input_arg.name: input_arg
                     for input_arg in all_input_args
                     if isinstance(input_arg, Placeholder)})
 
@@ -188,7 +188,8 @@ class FunctionDefinition(Taggable):
 
         # }}}
 
-        call_site = Call(self, bindings=Map(kwargs), tags=_get_default_tags())
+        call_site = Call(self, bindings=immutabledict(kwargs),
+                         tags=_get_default_tags())
 
         if self.return_type == ReturnType.ARRAY:
             return call_site["_"]
@@ -253,7 +254,7 @@ class NamedCallResult(NamedArray):
         return self._container.function.returns[self.name].dtype
 
 
-# eq=False to avoid equality comparison without EqualityMaper
+# eq=False to avoid equality comparison without EqualityMapper
 @attrs.define(frozen=True, eq=False, hash=True, cache_hash=True, repr=False)
 class Call(AbstractResultWithNamedArrays):
     """
@@ -270,7 +271,7 @@ class Call(AbstractResultWithNamedArrays):
 
     """
     function: FunctionDefinition
-    bindings: Map[str, Array]
+    bindings: Mapping[str, Array]
 
     _mapper_method: ClassVar[str] = "map_call"
 
@@ -371,7 +372,7 @@ def trace_call(f: Callable[..., ReturnT],
     function = FunctionDefinition(
         frozenset(pl_arg.name for pl_arg in pl_args) | frozenset(pl_kwargs),
         return_type,
-        Map(returns),
+        immutabledict(returns),
         tags=_get_default_tags() | (frozenset([FunctionIdentifier(identifier)])
                                     if identifier
                                     else frozenset())
diff --git a/pytato/loopy.py b/pytato/loopy.py
index 3d1ee15..f4bb1fb 100644
--- a/pytato/loopy.py
+++ b/pytato/loopy.py
@@ -37,7 +37,7 @@ from pytato.array import (AbstractResultWithNamedArrays, Array, ShapeType,
 from pytato.scalar_expr import (SubstitutionMapper, ScalarExpression,
                                 EvaluationMapper, IntegralT)
 from pytools import memoize_method
-from immutables import Map
+from immutabledict import immutabledict
 import islpy as isl
 
 __doc__ = r"""
@@ -78,7 +78,7 @@ class LoopyCall(AbstractResultWithNamedArrays):
     :mod:`loopy` translation unit.
     """
     translation_unit: "lp.TranslationUnit"
-    bindings: Dict[str, ArrayOrScalar]
+    bindings: Mapping[str, ArrayOrScalar]
     entrypoint: str
 
     _mapper_method: ClassVar[str] = "map_loopy_call"
@@ -212,18 +212,19 @@ def call_loopy(translation_unit: "lp.TranslationUnit",
 
     # {{{ perform shape inference here
 
-    bindings = extend_bindings_with_shape_inference(translation_unit[entrypoint],
-                                                    Map(bindings))
+    bindings_new = extend_bindings_with_shape_inference(translation_unit[entrypoint],
+                                                    immutabledict(bindings))
+    del bindings
 
     # }}}
 
     for arg in translation_unit[entrypoint].args:
         if arg.is_input:
-            if arg.name not in bindings:
+            if arg.name not in bindings_new:
                 raise ValueError(f"Kernel '{entrypoint}' expects an input"
                         f" '{arg.name}'")
 
-            arg_binding = bindings[arg.name]
+            arg_binding = bindings_new[arg.name]
 
             if isinstance(arg, (lp.ArrayArg, lp.ConstantArg)):
                 if not isinstance(arg_binding, Array):
@@ -242,7 +243,7 @@ def call_loopy(translation_unit: "lp.TranslationUnit",
 
     # {{{ infer types of the translation_unit
 
-    for name, ary in bindings.items():
+    for name, ary in bindings_new.items():
         if translation_unit[entrypoint].arg_dict[name].dtype not in [lp.auto,
                                                                      None]:
             continue
@@ -265,7 +266,7 @@ def call_loopy(translation_unit: "lp.TranslationUnit",
 
     translation_unit = translation_unit.with_entrypoints(frozenset())
 
-    return LoopyCall(translation_unit, bindings, entrypoint,
+    return LoopyCall(translation_unit, bindings_new, entrypoint,
                      tags=_get_default_tags())
 
 
@@ -379,8 +380,8 @@ def _get_pt_dim_expr(dim: Union[IntegralT, Array]) -> ScalarExpression:
 
 
 def extend_bindings_with_shape_inference(knl: lp.LoopKernel,
-                                         bindings: Map[str, ArrayOrScalar]
-                                         ) -> Dict[str, ArrayOrScalar]:
+                                         bindings: Mapping[str, ArrayOrScalar]
+                                         ) -> immutabledict[str, ArrayOrScalar]:
     from functools import reduce
     from loopy.symbolic import get_dependencies as lpy_get_deps
     from loopy.kernel.array import ArrayBase
@@ -478,6 +479,8 @@ def extend_bindings_with_shape_inference(knl: lp.LoopKernel,
     as_pt_size_param = EvaluationMapper({_pt_var_to_global_namespace(arg.name): arg
                                          for arg in pt_size_params})
 
+    bindings_dict = dict(bindings)
+
     for var, val in solutions.items():
         # map the pymbolic expression back into an expression in terms of
         # pt.SizeParams
@@ -494,9 +497,9 @@ def extend_bindings_with_shape_inference(knl: lp.LoopKernel,
 
         # }}}
 
-        bindings = bindings.set(var, val)
+        bindings_dict[var] = val
 
-    return dict(bindings)
+    return immutabledict(bindings_dict)
 
 # }}}
 
diff --git a/pytato/raising.py b/pytato/raising.py
index 3cdf77e..d5fb286 100644
--- a/pytato/raising.py
+++ b/pytato/raising.py
@@ -11,7 +11,7 @@ from pytato.utils import (get_indexing_expression,
 from pytato.scalar_expr import ScalarType, ScalarExpression, Reduce, SCALAR_CLASSES
 from pytato.reductions import ReductionOperation
 from dataclasses import dataclass
-from immutables import Map
+from immutabledict import immutabledict
 
 
 __doc__ = """
@@ -101,7 +101,7 @@ class ReduceOp(HighLevelOp):
     """
     op: ReductionOperation
     x: Array
-    axes: Map[int, str]
+    axes: Mapping[int, str]
 
 # }}}
 
@@ -319,7 +319,7 @@ def index_lambda_to_high_level_op(expr: IndexLambda) -> HighLevelOp:
                                       .expr
                                       .inner_expr
                                       .aggregate.name],
-                        axes=Map({i: idx.name
+                        axes=immutabledict({i: idx.name
                                   for i, idx in enumerate(expr
                                                           .expr
                                                           .inner_expr
diff --git a/pytato/reductions.py b/pytato/reductions.py
index cd4f509..6497e1d 100644
--- a/pytato/reductions.py
+++ b/pytato/reductions.py
@@ -34,7 +34,7 @@ import numpy as np
 
 from pytato.array import ShapeType, Array, make_index_lambda, ReductionDescriptor
 from pytato.scalar_expr import ScalarExpression, Reduce, INT_CLASSES
-from immutables import Map
+from immutabledict import immutabledict
 import pymbolic.primitives as prim
 
 # {{{ docs
@@ -213,7 +213,7 @@ def _get_reduction_indices_bounds(shape: ShapeType,
             indices.append(prim.Variable(f"_{n_out_dims}"))
             n_out_dims += 1
 
-    return indices, Map(redn_bounds)
+    return indices, immutabledict(redn_bounds)
 
 
 def _get_var_to_redn_descr(shape: ShapeType,
@@ -258,7 +258,7 @@ def _get_var_to_redn_descr(shape: ShapeType,
             var_to_redn_descr[idx] = redn_descr
             n_redn_dims += 1
 
-    return Map(var_to_redn_descr)
+    return immutabledict(var_to_redn_descr)
 
 
 def _make_reduction_lambda(
diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py
index 98724ad..0e21c28 100644
--- a/pytato/scalar_expr.py
+++ b/pytato/scalar_expr.py
@@ -42,7 +42,7 @@ from pymbolic.mapper.stringifier import (StringifyMapper as
         StringifyMapperBase)
 from pymbolic.mapper import CombineMapper as CombineMapperBase
 from pymbolic.mapper.collector import TermCollector as TermCollectorBase
-from immutables import Map
+from immutabledict import immutabledict
 import pymbolic.primitives as prim
 import numpy as np
 import re
@@ -113,7 +113,7 @@ class SubstitutionMapper(SubstitutionMapperBase):
     def map_reduce(self, expr: Reduce) -> ScalarExpression:
         return Reduce(self.rec(expr.inner_expr),
                       op=expr.op,
-                      bounds=Map(
+                      bounds=immutabledict(
                           {name: self.rec(bound)
                            for name, bound in expr.bounds.items()}))
 
diff --git a/pytato/stringifier.py b/pytato/stringifier.py
index e370279..8aac8d3 100644
--- a/pytato/stringifier.py
+++ b/pytato/stringifier.py
@@ -31,7 +31,7 @@ from pytato.transform import Mapper
 from pytato.array import (Array, DataWrapper, DictOfNamedArrays, Axis,
                           IndexLambda, ReductionDescriptor)
 from pytato.loopy import LoopyCall
-from immutables import Map
+from immutabledict import immutabledict
 import attrs
 
 
@@ -77,7 +77,7 @@ class Reprifier(Mapper):
     def map_foreign(self, expr: Any, depth: int) -> str:  # type: ignore[override]
         if isinstance(expr, tuple):
             return "(" + ", ".join(self.rec(el, depth) for el in expr) + ")"
-        elif isinstance(expr, (dict, Map)):
+        elif isinstance(expr, (dict, immutabledict)):
             return ("{"
                     + ", ".join(f"{key!r}: {self.rec(val, depth)}"
                                 for key, val
diff --git a/pytato/target/loopy/__init__.py b/pytato/target/loopy/__init__.py
index 68d051b..34278db 100644
--- a/pytato/target/loopy/__init__.py
+++ b/pytato/target/loopy/__init__.py
@@ -54,7 +54,7 @@ from dataclasses import dataclass, field
 from functools import cached_property
 
 from typing import Any, Mapping, Optional, Callable, Dict, TYPE_CHECKING
-from immutables import Map
+from immutabledict import immutabledict
 
 from pytato.target import Target, BoundProgram
 from pytato.tags import ImplementationStrategy
@@ -137,7 +137,8 @@ class BoundPyOpenCLProgram(BoundProgram):
     """
     program: loopy.TranslationUnit
     _processed_bound_args_cache: Dict[pyopencl.Context,
-                                      Map[str, Any]] = field(default_factory=dict)
+                                      Mapping[str, Any]] = \
+                                        field(default_factory=dict)
 
     def copy(self, *,
              program: Optional[loopy.TranslationUnit] = None,
@@ -169,7 +170,7 @@ class BoundPyOpenCLProgram(BoundProgram):
                                        queue: pyopencl.CommandQueue,
                                        allocator: Optional[Callable[
                                            [int], pyopencl.MemoryObject]],
-                                       ) -> Map[str, Any]:
+                                       ) -> Mapping[str, Any]:
         import pyopencl.array as cla
 
         cache_key = queue.context
@@ -193,7 +194,7 @@ class BoundPyOpenCLProgram(BoundProgram):
                                     " numpy array, pyopencl array or scalar."
                                     f" Got {type(bnd_arg).__name__} for '{name}'.")
 
-            result = Map(proc_bnd_args)
+            result: Mapping[str, Any] = immutabledict(proc_bnd_args)
             assert set(result.keys()) == set(self.bound_arguments.keys())
             self._processed_bound_args_cache[cache_key] = result
             return result
diff --git a/pytato/target/python/numpy_like.py b/pytato/target/python/numpy_like.py
index 10c5fe4..998208b 100644
--- a/pytato/target/python/numpy_like.py
+++ b/pytato/target/python/numpy_like.py
@@ -38,7 +38,7 @@ from pytato.array import (Stack, Concatenate, IndexLambda, DataWrapper,
                           Reshape, Array, DictOfNamedArrays, IndexBase,
                           DataInterface, NormalizedSlice, ShapeComponent,
                           IndexExpr, ArrayOrScalar, NamedArray)
-from immutables import Map
+from immutabledict import immutabledict
 from pytato.scalar_expr import SCALAR_CLASSES
 from pytato.utils import are_shape_components_equal
 from pytato.raising import BinaryOpType, C99CallOp
@@ -601,4 +601,4 @@ def generate_numpy_like(expr: Union[Array, Mapping[str, Array], DictOfNamedArray
         program,
         function_name,
         expected_arguments=frozenset(cgen_mapper.arg_names),
-        bound_arguments=Map(cgen_mapper.bound_arguments))
+        bound_arguments=immutabledict(cgen_mapper.bound_arguments))
diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py
index b7ac02b..5dfb382 100644
--- a/pytato/transform/__init__.py
+++ b/pytato/transform/__init__.py
@@ -30,7 +30,7 @@ THE SOFTWARE.
 
 import logging
 import numpy as np
-from immutables import Map
+from immutabledict import immutabledict
 from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic,
                     List, Mapping, Iterable, Tuple, Optional, TYPE_CHECKING,
                     Hashable)
@@ -261,7 +261,7 @@ class CopyMapper(CachedMapper[ArrayOrNames]):
                      for s in situp)
 
     def map_index_lambda(self, expr: IndexLambda) -> Array:
-        bindings: Mapping[str, Array] = Map({
+        bindings: Mapping[str, Array] = immutabledict({
                 name: self.rec(subexpr)
                 for name, subexpr in sorted(expr.bindings.items())})
         return IndexLambda(expr=expr.expr,
@@ -354,9 +354,10 @@ class CopyMapper(CachedMapper[ArrayOrNames]):
                                  )
 
     def map_loopy_call(self, expr: LoopyCall) -> LoopyCall:
-        bindings = {name: (self.rec(subexpr) if isinstance(subexpr, Array)
+        bindings: Mapping[Any, Any] = immutabledict(
+                    {name: (self.rec(subexpr) if isinstance(subexpr, Array)
                            else subexpr)
-                    for name, subexpr in sorted(expr.bindings.items())}
+                    for name, subexpr in sorted(expr.bindings.items())})
 
         return LoopyCall(translation_unit=expr.translation_unit,
                          bindings=bindings,
@@ -406,13 +407,13 @@ class CopyMapper(CachedMapper[ArrayOrNames]):
                        for name, ret in expr.returns.items()}
         return FunctionDefinition(expr.parameters,
                                   expr.return_type,
-                                  Map(new_returns),
+                                  immutabledict(new_returns),
                                   tags=expr.tags
                                   )
 
     def map_call(self, expr: Call) -> AbstractResultWithNamedArrays:
         return Call(self.map_function_definition(expr.function),
-                    Map({name: self.rec(bnd)
+                    immutabledict({name: self.rec(bnd)
                          for name, bnd in expr.bindings.items()}),
                     tags=expr.tags,
                     )
@@ -578,10 +579,11 @@ class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]):
 
     def map_loopy_call(self, expr: LoopyCall,
                        *args: Any, **kwargs: Any) -> LoopyCall:
-        bindings = {name: (self.rec(subexpr, *args, **kwargs)
+        bindings: Mapping[Any, Any] = immutabledict(
+                    {name: (self.rec(subexpr, *args, **kwargs)
                            if isinstance(subexpr, Array)
                            else subexpr)
-                    for name, subexpr in sorted(expr.bindings.items())}
+                    for name, subexpr in sorted(expr.bindings.items())})
 
         return LoopyCall(translation_unit=expr.translation_unit,
                          bindings=bindings,
@@ -634,7 +636,7 @@ class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]):
     def map_call(self, expr: Call,
                  *args: Any, **kwargs: Any) -> AbstractResultWithNamedArrays:
         return Call(self.map_function_definition(expr.function, *args, **kwargs),
-                    Map({name: self.rec(bnd, *args, **kwargs)
+                    immutabledict({name: self.rec(bnd, *args, **kwargs)
                          for name, bnd in expr.bindings.items()}),
                     tags=expr.tags,
                     )
@@ -1312,7 +1314,7 @@ class MPMSMaterializer(Mapper):
         new_expr = IndexLambda(expr=expr.expr,
                                shape=expr.shape,
                                dtype=expr.dtype,
-                               bindings=Map({bnd_name: bnd.expr
+                               bindings=immutabledict({bnd_name: bnd.expr
                                 for bnd_name, bnd in sorted(children_rec.items())}),
                                axes=expr.axes,
                                var_to_reduction_descr=expr.var_to_reduction_descr,
@@ -1441,13 +1443,13 @@ class MPMSMaterializer(Mapper):
         new_returns = {name: new_mapper(ret) for name, ret in expr.returns.items()}
         return FunctionDefinition(expr.parameters,
                                   expr.return_type,
-                                  Map(new_returns),
+                                  immutabledict(new_returns),
                                   tags=expr.tags)
 
     @memoize_method
     def map_call(self, expr: Call) -> Call:
         return Call(self.map_function_definition(expr.function),
-                        Map({name: self.rec(bnd).expr
+                        immutabledict({name: self.rec(bnd).expr
                              for name, bnd in expr.bindings.items()}),
                         tags=expr.tags)
 
diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py
index a25dd83..a9d8d06 100644
--- a/pytato/transform/calls.py
+++ b/pytato/transform/calls.py
@@ -26,7 +26,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from immutables import Map
+from typing import Mapping
 from pytato.transform import (ArrayOrNames, CopyMapper)
 from pytato.array import (AbstractResultWithNamedArrays, Array,
                           DictOfNamedArrays, Placeholder)
@@ -44,7 +44,8 @@ class PlaceholderSubstitutor(CopyMapper):
         A mapping from the placeholder name to the array that it is to be
         substituted with.
     """
-    def __init__(self, substitutions: Map[str, Array]) -> None:
+
+    def __init__(self, substitutions: Mapping[str, Array]) -> None:
         super().__init__()
         self.substitutions = substitutions
 
diff --git a/pytato/transform/einsum_distributive_law.py b/pytato/transform/einsum_distributive_law.py
index 106c5a4..9ce210a 100644
--- a/pytato/transform/einsum_distributive_law.py
+++ b/pytato/transform/einsum_distributive_law.py
@@ -32,7 +32,7 @@ THE SOFTWARE.
 """
 
 
-from typing import Callable, Dict, Tuple, Optional, FrozenSet
+from typing import Callable, Dict, Tuple, Optional, FrozenSet, Mapping
 import attrs
 from pytato.transform import ArrayOrNames, Mapper, MappedT
 from pytato.array import (Array, AxesT, Einsum, IndexLambda,
@@ -41,7 +41,7 @@ from pytato.array import (Array, AxesT, Einsum, IndexLambda,
                           Stack, Concatenate, Roll, AxisPermutation,
                           IndexBase, Reshape, InputArgumentBase)
 from pytato.raising import HighLevelOp
-from immutables import Map
+from immutabledict import immutabledict
 from pytools.tag import Tag
 from pytato.utils import are_shapes_equal
 import numpy as np
@@ -74,10 +74,10 @@ class DoDistribute(EinsumDistributiveLawDescriptor):
 @attrs.frozen
 class _EinsumDistributiveLawMapperContext:
     access_descriptors: Tuple[Tuple[EinsumAxisDescriptor, ...], ...]
-    surrounding_args: Map[int, Array]
-    redn_axis_to_redn_descr: Map[EinsumReductionAxis,
+    surrounding_args: Mapping[int, Array]
+    redn_axis_to_redn_descr: Mapping[EinsumReductionAxis,
                                  ReductionDescriptor]
-    index_to_access_descr: Map[str, EinsumAxisDescriptor]
+    index_to_access_descr: Mapping[str, EinsumAxisDescriptor]
     axes: AxesT = attrs.field(kw_only=True)
     tags: FrozenSet[Tag] = attrs.field(kw_only=True)
 
@@ -223,7 +223,7 @@ class EinsumDistributiveLawMapper(Mapper):
                 expr=expr.expr,
                 shape=expr.shape,
                 dtype=expr.dtype,
-                bindings=Map({name: self.rec(bnd, None)
+                bindings=immutabledict({name: self.rec(bnd, None)
                               for name, bnd in sorted(expr.bindings.items())}),
                 var_to_reduction_descr=expr.var_to_reduction_descr,
                 tags=expr.tags,
@@ -243,11 +243,11 @@ class EinsumDistributiveLawMapper(Mapper):
             else:
                 ctx = _EinsumDistributiveLawMapperContext(
                     expr.access_descriptors,
-                    Map({iarg: arg
+                    immutabledict({iarg: arg
                          for iarg, arg in enumerate(expr.args)
                          if iarg != distributive_law_descr.ioperand}),
-                    Map(expr.redn_axis_to_redn_descr),
-                    Map(expr.index_to_access_descr),
+                    immutabledict(expr.redn_axis_to_redn_descr),
+                    immutabledict(expr.index_to_access_descr),
                     tags=expr.tags,
                     axes=expr.axes,
                 )
diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py
index 4aa1d4c..ad16fac 100644
--- a/pytato/transform/lower_to_index_lambda.py
+++ b/pytato/transform/lower_to_index_lambda.py
@@ -29,7 +29,7 @@ THE SOFTWARE.
 import pymbolic.primitives as prim
 
 from typing import List, Any, Dict, Tuple, TypeVar, TYPE_CHECKING
-from immutables import Map
+from immutabledict import immutabledict
 from pytools import UniqueNameGenerator
 from pytato.array import (Array, IndexLambda, Stack, Concatenate,
                           Einsum, Reshape, Roll, AxisPermutation,
@@ -96,7 +96,7 @@ class ToIndexLambdaMixin:
         return IndexLambda(expr=expr.expr,
                            shape=self._rec_shape(expr.shape),
                            dtype=expr.dtype,
-                           bindings=Map({name: self.rec(bnd)
+                           bindings=immutabledict({name: self.rec(bnd)
                                          for name, bnd
                                          in sorted(expr.bindings.items())}),
                            axes=expr.axes,
@@ -133,8 +133,8 @@ class ToIndexLambdaMixin:
                            shape=self._rec_shape(expr.shape),
                            dtype=expr.dtype,
                            axes=expr.axes,
-                           bindings=Map(bindings),
-                           var_to_reduction_descr=Map(),
+                           bindings=immutabledict(bindings),
+                           var_to_reduction_descr=immutabledict(),
                            tags=expr.tags)
 
     def map_concatenate(self, expr: Concatenate) -> IndexLambda:
@@ -180,9 +180,9 @@ class ToIndexLambdaMixin:
         return IndexLambda(expr=concat_expr,
                            shape=self._rec_shape(expr.shape),
                            dtype=expr.dtype,
-                           bindings=Map(bindings),
+                           bindings=immutabledict(bindings),
                            axes=expr.axes,
-                           var_to_reduction_descr=Map(),
+                           var_to_reduction_descr=immutabledict(),
                            tags=expr.tags)
 
     def map_einsum(self, expr: Einsum) -> IndexLambda:
@@ -249,9 +249,9 @@ class ToIndexLambdaMixin:
         return IndexLambda(expr=inner_expr,
                            shape=self._rec_shape(expr.shape),
                            dtype=expr.dtype,
-                           bindings=Map(bindings),
+                           bindings=immutabledict(bindings),
                            axes=expr.axes,
-                           var_to_reduction_descr=Map(var_to_redn_descr),
+                           var_to_reduction_descr=immutabledict(var_to_redn_descr),
                            tags=expr.tags)
 
     def map_roll(self, expr: Roll) -> IndexLambda:
@@ -275,10 +275,10 @@ class ToIndexLambdaMixin:
         return IndexLambda(expr=index_expr,
                            shape=self._rec_shape(expr.shape),
                            dtype=expr.dtype,
-                           bindings=Map({name: self.rec(bnd)
+                           bindings=immutabledict({name: self.rec(bnd)
                                      for name, bnd in bindings.items()}),
                            axes=expr.axes,
-                           var_to_reduction_descr=Map(),
+                           var_to_reduction_descr=immutabledict(),
                            tags=expr.tags)
 
     def map_contiguous_advanced_index(self,
@@ -338,11 +338,11 @@ class ToIndexLambdaMixin:
 
         return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary),
                                                tuple(indices)),
-                           bindings=Map(bindings),
+                           bindings=immutabledict(bindings),
                            shape=self._rec_shape(expr.shape),
                            dtype=expr.dtype,
                            axes=expr.axes,
-                           var_to_reduction_descr=Map(),
+                           var_to_reduction_descr=immutabledict(),
                            tags=expr.tags,
                            )
 
@@ -400,11 +400,11 @@ class ToIndexLambdaMixin:
 
         return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary),
                                                tuple(indices)),
-                           bindings=Map(bindings),
+                           bindings=immutabledict(bindings),
                            shape=self._rec_shape(expr.shape),
                            dtype=expr.dtype,
                            axes=expr.axes,
-                           var_to_reduction_descr=Map(),
+                           var_to_reduction_descr=immutabledict(),
                            tags=expr.tags,
                            )
 
@@ -433,11 +433,11 @@ class ToIndexLambdaMixin:
 
         return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary),
                                                tuple(indices)),
-                           bindings=Map(bindings),
+                           bindings=immutabledict(bindings),
                            shape=self._rec_shape(expr.shape),
                            dtype=expr.dtype,
                            axes=expr.axes,
-                           var_to_reduction_descr=Map(),
+                           var_to_reduction_descr=immutabledict(),
                            tags=expr.tags,
                            )
 
@@ -447,9 +447,9 @@ class ToIndexLambdaMixin:
         return IndexLambda(expr=index_expr,
                            shape=self._rec_shape(expr.shape),
                            dtype=expr.dtype,
-                           bindings=Map({"_in0": self.rec(expr.array)}),
+                           bindings=immutabledict({"_in0": self.rec(expr.array)}),
                            axes=expr.axes,
-                           var_to_reduction_descr=Map(),
+                           var_to_reduction_descr=immutabledict(),
                            tags=expr.tags)
 
     def map_axis_permutation(self, expr: AxisPermutation) -> IndexLambda:
@@ -462,9 +462,9 @@ class ToIndexLambdaMixin:
         return IndexLambda(expr=index_expr,
                            shape=self._rec_shape(expr.shape),
                            dtype=expr.dtype,
-                           bindings=Map({"_in0": self.rec(expr.array)}),
+                           bindings=immutabledict({"_in0": self.rec(expr.array)}),
                            axes=expr.axes,
-                           var_to_reduction_descr=Map(),
+                           var_to_reduction_descr=immutabledict(),
                            tags=expr.tags)
 
 
diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py
index 591e705..ff3cd2d 100644
--- a/pytato/transform/metadata.py
+++ b/pytato/transform/metadata.py
@@ -553,14 +553,14 @@ class AxesTagsEquationCollector(Mapper):
 
 def _get_propagation_graph_from_constraints(
         equations: List[Tuple[str, str]]) -> Mapping[str, FrozenSet[str]]:
-    import immutables
+    from immutabledict import immutabledict
     propagation_graph: Dict[str, Set[str]] = {}
     for lhs, rhs in equations:
         assert lhs != rhs
         propagation_graph.setdefault(lhs, set()).add(rhs)
         propagation_graph.setdefault(rhs, set()).add(lhs)
 
-    return immutables.Map({k: frozenset(v)
+    return immutabledict({k: frozenset(v)
                            for k, v in propagation_graph.items()})
 
 
diff --git a/pytato/utils.py b/pytato/utils.py
index 58bcb58..4c96e37 100644
--- a/pytato/utils.py
+++ b/pytato/utils.py
@@ -38,7 +38,7 @@ from pytato.scalar_expr import (ScalarExpression, IntegralScalarExpression,
                                 SCALAR_CLASSES, INT_CLASSES, BoolT, ScalarType)
 from pytools import UniqueNameGenerator
 from pytato.transform import Mapper
-from immutables import Map
+from immutabledict import immutabledict
 
 
 __doc__ = """
@@ -205,9 +205,9 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar,
     return IndexLambda(expr=op(expr1, expr2),
                        shape=result_shape,
                        dtype=result_dtype,
-                       bindings=Map(bindings),
+                       bindings=immutabledict(bindings),
                        tags=_get_default_tags(),
-                       var_to_reduction_descr=Map(),
+                       var_to_reduction_descr=immutabledict(),
                        axes=_get_default_axes(len(result_shape)))
 
 
diff --git a/setup.py b/setup.py
index e00ab52..098aa4a 100644
--- a/setup.py
+++ b/setup.py
@@ -36,7 +36,7 @@ setup(
     install_requires=[
         "loopy>=2020.2",
         "pytools>=2022.1.13",
-        "immutables",
+        "immutabledict",
         "attrs",
         "bidict",
         ],
diff --git a/test/test_pytato.py b/test/test_pytato.py
index 98393b9..19eade0 100644
--- a/test/test_pytato.py
+++ b/test/test_pytato.py
@@ -691,7 +691,7 @@ def test_basic_index_equality_traverses_underlying_arrays():
 
 def test_idx_lambda_to_hlo():
     from pytato.raising import index_lambda_to_high_level_op
-    from immutables import Map
+    from immutabledict import immutabledict
     from pytato.raising import (BinaryOp, BinaryOpType, FullOp, ReduceOp,
                                 C99CallOp, BroadcastOp)
 
@@ -734,11 +734,11 @@ def test_idx_lambda_to_hlo():
     assert (index_lambda_to_high_level_op(pt.sum(b, axis=1))
             == ReduceOp(SumReductionOperation(),
                         b,
-                        Map({1: "_r0"})))
+                        immutabledict({1: "_r0"})))
     assert (index_lambda_to_high_level_op(pt.prod(a))
             == ReduceOp(ProductReductionOperation(),
                         a,
-                        Map({0: "_r0",
+                        immutabledict({0: "_r0",
                              1: "_r1"})))
     assert index_lambda_to_high_level_op(pt.sinh(a)) == C99CallOp("sinh", (a,))
     assert index_lambda_to_high_level_op(pt.arctan2(b, a)) == C99CallOp("atan2",
-- 
GitLab