From 20330bb84a721ba844a12111513c8d2f052025a2 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Fri, 11 Nov 2011 04:34:56 -0500 Subject: [PATCH] Use __getinitargs__ everywhere, introduce comparisons and Boolean primitives. --- pymbolic/primitives.py | 268 +++++++++++++++++++---------------------- 1 file changed, 126 insertions(+), 142 deletions(-) diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index d658d6b..07e9467 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -9,6 +9,8 @@ class Expression(object): Expression objects are immutable. """ + # {{{ arithmetic + def __add__(self, other): if not is_valid_operand(other): return NotImplemented @@ -140,6 +142,8 @@ class Expression(object): return 1 return Power(other, self) + # }}} + def __neg__(self): return -1*self @@ -169,7 +173,8 @@ class Expression(object): return "%s(%s)" % (self.__class__.__name__, initargs_str) - # hashable interface ------------------------------------------------------ + # {{{ hashable interface + def __eq__(self, other): """Provides equality testing with quick positive and negative paths based on L{id} and L{__hash__}(). @@ -199,19 +204,30 @@ class Expression(object): self.hash_value = self.get_hash() return self.hash_value - # hashable backend -------------------------------------------------------- + # }}} + + # {{{ hashable backend + def is_equal(self, other): - raise NotImplementedError("is_equal() in "+str(type(self))) + return (type(other) == type(self) + and self.__getinitargs__() == other.__getinitargs__()) def get_hash(self): - raise NotImplementedError("get_hash() in "+str(type(self))) + return hash((type(self),)+ self.__getinitargs__()) + + # }}} + + # {{{ comparison interface + + # /!\ Don't be tempted to resolve these to ComparisonOperator. - # comparison interface ---------------------------------------------------- def __le__(self, other): raise TypeError("expressions don't have an order") def __lt__(self, other): raise TypeError("expressions don't have an order") def __ge__(self, other): raise TypeError("expressions don't have an order") def __gt__(self, other): raise TypeError("expressions don't have an order") + # }}} + @@ -247,25 +263,14 @@ class Variable(Leaf): else: return NotImplemented - def is_equal(self, other): - return (other.__class__ == self.__class__ - and self.name == other.name) - - def get_hash(self): - return hash((self.__class__, self.name)) - mapper_method = intern("map_variable") class Wildcard(Leaf): - def is_equal(self, other): - return (other.__class__ == self.__class__ - and self.name == other.name) - - def get_hash(self): - return hash((self.__class__, self.name)) + def __getinitargs__(self): + return () mapper_method = intern("map_wildcard") @@ -280,18 +285,14 @@ class FunctionSymbol(AlgebraicLeaf): def __getinitargs__(self): return () - def is_equal(self, other): - return self.__class__ == other.__class - - def get_hash(self): - return hash(self.__class__) - mapper_method = intern("map_function_symbol") +# {{{ structural primitives + class Call(AlgebraicLeaf): def __init__(self, function, parameters): self.function = function @@ -310,14 +311,6 @@ class Call(AlgebraicLeaf): def __getinitargs__(self): return self.function, self.parameters - def is_equal(self, other): - return isinstance(other, Call) \ - and (self.function == other.function) \ - and (self.parameters == other.parameters) - - def get_hash(self): - return hash((self.__class__, self.function, self.parameters)) - mapper_method = intern("map_call") @@ -331,14 +324,6 @@ class Subscript(AlgebraicLeaf): def __getinitargs__(self): return self.aggregate, self.index - def is_equal(self, other): - return isinstance(other, Subscript) \ - and (self.aggregate == other.aggregate) \ - and (self.index == other.index) - - def get_hash(self): - return hash((self.__class__, self.aggregate, self.index)) - mapper_method = intern("map_subscript") @@ -352,18 +337,11 @@ class Lookup(AlgebraicLeaf): def __getinitargs__(self): return self.aggregate, self.name - def is_equal(self, other): - return isinstance(other, Lookup) \ - and (self.aggregate == other.aggregate) \ - and (self.name == other.name) - - def get_hash(self): - return hash((self.__class__, self.aggregate, self.name)) - mapper_method = intern("map_lookup") +# }}} - +# {{{ arithmetic primitives class Sum(Expression): def __init__(self, children): @@ -374,10 +352,6 @@ class Sum(Expression): def __getinitargs__(self): return self.children - def is_equal(self, other): - return (isinstance(other, Sum) - and (set(self.children) == set(other.children))) - def __add__(self, other): if not is_valid_operand(other): return NotImplemented @@ -415,9 +389,6 @@ class Sum(Expression): # FIXME: Right semantics? return True - def get_hash(self): - return hash((self.__class__, self.children)) - mapper_method = intern("map_sum") @@ -431,10 +402,6 @@ class Product(Expression): def __getinitargs__(self): return self.children - def is_equal(self, other): - return (isinstance(other, Product) - and (set(self.children) == set(other.children))) - def __mul__(self, other): if not is_valid_operand(other): return NotImplemented @@ -463,9 +430,6 @@ class Product(Expression): return False return True - def get_hash(self): - return hash((self.__class__, self.children)) - mapper_method = intern("map_product") @@ -490,9 +454,6 @@ class QuotientBase(Expression): def __nonzero__(self): return bool(self.numerator) - def get_hash(self): - return hash((self.__class__, self.numerator, self.denominator)) - @@ -509,22 +470,12 @@ class Quotient(QuotientBase): class FloorDiv(QuotientBase): - def is_equal(self, other): - return isinstance(other, type(self)) \ - and (self.numerator == other.numerator) \ - and (self.denominator == other.denominator) - mapper_method = intern("map_floor_div") class Remainder(QuotientBase): - def is_equal(self, other): - return self.__class__ == other.__class__ \ - and (self.numerator == other.numerator) \ - and (self.denominator == other.denominator) - mapper_method = intern("map_remainder") @@ -538,19 +489,105 @@ class Power(Expression): def __getinitargs__(self): return self.base, self.exponent - def is_equal(self, other): - return isinstance(other, Power) \ - and (self.base == other.base) \ - and (self.exponent == other.exponent) + mapper_method = intern("map_power") - def get_hash(self): - return hash((self.__class__, self.base, self.exponent)) +# }}} - mapper_method = intern("map_power") +# {{{ comparisons, logic, conditionals + +class ComparisonOperator(Expression): + """Note: comparisons are not implicitly constructed by comparing + Expression objects. + """ + + def __init__(self, left, operator, right): + self.left = left + self.right = right + if not operator in [">", ">=", "==", "<", "<="]: + raise RuntimeError("invalid operator") + + def __getinitargs__(self): + return self.left, self.operator, self.right + + mapper_method = intern("map_comparison") + + + + +class BooleanExpression(Expression): + pass + +class LogcialNot(BooleanExpression): + def __init__(self, child): + self.child = child + + def __getinitargs__(self): + return (self.child, self.prefix) + + mapper_method = intern("map_logical_not") + + + + +class LogicalOr(BooleanExpression): + def __init__(self, children): + assert isinstance(children, tuple) + + self.children = children + + def __getinitargs__(self): + return self.children + + mapper_method = intern("map_logical_or") +class LogicalAnd(BooleanExpression): + def __init__(self, children): + assert isinstance(children, tuple) + + self.children = children + + def __getinitargs__(self): + return self.children + + mapper_method = intern("map_logical_and") + + + + +class IfPositive(Expression): + def __init__(self, criterion, then, else_): + self.criterion = criterion + self.then = then + self.else_ = else_ + + def __getinitargs__(self): + return self.criterion, self.then, self.else_ + + mapper_method = intern("map_if_positive") + + + + +class _MinMaxBase(Expression): + def __init__(self, children): + self.children = children + + def __getinitargs__(self): + return self.children + +class Min(_MinMaxBase): + mapper_method = intern("map_min") + +class Max(_MinMaxBase): + mapper_method = intern("map_max") + +# }}} + +# {{{ + class Vector(Expression): """An immutable sequence that you can compute with.""" @@ -603,9 +640,11 @@ class Vector(Expression): return Vector(tuple(other*x for x in self)) def __div__(self, other): + import operator return Vector(tuple(operator.div(x, other) for x in self)) def __truediv__(self, other): + import operator return Vector(tuple(operator.truediv(x, other) for x in self)) def __floordiv__(self, other): @@ -614,9 +653,6 @@ class Vector(Expression): def __getinitargs__(self): return self.children - def get_hash(self): - return hash((self.__class__, self.children)) - mapper_method = intern("map_vector") @@ -630,13 +666,6 @@ class CommonSubexpression(Expression): def __getinitargs__(self): return (self.child, self.prefix) - def get_hash(self): - return hash((self.__class__, self.child)) - - def is_equal(self, other): - return (other.__class__ == self.__class__ - and other.child == self.child) - def get_extra_properties(self): return {} @@ -646,55 +675,6 @@ class CommonSubexpression(Expression): -class IfPositive(Expression): - def __init__(self, criterion, then, else_): - self.criterion = criterion - self.then = then - self.else_ = else_ - - def __getinitargs__(self): - return self.criterion, self.then, self.else_ - - def is_equal(self, other): - return (isinstance(other, IfPositive) - and self.criterion == other.criterion - and self.then == other.then - and self.else_ == other.else_) - - def get_hash(self): - return hash(( - self.__class__, - self.criterion, - self.then, - self.else_)) - - mapper_method = intern("map_if_positive") - - - - -class _MinMaxBase(Expression): - def __init__(self, children): - self.children = children - - def __getinitargs__(self): - return self.children - - def is_equal(self, other): - return (isinstance(other, type(self)) - and self.children == other.children) - - def get_hash(self): - return hash((type(self), self.children)) - -class Min(_MinMaxBase): - mapper_method = intern("map_min") - -class Max(_MinMaxBase): - mapper_method = intern("map_max") - - - # intelligent makers --------------------------------------------------------- def make_variable(var_or_string): @@ -891,3 +871,7 @@ def make_sym_vector(name, components): vfld = Variable(name) return join_fields(*[vfld[i] for i in components]) + + + +# vim: foldmethod=marker -- GitLab