From 2b68db4113a72fc69d85c7cdd2a19ec43d56990c Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Aug 2023 23:57:40 -0500 Subject: [PATCH] Use attrs isntead of dataclasses in einsum_distributive_law --- pytato/transform/einsum_distributive_law.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pytato/transform/einsum_distributive_law.py b/pytato/transform/einsum_distributive_law.py index 5ef5b50..106c5a4 100644 --- a/pytato/transform/einsum_distributive_law.py +++ b/pytato/transform/einsum_distributive_law.py @@ -33,7 +33,7 @@ THE SOFTWARE. from typing import Callable, Dict, Tuple, Optional, FrozenSet -import dataclasses as dc +import attrs from pytato.transform import ArrayOrNames, Mapper, MappedT from pytato.array import (Array, AxesT, Einsum, IndexLambda, EinsumReductionAxis, @@ -54,7 +54,7 @@ class EinsumDistributiveLawDescriptor: """ -@dc.dataclass(frozen=True) +@attrs.frozen class DoNotDistribute(EinsumDistributiveLawDescriptor): """ Tells :func:`apply_distributive_property_to_einsums` to not apply @@ -62,7 +62,7 @@ class DoNotDistribute(EinsumDistributiveLawDescriptor): """ -@dc.dataclass(frozen=True) +@attrs.frozen class DoDistribute(EinsumDistributiveLawDescriptor): """ Tells :func:`apply_distributive_property_to_einsums` to apply distributive @@ -71,17 +71,17 @@ class DoDistribute(EinsumDistributiveLawDescriptor): ioperand: int -@dc.dataclass(frozen=True) +@attrs.frozen class _EinsumDistributiveLawMapperContext: access_descriptors: Tuple[Tuple[EinsumAxisDescriptor, ...], ...] surrounding_args: Map[int, Array] redn_axis_to_redn_descr: Map[EinsumReductionAxis, ReductionDescriptor] index_to_access_descr: Map[str, EinsumAxisDescriptor] - axes: AxesT = dc.field(kw_only=True) - tags: FrozenSet[Tag] = dc.field(kw_only=True) + axes: AxesT = attrs.field(kw_only=True) + tags: FrozenSet[Tag] = attrs.field(kw_only=True) - def __post_init__(self) -> None: + def __attrs_post_init__(self) -> None: # {{{ check that exactly one of the args is missing assert len(self.surrounding_args) == ( -- GitLab