From 20330bb84a721ba844a12111513c8d2f052025a2 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Fri, 11 Nov 2011 04:34:56 -0500
Subject: [PATCH] Use __getinitargs__ everywhere, introduce comparisons and
 Boolean primitives.

---
 pymbolic/primitives.py | 268 +++++++++++++++++++----------------------
 1 file changed, 126 insertions(+), 142 deletions(-)

diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py
index d658d6b..07e9467 100644
--- a/pymbolic/primitives.py
+++ b/pymbolic/primitives.py
@@ -9,6 +9,8 @@ class Expression(object):
     Expression objects are immutable.
     """
 
+    # {{{ arithmetic
+
     def __add__(self, other):
         if not is_valid_operand(other):
             return NotImplemented
@@ -140,6 +142,8 @@ class Expression(object):
             return 1
         return Power(other, self)
 
+    # }}}
+
     def __neg__(self):
         return -1*self
 
@@ -169,7 +173,8 @@ class Expression(object):
 
         return "%s(%s)" % (self.__class__.__name__, initargs_str)
 
-    # hashable interface ------------------------------------------------------
+    # {{{ hashable interface
+
     def __eq__(self, other):
         """Provides equality testing with quick positive and negative paths
         based on L{id} and L{__hash__}().
@@ -199,19 +204,30 @@ class Expression(object):
             self.hash_value = self.get_hash()
             return self.hash_value
 
-    # hashable backend --------------------------------------------------------
+    # }}}
+
+    # {{{ hashable backend
+
     def is_equal(self, other):
-        raise NotImplementedError("is_equal() in "+str(type(self)))
+        return (type(other) == type(self)
+                and self.__getinitargs__() == other.__getinitargs__())
 
     def get_hash(self):
-        raise NotImplementedError("get_hash() in "+str(type(self)))
+        return hash((type(self),)+ self.__getinitargs__())
+
+    # }}}
+
+    # {{{ comparison interface
+
+    # /!\ Don't be tempted to resolve these to ComparisonOperator.
 
-    # comparison interface ----------------------------------------------------
     def __le__(self, other): raise TypeError("expressions don't have an order")
     def __lt__(self, other): raise TypeError("expressions don't have an order")
     def __ge__(self, other): raise TypeError("expressions don't have an order")
     def __gt__(self, other): raise TypeError("expressions don't have an order")
 
+    # }}}
+
 
 
 
@@ -247,25 +263,14 @@ class Variable(Leaf):
         else:
             return NotImplemented
 
-    def is_equal(self, other):
-        return (other.__class__ == self.__class__
-                and self.name == other.name)
-
-    def get_hash(self):
-        return hash((self.__class__, self.name))
-
     mapper_method = intern("map_variable")
 
 
 
 
 class Wildcard(Leaf):
-    def is_equal(self, other):
-        return (other.__class__ == self.__class__
-                and self.name == other.name)
-
-    def get_hash(self):
-        return hash((self.__class__, self.name))
+    def __getinitargs__(self):
+        return ()
 
     mapper_method = intern("map_wildcard")
 
@@ -280,18 +285,14 @@ class FunctionSymbol(AlgebraicLeaf):
     def __getinitargs__(self):
         return ()
 
-    def is_equal(self, other):
-        return self.__class__ == other.__class
-
-    def get_hash(self):
-        return hash(self.__class__)
-
     mapper_method = intern("map_function_symbol")
 
 
 
 
 
+# {{{ structural primitives
+
 class Call(AlgebraicLeaf):
     def __init__(self, function, parameters):
         self.function = function
@@ -310,14 +311,6 @@ class Call(AlgebraicLeaf):
     def __getinitargs__(self):
         return self.function, self.parameters
 
-    def is_equal(self, other):
-        return isinstance(other, Call) \
-               and (self.function == other.function) \
-               and (self.parameters == other.parameters)
-
-    def get_hash(self):
-        return hash((self.__class__, self.function, self.parameters))
-
     mapper_method = intern("map_call")
 
 
@@ -331,14 +324,6 @@ class Subscript(AlgebraicLeaf):
     def __getinitargs__(self):
         return self.aggregate, self.index
 
