diff --git a/pytato/array.py b/pytato/array.py index bbf4ae739fdb8c096ff76bc02859882b1f85698a..067c1a08de172df6dc614664d72025a2f27c5de8 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -745,10 +745,11 @@ class AbstractResultWithNamedArrays(Mapping[str, NamedArray], Taggable, ABC): def _is_eq_valid(self) -> bool: return self.__class__.__eq__ is AbstractResultWithNamedArrays.__eq__ - def __attrs_post_init__(self) -> None: - # ensure that a developer does not uses dataclass' "__eq__" - # or "__hash__" implementation as they have exponential complexity. - assert self._is_eq_valid() + if __debug__: + def __attrs_post_init__(self) -> None: + # ensure that a developer does not uses dataclass' "__eq__" + # or "__hash__" implementation as they have exponential complexity. + assert self._is_eq_valid() @abstractmethod def __contains__(self, name: object) -> bool: @@ -1450,10 +1451,11 @@ class Reshape(IndexRemappingBase): _mapper_method: ClassVar[str] = "map_reshape" - def __attrs_post_init__(self) -> None: - # FIXME: Get rid of this restriction - assert self.order == "C" - super().__attrs_post_init__() + if __debug__: + def __attrs_post_init__(self) -> None: + # FIXME: Get rid of this restriction + assert self.order == "C" + super().__attrs_post_init__() @property def shape(self) -> ShapeType: diff --git a/pytato/function.py b/pytato/function.py index 6e5d044d2731fd77abe246f6b67c824938780cf3..b053831a04c2382a665247135c40b7c006a8b6fe 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -276,11 +276,12 @@ class Call(AbstractResultWithNamedArrays): copy = attrs.evolve - def __attrs_post_init__(self) -> None: - # check that the invocation parameters and the function definition - # parameters agree with each other. - assert frozenset(self.bindings) == self.function.parameters - super().__attrs_post_init__() + if __debug__: + def __attrs_post_init__(self) -> None: + # check that the invocation parameters and the function definition + # parameters agree with each other. + assert frozenset(self.bindings) == self.function.parameters + super().__attrs_post_init__() def __contains__(self, name: object) -> bool: return name in self.function.returns