diff --git a/pytato/transform/einsum_distributive_law.py b/pytato/transform/einsum_distributive_law.py index 5ef5b504973c32f4b5ca37d4f8d1f0aff65ded44..106c5a4fe7c5aacd42957ca3a5d9453e29cb5e3e 100644 --- a/pytato/transform/einsum_distributive_law.py +++ b/pytato/transform/einsum_distributive_law.py @@ -33,7 +33,7 @@ THE SOFTWARE. from typing import Callable, Dict, Tuple, Optional, FrozenSet -import dataclasses as dc +import attrs from pytato.transform import ArrayOrNames, Mapper, MappedT from pytato.array import (Array, AxesT, Einsum, IndexLambda, EinsumReductionAxis, @@ -54,7 +54,7 @@ class EinsumDistributiveLawDescriptor: """ -@dc.dataclass(frozen=True) +@attrs.frozen class DoNotDistribute(EinsumDistributiveLawDescriptor): """ Tells :func:`apply_distributive_property_to_einsums` to not apply @@ -62,7 +62,7 @@ class DoNotDistribute(EinsumDistributiveLawDescriptor): """ -@dc.dataclass(frozen=True) +@attrs.frozen class DoDistribute(EinsumDistributiveLawDescriptor): """ Tells :func:`apply_distributive_property_to_einsums` to apply distributive @@ -71,17 +71,17 @@ class DoDistribute(EinsumDistributiveLawDescriptor): ioperand: int -@dc.dataclass(frozen=True) +@attrs.frozen class _EinsumDistributiveLawMapperContext: access_descriptors: Tuple[Tuple[EinsumAxisDescriptor, ...], ...] surrounding_args: Map[int, Array] redn_axis_to_redn_descr: Map[EinsumReductionAxis, ReductionDescriptor] index_to_access_descr: Map[str, EinsumAxisDescriptor] - axes: AxesT = dc.field(kw_only=True) - tags: FrozenSet[Tag] = dc.field(kw_only=True) + axes: AxesT = attrs.field(kw_only=True) + tags: FrozenSet[Tag] = attrs.field(kw_only=True) - def __post_init__(self) -> None: + def __attrs_post_init__(self) -> None: # {{{ check that exactly one of the args is missing assert len(self.surrounding_args) == (