From a3c7eead940d9ca29b6807f2785b3e4cd83e0430 Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Thu, 30 Jan 2025 09:03:34 -0600 Subject: [PATCH] replace pyrsistent.PMap, immutables.Map, immutabledict with constantdict (#884) * replace pyrsistent.PMap, immutables.Map with immutabledict * go back to immutables.Map for Tree * ruff fixes * spelling * switch to constantdict * lint fixes * doc fix * work around doc failure * remove some spurious changes * clean up types a bit * simplify with_kernel --- doc/conf.py | 6 +----- loopy/codegen/__init__.py | 8 +++---- loopy/frontend/fortran/translator.py | 4 ++-- loopy/kernel/__init__.py | 4 ++-- loopy/kernel/data.py | 4 ++-- loopy/kernel/function_interface.py | 12 ++++++----- loopy/kernel/instruction.py | 16 +++++++------- loopy/kernel/tools.py | 4 ++-- loopy/preprocess.py | 12 +++++------ loopy/schedule/__init__.py | 4 ++-- loopy/schedule/tools.py | 4 ++-- loopy/schedule/tree.py | 2 +- loopy/symbolic.py | 8 +++---- loopy/target/c/c_execution.py | 4 ++-- loopy/target/execution.py | 28 ++++++++++++------------- loopy/target/pyopencl_execution.py | 4 ++-- loopy/tools.py | 16 ++++++-------- loopy/transform/buffer.py | 6 +++--- loopy/transform/callable.py | 6 +++--- loopy/transform/data.py | 4 ++-- loopy/transform/fusion.py | 8 +++---- loopy/transform/pack_and_unpack_args.py | 4 ++-- loopy/transform/precompute.py | 4 ++-- loopy/transform/realize_reduction.py | 4 ++-- loopy/transform/save.py | 4 ++-- loopy/translation_unit.py | 28 ++++++++++++------------- pyproject.toml | 4 +--- 27 files changed, 101 insertions(+), 111 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index b23ce311..d12eb17b 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -32,7 +32,7 @@ intersphinx_mapping = { "pyopencl": ("https://documen.tician.de/pyopencl", None), "cgen": ("https://documen.tician.de/cgen", None), "pymbolic": ("https://documen.tician.de/pymbolic", None), - "pyrsistent": ("https://pyrsistent.readthedocs.io/en/latest/", None), + "constantdict": ("https://matthiasdiener.github.io/constantdict/", None), } nitpicky = True @@ -43,10 +43,6 @@ nitpick_ignore_regex = [ ["py:class", r"numpy\.float[0-9]+"], ["py:class", r"numpy\.complex[0-9]+"], - # As of 2022-06-22, it doesn't look like there's sphinx documentation - # available. - ["py:class", r"immutables\.(.+)"], - # Reference not found from "<unknown>"? I'm not even sure where to look. ["py:class", r"ExpressionNode"], diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py index 3c3b42f3..3de36e24 100644 --- a/loopy/codegen/__init__.py +++ b/loopy/codegen/__init__.py @@ -32,7 +32,7 @@ from typing import ( Sequence, ) -import immutables +import constantdict logger = logging.getLogger(__name__) @@ -168,7 +168,7 @@ class CodeGenerationState: seen_functions: set[SeenFunction] seen_atomic_dtypes: set[LoopyType] - var_subst_map: immutables.Map[str, Expression] + var_subst_map: constantdict.constantdict[str, Expression] allow_complex: bool callables_table: CallablesTable is_entrypoint: bool @@ -381,7 +381,7 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target, seen_dtypes=seen_dtypes, seen_functions=seen_functions, seen_atomic_dtypes=seen_atomic_dtypes, - var_subst_map=immutables.Map(), + var_subst_map=constantdict.constantdict(), allow_complex=allow_complex, var_name_generator=kernel.get_var_name_generator(), is_generating_device_code=False, @@ -482,7 +482,7 @@ def diverge_callee_entrypoints(program): new_callables[name] = clbl - return program.copy(callables_table=immutables.Map(new_callables)) + return program.copy(callables_table=constantdict.constantdict(new_callables)) @dataclass(frozen=True) diff --git a/loopy/frontend/fortran/translator.py b/loopy/frontend/fortran/translator.py index 5000abf8..4c55b30f 100644 --- a/loopy/frontend/fortran/translator.py +++ b/loopy/frontend/fortran/translator.py @@ -29,7 +29,7 @@ from typing import ClassVar from warnings import warn import numpy as np -from immutables import Map +from constantdict import constantdict import islpy as isl from islpy import dim_type @@ -334,7 +334,7 @@ def specialize_fortran_division(t_unit): new_callables[name] = clbl - return t_unit.copy(callables_table=Map(new_callables)) + return t_unit.copy(callables_table=constantdict(new_callables)) # }}} diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index d612b5db..f487078c 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -48,7 +48,7 @@ from typing import ( from warnings import warn import numpy as np -from immutables import Map +from constantdict import constantdict import islpy # to help out Sphinx import islpy as isl @@ -183,7 +183,7 @@ class LoopKernel(Taggable): Callable[[LoopKernel, str], tuple[LoopyType, str] | None]] = () linearization: Sequence[ScheduleItem] | None = None iname_slab_increments: Mapping[InameStr, tuple[int, int]] = field( - default_factory=Map) + default_factory=constantdict) """ A mapping from inames to (lower_incr, upper_incr) tuples that will be separated out in the execution to generate diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index d10401e5..3dd1cf82 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -64,7 +64,7 @@ from loopy.typing import Expression, ShapeType, auto if TYPE_CHECKING: - from immutables import Map + from collections.abc import Mapping from pymbolic import ArithmeticExpression, Variable @@ -437,7 +437,7 @@ class _ArraySeparationInfo: should be used to realize this array. """ sep_axis_indices_set: frozenset[int] - subarray_names: Map[tuple[int, ...], str] + subarray_names: Mapping[tuple[int, ...], str] class ArrayArg(ArrayBase, KernelArgument): diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py index 146d40f4..799e5d91 100644 --- a/loopy/kernel/function_interface.py +++ b/loopy/kernel/function_interface.py @@ -27,7 +27,7 @@ from dataclasses import dataclass, replace from typing import TYPE_CHECKING, Any, Callable, TypeVar from warnings import warn -from immutabledict import immutabledict +from constantdict import constantdict from typing_extensions import Self from loopy.diagnostic import LoopyError @@ -348,7 +348,8 @@ class InKernelCallable(ABC): try: hash(arg_id_to_dtype) except TypeError: - arg_id_to_dtype = immutabledict(arg_id_to_dtype) + assert arg_id_to_dtype is not None + arg_id_to_dtype = constantdict(arg_id_to_dtype) warn("arg_id_to_dtype passed to InKernelCallable was not hashable. " "This usage is deprecated and will stop working in 2026.", DeprecationWarning, stacklevel=3) @@ -356,7 +357,8 @@ class InKernelCallable(ABC): try: hash(arg_id_to_descr) except TypeError: - arg_id_to_descr = immutabledict(arg_id_to_descr) + assert arg_id_to_descr is not None + arg_id_to_descr = constantdict(arg_id_to_descr) warn("arg_id_to_descr passed to InKernelCallable was not hashable. " "This usage is deprecated and will stop working in 2026.", DeprecationWarning, stacklevel=3) @@ -773,7 +775,7 @@ class CallableKernel(InKernelCallable): # Return the kernel call with specialized subkernel and the corresponding # new arg_id_to_dtype return self.copy(subkernel=specialized_kernel, - arg_id_to_dtype=immutabledict(new_arg_id_to_dtype)), callables_table + arg_id_to_dtype=constantdict(new_arg_id_to_dtype)), callables_table def with_descrs(self, arg_id_to_descr, clbl_inf_ctx): @@ -848,7 +850,7 @@ class CallableKernel(InKernelCallable): # }}} return (self.copy(subkernel=subkernel, - arg_id_to_descr=immutabledict(arg_id_to_descr)), + arg_id_to_descr=constantdict(arg_id_to_descr)), clbl_inf_ctx) def with_added_arg(self, arg_dtype, arg_descr): diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py index f882c09f..604b581e 100644 --- a/loopy/kernel/instruction.py +++ b/loopy/kernel/instruction.py @@ -292,7 +292,7 @@ class InstructionBase(ImmutableRecord, Taggable): *, depends_on: frozenset[str] | str | None = None, ) -> None: - from immutabledict import immutabledict + from constantdict import constantdict if predicates is None: predicates = frozenset() @@ -324,29 +324,29 @@ class InstructionBase(ImmutableRecord, Taggable): raise LoopyError("Setting depends_on_is_final to True requires " "actually specifying happens_after/depends_on") - if isinstance(happens_after, immutabledict): + if isinstance(happens_after, constantdict): pass elif happens_after is None: - happens_after = immutabledict() + happens_after = constantdict() elif isinstance(happens_after, str): warn("Passing a string for happens_after/depends_on is deprecated and " "will stop working in 2025. Instead, pass a full-fledged " "happens_after data structure.", DeprecationWarning, stacklevel=2) - happens_after = immutabledict({ + happens_after = constantdict({ after_id.strip(): HappensAfter( variable_name=None, instances_rel=None) for after_id in happens_after.split(",") if after_id.strip()}) elif isinstance(happens_after, frozenset): - happens_after = immutabledict({ + happens_after = constantdict({ after_id: HappensAfter( variable_name=None, instances_rel=None) for after_id in happens_after}) elif isinstance(happens_after, dict): - happens_after = immutabledict(happens_after) + happens_after = constantdict(happens_after) else: raise TypeError("'happens_after' has unexpected type: " f"{type(happens_after)}") @@ -569,13 +569,13 @@ class InstructionBase(ImmutableRecord, Taggable): def __setstate__(self, val): super().__setstate__(val) - from immutabledict import immutabledict + from constantdict import constantdict from loopy.tools import intern_frozenset_of_ids if self.id is not None: # pylint:disable=access-member-before-definition self.id = intern(self.id) - self.happens_after = immutabledict({ + self.happens_after = constantdict({ intern(after_id): ha for after_id, ha in self.happens_after.items()}) self.groups = intern_frozenset_of_ids(self.groups) diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index c48da4be..856ba19c 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -2089,7 +2089,7 @@ def get_call_graph(t_unit, only_kernel_callables=False): :arg t_unit: An instance of :class:`TranslationUnit`. """ - from pyrsistent import pmap + from constantdict import constantdict from loopy.kernel import KernelState @@ -2116,7 +2116,7 @@ def get_call_graph(t_unit, only_kernel_callables=False): call_graph[name] = clbl.get_called_callables(t_unit.callables_table, recursive=False) - return pmap(call_graph) + return constantdict(call_graph) # }}} diff --git a/loopy/preprocess.py b/loopy/preprocess.py index aee4044b..7600c97b 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) from functools import partial import numpy as np -from immutables import Map +from constantdict import constantdict from pytools import ProcessLogger @@ -197,7 +197,7 @@ def make_arrays_for_sep_arrays(kernel: LoopKernel) -> LoopKernel: sep_info = _ArraySeparationInfo( sep_axis_indices_set=sep_axis_indices_set, - subarray_names=Map({ + subarray_names=constantdict({ ind: vng(f"{arg.name}_s{'_'.join(str(i) for i in ind)}") for ind in np.ndindex(*cast("List[int]", sep_shape))})) @@ -605,8 +605,6 @@ class ArgDescrInferenceMapper(RuleAwareIdentityMapper): raise NotImplementedError def __call__(self, expr, kernel, insn, assignees=None): - import immutables - from loopy.kernel.data import InstructionBase from loopy.symbolic import ExpansionState, UncachedIdentityMapper assert insn is None or isinstance(insn, InstructionBase) @@ -616,7 +614,7 @@ class ArgDescrInferenceMapper(RuleAwareIdentityMapper): kernel=kernel, instruction=insn, stack=(), - arg_context=immutables.Map()), assignees=assignees) + arg_context=constantdict()), assignees=assignees) def map_kernel(self, kernel): @@ -750,7 +748,7 @@ def filter_reachable_callables(t_unit): t_unit.entrypoints) new_callables = {name: clbl for name, clbl in t_unit.callables_table.items() if name in (reachable_function_ids | t_unit.entrypoints)} - return t_unit.copy(callables_table=Map(new_callables)) + return t_unit.copy(callables_table=constantdict(new_callables)) def _preprocess_single_kernel(kernel: LoopKernel, is_entrypoint: bool) -> LoopKernel: @@ -875,7 +873,7 @@ def preprocess_program(t_unit: TranslationUnit) -> TranslationUnit: new_callables[func_id] = in_knl_callable - t_unit = t_unit.copy(callables_table=Map(new_callables)) + t_unit = t_unit.copy(callables_table=constantdict(new_callables)) # }}} diff --git a/loopy/schedule/__init__.py b/loopy/schedule/__init__.py index a009fc9c..8de619fe 100644 --- a/loopy/schedule/__init__.py +++ b/loopy/schedule/__init__.py @@ -34,7 +34,7 @@ from typing import ( TypeVar, ) -from immutables import Map +from constantdict import constantdict import islpy as isl from pytools import ImmutableRecord, MinRecursionLimit, ProcessLogger @@ -2480,7 +2480,7 @@ def linearize(t_unit: TranslationUnit) -> TranslationUnit: else: raise NotImplementedError(type(clbl)) - return t_unit.copy(callables_table=Map(new_callables)) + return t_unit.copy(callables_table=constantdict(new_callables)) # vim: foldmethod=marker diff --git a/loopy/schedule/tools.py b/loopy/schedule/tools.py index b659ee7b..709e3705 100644 --- a/loopy/schedule/tools.py +++ b/loopy/schedule/tools.py @@ -65,7 +65,7 @@ from dataclasses import dataclass from functools import cached_property, reduce from typing import TYPE_CHECKING, AbstractSet, Sequence -from immutables import Map +from constantdict import constantdict from typing_extensions import TypeAlias import islpy as isl @@ -1062,7 +1062,7 @@ def _get_iname_to_tree_node_id_from_partial_loop_nest_tree( for iname in node: iname_to_tree_node_id[iname] = node - return Map(iname_to_tree_node_id) + return constantdict(iname_to_tree_node_id) def get_loop_tree(kernel: LoopKernel) -> LoopTree: diff --git a/loopy/schedule/tree.py b/loopy/schedule/tree.py index 3861aa75..fbd26477 100644 --- a/loopy/schedule/tree.py +++ b/loopy/schedule/tree.py @@ -40,7 +40,7 @@ from dataclasses import dataclass from functools import cached_property, reduce from typing import Generic, TypeVar -from immutables import Map +from constantdict import constantdict as Map # noqa: N812 from pytools import memoize_method diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 7b4b9805..20ff55fe 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -42,8 +42,8 @@ from typing import ( ) from warnings import warn -import immutables import numpy as np +from constantdict import constantdict from typing_extensions import Self import islpy as isl @@ -1114,7 +1114,7 @@ class ExpansionState: kernel: LoopKernel instruction: InstructionBase stack: tuple[tuple[str, Tag], ...] - arg_context: immutables.Map[str, Expression] + arg_context: Mapping[str, Expression] def __post_init__(self) -> None: hash(self.arg_context) @@ -1352,7 +1352,7 @@ class RuleAwareIdentityMapper(IdentityMapper[Concatenate[ExpansionState, P]]): from pymbolic.mapper.substitutor import make_subst_func arg_subst_map = SubstitutionMapper(make_subst_func(arg_context)) - return immutables.Map({ + return constantdict({ formal_arg_name: arg_subst_map(arg_value) for formal_arg_name, arg_value in zip(arg_names, arguments)}) @@ -1398,7 +1398,7 @@ class RuleAwareIdentityMapper(IdentityMapper[Concatenate[ExpansionState, P]]): kernel=kernel, instruction=insn, stack=(), - arg_context=immutables.Map())) + arg_context=constantdict())) def map_instruction(self, kernel, insn): return insn diff --git a/loopy/target/c/c_execution.py b/loopy/target/c/c_execution.py index 270c3d0d..b1566d07 100644 --- a/loopy/target/c/c_execution.py +++ b/loopy/target/c/c_execution.py @@ -48,7 +48,7 @@ from loopy.types import LoopyType if TYPE_CHECKING: - from immutables import Map + from constantdict import constantdict from loopy.codegen.result import GeneratedProgram from loopy.kernel import LoopKernel @@ -500,7 +500,7 @@ class CExecutor(ExecutorBase): @memoize_method def translation_unit_info(self, - arg_to_dtype: Map[str, LoopyType] | None = None) -> _KernelInfo: + arg_to_dtype: constantdict[str, LoopyType] | None = None) -> _KernelInfo: t_unit = self.get_typed_and_scheduled_translation_unit(arg_to_dtype) from loopy.codegen import generate_code_v2 diff --git a/loopy/target/execution.py b/loopy/target/execution.py index cb737f95..2bde4ef9 100644 --- a/loopy/target/execution.py +++ b/loopy/target/execution.py @@ -36,7 +36,7 @@ from typing import ( cast, ) -from immutables import Map +from constantdict import constantdict from pymbolic import Variable, var from pytools.codegen import CodeGenerator, Indentation @@ -817,7 +817,7 @@ class ExecutorBase: "your argument.") def get_typed_and_scheduled_translation_unit_uncached( - self, arg_to_dtype: Map[str, LoopyType] | None + self, arg_to_dtype: constantdict[str, LoopyType] | None ) -> TranslationUnit: t_unit = self.t_unit @@ -827,15 +827,15 @@ class ExecutorBase: # FIXME: This is not so nice. This transfers types from the # subarrays of sep-tagged arrays to the 'main' array, because # type inference fails otherwise. - with arg_to_dtype.mutate() as mm: - for name, sep_info in self.sep_info.items(): - if entry_knl.arg_dict[name].dtype is None: - for sep_name in sep_info.subarray_names.values(): - if sep_name in arg_to_dtype: - mm.set(name, arg_to_dtype[sep_name]) - del mm[sep_name] + mm = arg_to_dtype.mutate() + for name, sep_info in self.sep_info.items(): + if entry_knl.arg_dict[name].dtype is None: + for sep_name in sep_info.subarray_names.values(): + if sep_name in arg_to_dtype: + mm[name] = arg_to_dtype[sep_name] + del mm[sep_name] - arg_to_dtype = mm.finish() + arg_to_dtype = mm.finish() from loopy.kernel.tools import add_dtypes t_unit = t_unit.with_kernel(add_dtypes(entry_knl, arg_to_dtype)) @@ -854,7 +854,7 @@ class ExecutorBase: return t_unit def get_typed_and_scheduled_translation_unit( - self, arg_to_dtype: Map[str, LoopyType] | None + self, arg_to_dtype: constantdict[str, LoopyType] | None ) -> TranslationUnit: from loopy import CACHING_ENABLED @@ -876,7 +876,7 @@ class ExecutorBase: return t_unit - def arg_to_dtype(self, kwargs) -> Map[str, LoopyType] | None: + def arg_to_dtype(self, kwargs) -> constantdict[str, LoopyType] | None: if not self.has_runtime_typed_args: return None @@ -893,7 +893,7 @@ class ExecutorBase: else: arg_to_dtype[arg_name] = NumpyType(dtype) - return Map(arg_to_dtype) + return constantdict(arg_to_dtype) # {{{ debugging aids @@ -904,7 +904,7 @@ class ExecutorBase: def get_code( self, entrypoint: str, - arg_to_dtype: Map[str, LoopyType] | None = None) -> str: + arg_to_dtype: constantdict[str, LoopyType] | None = None) -> str: kernel = self.get_typed_and_scheduled_translation_unit(arg_to_dtype) from loopy.codegen import generate_code_v2 diff --git a/loopy/target/pyopencl_execution.py b/loopy/target/pyopencl_execution.py index c9191e1d..8c368cf8 100644 --- a/loopy/target/pyopencl_execution.py +++ b/loopy/target/pyopencl_execution.py @@ -42,7 +42,7 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: - from immutables import Map + from constantdict import constantdict import pyopencl as cl @@ -311,7 +311,7 @@ class PyOpenCLExecutor(ExecutorBase): @memoize_method def translation_unit_info( self, - arg_to_dtype: Map[str, LoopyType] | None = None) -> _KernelInfo: + arg_to_dtype: constantdict[str, LoopyType] | None = None) -> _KernelInfo: t_unit = self.get_typed_and_scheduled_translation_unit(arg_to_dtype) # FIXME: now just need to add the types to the arguments diff --git a/loopy/tools.py b/loopy/tools.py index 6faaceb9..e9f9932b 100644 --- a/loopy/tools.py +++ b/loopy/tools.py @@ -29,7 +29,7 @@ from functools import cached_property from sys import intern import numpy as np -from immutables import Map +from constantdict import constantdict import islpy as isl from pytools import ProcessLogger, memoize_method @@ -70,8 +70,8 @@ class LoopyKeyBuilder(KeyBuilderBase): update_for_list = KeyBuilderBase.update_for_tuple update_for_set = KeyBuilderBase.update_for_frozenset - update_for_dict = KeyBuilderBase.update_for_immutabledict - update_for_defaultdict = KeyBuilderBase.update_for_immutabledict + update_for_dict = KeyBuilderBase.update_for_constantdict + update_for_defaultdict = KeyBuilderBase.update_for_constantdict def update_for_BasicSet(self, key_hash, key): # noqa from islpy import Printer @@ -80,15 +80,11 @@ class LoopyKeyBuilder(KeyBuilderBase): key_hash.update(prn.get_str().encode("utf8")) def update_for_Map(self, key_hash, key): # noqa - if isinstance(key, Map): - self.update_for_dict(key_hash, key) - elif isinstance(key, isl.Map): + if isinstance(key, isl.Map): self.update_for_BasicSet(key_hash, key) else: raise AssertionError() - update_for_PMap = update_for_dict # noqa: N815 - # }}} @@ -800,7 +796,7 @@ def t_unit_to_python(t_unit, var_name="t_unit", .callables_table)) for name, clbl in t_unit.callables_table.items() if isinstance(clbl, CallableKernel)} - t_unit = t_unit.copy(callables_table=Map(new_callables)) + t_unit = t_unit.copy(callables_table=constantdict(new_callables)) knl_python_code_srcs = [_kernel_to_python(clbl.subkernel, name in t_unit.entrypoints, @@ -815,7 +811,7 @@ def t_unit_to_python(t_unit, var_name="t_unit", "import loopy as lp", "import numpy as np", "from pymbolic.primitives import *", - "import immutables", + "from constantdict import constantdict", ]) body_str = "\n".join([*knl_python_code_srcs, "\n", merge_stmt]) diff --git a/loopy/transform/buffer.py b/loopy/transform/buffer.py index f113e453..be29eeb4 100644 --- a/loopy/transform/buffer.py +++ b/loopy/transform/buffer.py @@ -25,7 +25,7 @@ THE SOFTWARE. import logging -from immutables import Map +from constantdict import constantdict from pymbolic import var from pymbolic.mapper.substitutor import make_subst_func @@ -128,7 +128,7 @@ class ArrayAccessReplacer(RuleAwareIdentityMapper): # Can't possibly be nested, but recurse anyway to # make sure substitution rules referenced below here # do not get thrown away. - self.rec(result, expn_state.copy(arg_context=Map())) + self.rec(result, expn_state.copy(arg_context=constantdict())) return result @@ -540,7 +540,7 @@ def buffer_array(program, *args, **kwargs): new_callables[func_id] = clbl - return program.copy(callables_table=Map(new_callables)) + return program.copy(callables_table=constantdict(new_callables)) # vim: foldmethod=marker diff --git a/loopy/transform/callable.py b/loopy/transform/callable.py index 8669a4ab..ef55db5a 100644 --- a/loopy/transform/callable.py +++ b/loopy/transform/callable.py @@ -26,7 +26,7 @@ THE SOFTWARE. from typing import TYPE_CHECKING -from immutables import Map +from constantdict import constantdict import islpy as isl from pytools import UniqueNameGenerator @@ -129,7 +129,7 @@ def merge(translation_units: Sequence[TranslationUnit]) -> TranslationUnit: return TranslationUnit( entrypoints=frozenset().union(*( t.entrypoints or frozenset() for t in translation_units)), - callables_table=Map(callables_table), + callables_table=constantdict(callables_table), target=translation_units[0].target) @@ -605,7 +605,7 @@ def rename_callable( new_entrypoints = ((new_entrypoints | frozenset([new_name])) - frozenset([old_name])) - return t_unit.copy(callables_table=Map(new_callables_table), + return t_unit.copy(callables_table=constantdict(new_callables_table), entrypoints=new_entrypoints) # }}} diff --git a/loopy/transform/data.py b/loopy/transform/data.py index 80a0c4a1..2b0606ec 100644 --- a/loopy/transform/data.py +++ b/loopy/transform/data.py @@ -28,7 +28,7 @@ from typing import TYPE_CHECKING, cast from warnings import warn import numpy as np -from immutables import Map +from constantdict import constantdict from islpy import dim_type from pytools import MovedFunctionDeprecationWrapper @@ -431,7 +431,7 @@ def add_prefetch(t_unit, new_callables[func_id] = in_knl_callable - return t_unit.copy(callables_table=Map(new_callables)) + return t_unit.copy(callables_table=constantdict(new_callables)) # }}} diff --git a/loopy/transform/fusion.py b/loopy/transform/fusion.py index b16d837f..f11006af 100644 --- a/loopy/transform/fusion.py +++ b/loopy/transform/fusion.py @@ -24,7 +24,7 @@ THE SOFTWARE. """ -from immutables import Map +from constantdict import constantdict import islpy as isl from islpy import dim_type @@ -120,8 +120,8 @@ def _merge_dicts(item_name, dict_a, dict_b): else: result[k] = v - if isinstance(dict_a, Map): - return Map(result) + if isinstance(dict_a, constantdict): + return constantdict(result) else: return result @@ -453,7 +453,7 @@ def fuse_kernels(kernels, suffixes=None, data_flow=None): new_callables[result.name] = CallableKernel(result) - return TranslationUnit(callables_table=Map(new_callables), + return TranslationUnit(callables_table=constantdict(new_callables), target=result.target, entrypoints=frozenset([result.name])) diff --git a/loopy/transform/pack_and_unpack_args.py b/loopy/transform/pack_and_unpack_args.py index 9dc5f9a9..50fb6e29 100644 --- a/loopy/transform/pack_and_unpack_args.py +++ b/loopy/transform/pack_and_unpack_args.py @@ -23,7 +23,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from immutables import Map +from constantdict import constantdict from loopy.diagnostic import LoopyError from loopy.kernel import LoopKernel @@ -342,6 +342,6 @@ def pack_and_unpack_args_for_call(program, *args, **kwargs): new_callables[func_id] = in_knl_callable - return program.copy(callables_table=Map(new_callables)) + return program.copy(callables_table=constantdict(new_callables)) # vim: foldmethod=marker diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py index 3988b1f5..40666412 100644 --- a/loopy/transform/precompute.py +++ b/loopy/transform/precompute.py @@ -28,7 +28,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Sequence, cast import numpy as np -from immutables import Map +from constantdict import constantdict import islpy as isl from pymbolic import ArithmeticExpression, var @@ -1187,6 +1187,6 @@ def precompute(program, *args, **kwargs): new_callables[func_id] = clbl - return program.copy(callables_table=Map(new_callables)) + return program.copy(callables_table=constantdict(new_callables)) # vim: foldmethod=marker diff --git a/loopy/transform/realize_reduction.py b/loopy/transform/realize_reduction.py index 5f504e72..ceb824b4 100644 --- a/loopy/transform/realize_reduction.py +++ b/loopy/transform/realize_reduction.py @@ -34,7 +34,7 @@ from typing import TYPE_CHECKING, Callable, Sequence logger = logging.getLogger(__name__) -from immutables import Map +from constantdict import constantdict import islpy as isl from pytools import memoize_on_first_arg @@ -2184,6 +2184,6 @@ def realize_reduction(t_unit, *args, **kwargs): subkernel=new_knl) callables_table[knl.name] = in_knl_callable - return t_unit.copy(callables_table=Map(callables_table)) + return t_unit.copy(callables_table=constantdict(callables_table)) # vim: foldmethod=marker diff --git a/loopy/transform/save.py b/loopy/transform/save.py index e1dbfd99..fe3cf190 100644 --- a/loopy/transform/save.py +++ b/loopy/transform/save.py @@ -26,7 +26,7 @@ THE SOFTWARE. import logging from functools import cached_property -from immutables import Map +from constantdict import constantdict from pytools import Record, memoize_method @@ -252,7 +252,7 @@ class TemporarySaver: from collections import defaultdict self.insns_to_insert = [] self.insns_to_update = {} - self.updated_iname_objs = Map() + self.updated_iname_objs = constantdict() self.updated_temporary_variables = {} # temporary name -> save or reload insn ids diff --git a/loopy/translation_unit.py b/loopy/translation_unit.py index 670feeef..93ecbe9a 100644 --- a/loopy/translation_unit.py +++ b/loopy/translation_unit.py @@ -37,7 +37,7 @@ from typing import ( ) from warnings import warn -from immutables import Map +from constantdict import constantdict from typing_extensions import Concatenate, ParamSpec, Self from pymbolic.primitives import Call, Variable @@ -175,7 +175,7 @@ class CallableResolver(RuleAwareIdentityMapper): # {{{ translation unit FunctionIdT = Union[str, ReductionOpFunction] -ConcreteCallablesTable = Map[FunctionIdT, InKernelCallable] +ConcreteCallablesTable = constantdict[FunctionIdT, InKernelCallable] CallablesTable = Mapping[FunctionIdT, InKernelCallable] @@ -203,7 +203,7 @@ class TranslationUnit: .. attribute:: callables_table - An instance of :class:`pyrsistent.PMap` mapping the function + An instance of :class:`constantdict.constantdict` mapping the function identifiers in a kernel to their associated instances of :class:`~loopy.kernel.function_interface.InKernelCallable`. @@ -239,7 +239,7 @@ class TranslationUnit: def __post_init__(self): assert isinstance(self.entrypoints, abc_Set) - assert isinstance(self.callables_table, Map) + assert isinstance(self.callables_table, constantdict) def copy(self, **kwargs: Any) -> Self: target = kwargs.pop("target", None) @@ -267,7 +267,7 @@ class TranslationUnit: new_callables[func_id] = clbl t_unit = replace( - self, callables_table=Map(new_callables), target=target) + self, callables_table=constantdict(new_callables), target=target) return t_unit @@ -302,14 +302,14 @@ class TranslationUnit: # update the callable kernel new_in_knl_callable = self.callables_table[kernel.name].copy( subkernel=kernel) - new_callables = self.callables_table.delete(kernel.name).set( - kernel.name, new_in_knl_callable) - return self.copy(callables_table=new_callables) + return self.copy( + callables_table=self.callables_table.set( + kernel.name, new_in_knl_callable)) else: # add a new callable kernel clbl = CallableKernel(kernel) - new_callables = self.callables_table.set(kernel.name, clbl) - return self.copy(callables_table=new_callables) + return self.copy( + callables_table=self.callables_table.set(kernel.name, clbl)) def __getitem__(self, name) -> LoopKernel: """ @@ -720,7 +720,7 @@ class CallablesInferenceContext: # }}} - return program.copy(callables_table=Map(new_callables)) + return program.copy(callables_table=constantdict(new_callables)) def __getitem__(self, name): result = self.callables[name] @@ -741,7 +741,7 @@ def make_program(kernel: LoopKernel) -> TranslationUnit: """ return TranslationUnit( - callables_table=Map({ + callables_table=constantdict({ kernel.name: CallableKernel(kernel)}), target=kernel.target, entrypoints=frozenset()) @@ -803,7 +803,7 @@ def for_each_kernel( new_callables[func_id] = clbl - return t_unit.copy(callables_table=Map(new_callables)) + return t_unit.copy(callables_table=constantdict(new_callables)) elif isinstance(t_unit_or_kernel, LoopKernel): kernel = t_unit_or_kernel return transform(kernel, *args, **kwargs) @@ -899,7 +899,7 @@ def resolve_callables(t_unit: TranslationUnit) -> TranslationUnit: else: raise NotImplementedError(f"{type(clbl)}") - t_unit = t_unit.copy(callables_table=Map(callables_table)) + t_unit = t_unit.copy(callables_table=constantdict(callables_table)) validate_kernel_call_sites(t_unit) diff --git a/pyproject.toml b/pyproject.toml index d61564a7..7916bea5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,9 +41,7 @@ dependencies = [ "codepy>=2017.1", "colorama", "Mako", - "pyrsistent", - "immutables", - "immutabledict", + "constantdict", "typing-extensions>=4", ] -- GitLab