diff --git a/pytato/function.py b/pytato/function.py index c79dfcfe3349f2fe6b8d940f812b0ab3860a56b0..66a30fc1d6bd3c349a9ca1ee8628a08eb1c3b264 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -59,6 +59,7 @@ import enum import re from functools import cached_property from typing import ( + Any, Callable, ClassVar, Hashable, @@ -236,6 +237,15 @@ class FunctionDefinition(Taggable): else: raise NotImplementedError(self.return_type) + def __eq__(self, other: Any) -> bool: + if self is other: + return True + if not isinstance(other, FunctionDefinition): + return False + + from pytato.equality import EqualityComparer + return EqualityComparer().map_function_definition(self, other) + @attrs.frozen(eq=False, repr=False, hash=True, cache_hash=True) class NamedCallResult(NamedArray):