-    def is_equal(self, other):
-        return isinstance(other, Subscript) \
-               and (self.aggregate == other.aggregate) \
-               and (self.index == other.index)
-
-    def get_hash(self):
-        return hash((self.__class__, self.aggregate, self.index))
-
     mapper_method = intern("map_subscript")
 
 
@@ -352,18 +337,11 @@ class Lookup(AlgebraicLeaf):
     def __getinitargs__(self):
         return self.aggregate, self.name
 
-    def is_equal(self, other):
-        return isinstance(other, Lookup) \
-               and (self.aggregate == other.aggregate) \
-               and (self.name == other.name)
-
-    def get_hash(self):
-        return hash((self.__class__, self.aggregate, self.name))
-
     mapper_method = intern("map_lookup")
 
+# }}}
 
-
+# {{{ arithmetic primitives
 
 class Sum(Expression):
     def __init__(self, children):
@@ -374,10 +352,6 @@ class Sum(Expression):
     def __getinitargs__(self):
         return self.children
 
-    def is_equal(self, other):
-        return (isinstance(other, Sum) 
-                and (set(self.children) == set(other.children)))
-
     def __add__(self, other):
         if not is_valid_operand(other):
             return NotImplemented
@@ -415,9 +389,6 @@ class Sum(Expression):
             # FIXME: Right semantics?
             return True
 
-    def get_hash(self):
-        return hash((self.__class__, self.children))
-
     mapper_method = intern("map_sum")
 
 
@@ -431,10 +402,6 @@ class Product(Expression):
     def __getinitargs__(self):
         return self.children
 
-    def is_equal(self, other):
-        return (isinstance(other, Product) 
-                and (set(self.children) == set(other.children)))
-
     def __mul__(self, other):
         if not is_valid_operand(other):
             return NotImplemented
@@ -463,9 +430,6 @@ class Product(Expression):
                 return False
         return True
 
-    def get_hash(self):
-        return hash((self.__class__, self.children))
-
     mapper_method = intern("map_product")
 
 
@@ -490,9 +454,6 @@ class QuotientBase(Expression):
     def __nonzero__(self):
         return bool(self.numerator)
 
-    def get_hash(self):
-        return hash((self.__class__, self.numerator, self.denominator))
-
 
 
 
@@ -509,22 +470,12 @@ class Quotient(QuotientBase):
 
 
 class FloorDiv(QuotientBase):
-    def is_equal(self, other):
-        return isinstance(other, type(self)) \
-               and (self.numerator == other.numerator) \
-               and (self.denominator == other.denominator)
-
     mapper_method = intern("map_floor_div")
 
 
 
 
 class Remainder(QuotientBase):
-    def is_equal(self, other):
-        return self.__class__ == other.__class__ \
-               and (self.numerator == other.numerator) \
-               and (self.denominator == other.denominator)
-
     mapper_method = intern("map_remainder")
 
 
@@ -538,19 +489,105 @@ class Power(Expression):
     def __getinitargs__(self):
         return self.base, self.exponent
 
-    def is_equal(self, other):
-        return isinstance(other, Power) \
-               and (self.base == other.base) \
-               and (self.exponent == other.exponent)
+    mapper_method = intern("map_power")
 
-    def get_hash(self):
-        return hash((self.__class__, self.base, self.exponent))
+# }}}
 
