diff --git a/src/primitives.py b/src/primitives.py index f4003b0b6b5affe8ac4a936f27cf1f17ca45ad05..2caec17f547af34ff4415e1f62a9d6ed7a7c3ebd 100644 --- a/src/primitives.py +++ b/src/primitives.py @@ -9,9 +9,6 @@ class Expression(object): Expression objects are immutable. """ - def __ne__(self, other): - return not self.__eq__(other) - def __add__(self, other): if not is_valid_operand(other): return NotImplemented @@ -136,13 +133,44 @@ class Expression(object): return "%s(%s)" % (self.__class__.__name__, initargs_str) + # hashable interface ------------------------------------------------------ + def __eq__(self, other): + """Provides equality testing with quick positive and negative paths + based on L{id} and L{__hash__}(). + + Subclasses should generally not override this method, but instead + provide an implementation of L{is_equal}. + """ + if id(self) == id(other): + return True + elif hash(self) != hash(other): + return False + else: + return self.is_equal(other) + + def __ne__(self, other): + return not self.__eq__(other) + def __hash__(self): + """Provides caching for hash values. + + Subclasses should generally not override this method, but instead + provide an implementation of L{get_hash}. + """ try: return self.hash_value except AttributeError: self.hash_value = self.get_hash() return self.hash_value + # hashable backend -------------------------------------------------------- + def is_equal(self, other): + return NotImplemented + + def get_hash(self, other): + raise NotImplementedError + + @@ -188,7 +216,7 @@ class Variable(Leaf): else: return NotImplemented - def __eq__(self, other): + def is_equal(self, other): return (isinstance(other, Variable) and self.name == other.name) @@ -206,7 +234,7 @@ class Call(AlgebraicLeaf): def __getinitargs__(self): return self.function, self.parameters - def __eq__(self, other): + def is_equal(self, other): return isinstance(other, Call) \ and (self.function == other.function) \ and (self.parameters == other.parameters) @@ -228,7 +256,7 @@ class Subscript(AlgebraicLeaf): def __getinitargs__(self): return self.aggregate, self.index - def __eq__(self, other): + def is_equal(self, other): return isinstance(other, Subscript) \ and (self.aggregate == other.aggregate) \ and (self.index == other.index) @@ -250,7 +278,7 @@ class Lookup(AlgebraicLeaf): def __getinitargs__(self): return self.aggregate, self.name - def __eq__(self, other): + def is_equal(self, other): return isinstance(other, Lookup) \ and (self.aggregate == other.aggregate) \ and (self.name == other.name) @@ -273,7 +301,7 @@ class Sum(Expression): def __getinitargs__(self): return self.children - def __eq__(self, other): + def is_equal(self, other): return (isinstance(other, Sum) and (set(self.children) == set(other.children))) @@ -331,7 +359,7 @@ class Product(Expression): def __getinitargs__(self): return self.children - def __eq__(self, other): + def is_equal(self, other): return (isinstance(other, Product) and (set(self.children) == set(other.children))) @@ -388,7 +416,7 @@ class Quotient(Expression): def den(self): return self.denominator - def __eq__(self, other): + def is_equal(self, other): from pymbolic.rational import Rational return isinstance(other, (Rational, Quotient)) \ and (self.numerator == other.numerator) \ @@ -414,7 +442,7 @@ class Power(Expression): def __getinitargs__(self): return self.base, self.exponent - def __eq__(self, other): + def is_equal(self, other): return isinstance(other, Power) \ and (self.base == other.base) \ and (self.exponent == other.exponent)