-    mapper_method = intern("map_power")
+# {{{ comparisons, logic, conditionals
+
+class ComparisonOperator(Expression):
+    """Note: comparisons are not implicitly constructed by comparing
+    Expression objects.
+    """
+
+    def __init__(self, left, operator, right):
+        self.left = left
+        self.right = right
+        if not operator in [">", ">=", "==", "<", "<="]:
+            raise RuntimeError("invalid operator")
+
+    def __getinitargs__(self):
+        return self.left, self.operator, self.right
+
+    mapper_method = intern("map_comparison")
+
+
+
+
+class BooleanExpression(Expression):
+    pass
+
+class LogcialNot(BooleanExpression):
+    def __init__(self, child):
+        self.child = child
+
+    def __getinitargs__(self):
+        return (self.child, self.prefix)
+
+    mapper_method = intern("map_logical_not")
+
+
+
+
+class LogicalOr(BooleanExpression):
+    def __init__(self, children):
+        assert isinstance(children, tuple)
+
+        self.children = children
+
+    def __getinitargs__(self):
+        return self.children
+
+    mapper_method = intern("map_logical_or")
 
 
 
 
+class LogicalAnd(BooleanExpression):
+    def __init__(self, children):
+        assert isinstance(children, tuple)
+
+        self.children = children
+
+    def __getinitargs__(self):
+        return self.children
+
+    mapper_method = intern("map_logical_and")
+
+
+
+
+class IfPositive(Expression):
+    def __init__(self, criterion, then, else_):
+        self.criterion = criterion
+        self.then = then
+        self.else_ = else_
+
+    def __getinitargs__(self):
+        return self.criterion, self.then, self.else_
+
+    mapper_method = intern("map_if_positive")
+
+
+
+
+class _MinMaxBase(Expression):
+    def __init__(self, children):
+        self.children = children
+
+    def __getinitargs__(self):
+        return self.children
+
+class Min(_MinMaxBase):
+    mapper_method = intern("map_min")
+
+class Max(_MinMaxBase):
+    mapper_method = intern("map_max")
+
+# }}}
+
+# {{{
+
 class Vector(Expression):
     """An immutable sequence that you can compute with."""
 
@@ -603,9 +640,11 @@ class Vector(Expression):
         return Vector(tuple(other*x for x in self))
 
     def __div__(self, other):
+        import operator
         return Vector(tuple(operator.div(x, other) for x in self))
 
     def __truediv__(self, other):
+        import operator
         return Vector(tuple(operator.truediv(x, other) for x in self))
 
     def __floordiv__(self, other):
@@ -614,9 +653,6 @@ class Vector(Expression):
     def __getinitargs__(self):
         return self.children
 
-    def get_hash(self):
-        return hash((self.__class__, self.children))
-
     mapper_method = intern("map_vector")
 
 
@@ -630,13 +666,6 @@ class CommonSubexpression(Expression):
     def __getinitargs__(self):
         return (self.child, self.prefix)
 
-    def get_hash(self):
-        return hash((self.__class__, self.child))
-
-    def is_equal(self, other):
-        return (other.__class__ == self.__class__
-                and other.child == self.child)
-
     def get_extra_properties(self):
         return {}
 
@@ -646,55 +675,6 @@ class CommonSubexpression(Expression):
 
 
 
-class IfPositive(Expression):
-    def __init__(self, criterion, then, else_):
-        self.criterion = criterion
-        self.then = then
-        self.else_ = else_
-
-    def __getinitargs__(self):
-        return self.criterion, self.then, self.else_
-
-    def is_equal(self, other):
-        return (isinstance(other, IfPositive)
-                and self.criterion == other.criterion
-                and self.then == other.then
-                and self.else_ == other.else_)
-
-    def get_hash(self):
-        return hash((
-                self.__class__,
-                self.criterion,
-                self.then,
-                self.else_))
-
-    mapper_method = intern("map_if_positive")
-
-
-
-
-class _MinMaxBase(Expression):
-    def __init__(self, children):
-        self.children = children
-
-    def __getinitargs__(self):
-        return self.children
-
-    def is_equal(self, other):
-        return (isinstance(other, type(self))
-                and self.children == other.children)
-
-    def get_hash(self):
-        return hash((type(self), self.children))
-
-class Min(_MinMaxBase):
-    mapper_method = intern("map_min")
-
-class Max(_MinMaxBase):
-    mapper_method = intern("map_max")
-
-
-
 
 # intelligent makers ---------------------------------------------------------
 def make_variable(var_or_string):
@@ -891,3 +871,7 @@ def make_sym_vector(name, components):
     vfld = Variable(name)
     return join_fields(*[vfld[i] for i in components])
 
+
+
+
+# vim: foldmethod=marker
-- 
GitLab