From 3d3bc1ec24e5c58672deb3693687e6985a1ff103 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 6 Jul 2005 16:14:18 +0000 Subject: [PATCH] [pymbolic @ Arch-1:inform@tiker.net--iam-2005%pymbolic--mainline--1.0--patch-19] Many improvements --- setup.py | 4 +- src/__init__.py | 5 +- src/compiler.py | 11 +- src/mapper/__init__.py | 77 ++++++++ src/mapper/constant_detector.py | 27 --- src/mapper/dependency.py | 32 ++++ src/mapper/differentiator.py | 42 +++-- src/mapper/evaluator.py | 35 ++-- src/mapper/hash_generator.py | 40 +++++ src/mapper/mapper.py | 82 --------- src/mapper/stringifier.py | 27 +-- src/mapper/substitutor.py | 8 +- src/parser.py | 10 +- src/polynomial.py | 7 + src/primitives.py | 303 ++++++++++++++++++++------------ src/rational.py | 7 + 16 files changed, 435 insertions(+), 282 deletions(-) delete mode 100644 src/mapper/constant_detector.py create mode 100644 src/mapper/dependency.py create mode 100644 src/mapper/hash_generator.py delete mode 100644 src/mapper/mapper.py diff --git a/setup.py b/setup.py index b8cb339..beca570 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,6 @@ setup(name="pymbolic", author_email="inform@tiker.net", license = "BSD, like Python itself", url="http://news.tiker.net/software/pymbolic", - packages=["pymbolic"], - package_dir={"pymbolic": "src"} + packages=["pymbolic", "pymbolic.mapper"], + package_dir={"pymbolic": "src", "pymbolic.mapper":"src/mapper"} ) diff --git a/src/__init__.py b/src/__init__.py index c660e67..9e5b200 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -3,7 +3,7 @@ import compiler import mapper.evaluator import mapper.stringifier -import mapper.constant_detector +import mapper.dependency import mapper.substitutor import mapper.differentiator @@ -11,7 +11,7 @@ parse = parser.parse evaluate = mapper.evaluator.evaluate compile = compiler.compile stringify = mapper.stringifier.stringify -is_constant = mapper.constant_detector.is_constant +is_constant = mapper.dependency.is_constant substitute = mapper.substitutor.substitute differentiate = mapper.differentiator.differentiate @@ -21,6 +21,7 @@ if __name__ == "__main__": print ex print evaluate(ex, {"alpha":5, "cos":math.cos, "x":-math.pi, "pi":math.pi}) + print hash(ex) print is_constant(ex) print substitute(ex, {"alpha": ex}) ex2 = parse("cos(x**2/x)") diff --git a/src/compiler.py b/src/compiler.py index 0ae13d5..d365d31 100644 --- a/src/compiler.py +++ b/src/compiler.py @@ -1,4 +1,5 @@ -import mapper.stringifier +import pymbolic.mapper.stringifier +import pymbolic.mapper.dependency @@ -16,14 +17,6 @@ class CompiledExpression: # FIXME used_variables = sets.Set() - def addVariable(var): - try: - var = self.VariableSubstitutions[var] - except: - pass - used_variables.add(var) - return var - pythonified = mapper.stringifier.stringify(self.Expression) used_variables = list(used_variables) diff --git a/src/mapper/__init__.py b/src/mapper/__init__.py index e69de29..65d485f 100644 --- a/src/mapper/__init__.py +++ b/src/mapper/__init__.py @@ -0,0 +1,77 @@ +class CombineMapper: + def combine(self, values): + raise NotImplementedError + + def map_call(self, expr): + return self.combine([expr.function.invoke_mapper(self)] + + [child.invoke_mapper(self) + for child in expr.parameters]) + + def map_subscript(self, expr): + return expr.__class__(expr.aggregate.invoke_mapper(self), + expr.index.invoke_mapper(self)) + + def map_negation(self, expr): + return expr.child.invoke_mapper(self) + + def map_sum(self, expr): + return self.combine(child.invoke_mapper(self) + for child in expr.children) + + def map_rational(self, expr): + return self.combine((expr.numerator.invoke_mapper(self), + expr.denominator.invoke_mapper(self))) + + map_product = map_sum + + def map_power(self, expr): + return self.combine((expr.base.invoke_mapper(self), + expr.exponent.invoke_mapper(self))) + + def map_polynomial(self, expr): + raise NotImplementedError + + map_list = map_sum + + + + + +class IdentityMapper: + def map_constant(self, expr): + return expr + + def map_variable(self, expr): + return expr + + def map_call(self, expr): + return expr.__class__(expr.function.invoke_mapper(self), + tuple(child.invoke_mapper(self) + for child in expr.parameters)) + + def map_subscript(self, expr): + return expr.__class__(expr.aggregate.invoke_mapper(self), + expr.index.invoke_mapper(self)) + + def map_negation(self, expr): + return expr.__class__(expr.child.invoke_mapper(self)) + + def map_sum(self, expr): + return expr.__class__(tuple(child.invoke_mapper(self) + for child in expr.children)) + + map_product = map_sum + + def map_rational(self, expr): + return expr.__class__(expr.numerator.invoke_mapper(self), + expr.denominator.invoke_mapper(self)) + + def map_power(self, expr): + return expr.__class__(expr.base.invoke_mapper(self), + expr.exponent.invoke_mapper(self)) + + def map_polynomial(self, expr): + raise NotImplementedError + + map_list = map_sum + diff --git a/src/mapper/constant_detector.py b/src/mapper/constant_detector.py deleted file mode 100644 index 28224d5..0000000 --- a/src/mapper/constant_detector.py +++ /dev/null @@ -1,27 +0,0 @@ -import mapper -import operator - - - - -class ConstantDetectionMapper(mapper.CombineMapper): - def __init__(self, with_respect_to=None): - self.WRT = with_respect_to - - def combine(self, values): - return reduce(operator.and_, values) - - def map_constant(self, expr): - return True - - def map_variable(self, expr): - if self.WRT: - return expr.Name not in self.WRT - else: - return False - - - - -def is_constant(expr, with_respect_to=None): - return expr.invoke_mapper(ConstantDetectionMapper(with_respect_to)) diff --git a/src/mapper/dependency.py b/src/mapper/dependency.py new file mode 100644 index 0000000..35ab523 --- /dev/null +++ b/src/mapper/dependency.py @@ -0,0 +1,32 @@ +import sets +import operator + +import pymbolic.mapper + + + + +class DependencyMapper(pymbolic.mapper.CombineMapper): + def combine(self, values): + return reduce(operator.or_, values) + + def map_constant(self, expr): + return sets.Set() + + def map_variable(self, expr): + return sets.Set([expr]) + + map_subscript = map_variable + + + + +def get_dependencies(expr): + return expr.invoke_mapper(DependencyMapper()) + + + + +def is_constant(expr, with_respect_to=None): + return sets.Set(with_respect_to) <= get_dependencies(expr) + diff --git a/src/mapper/differentiator.py b/src/mapper/differentiator.py index 1bf9dc6..6131e3b 100644 --- a/src/mapper/differentiator.py +++ b/src/mapper/differentiator.py @@ -1,16 +1,15 @@ -import constant_detector -import evaluator -import primitives -import mapper import math +import pymbolic +import pymbolic.primitives as primitives + def map_math_functions_by_name(i, func, pars): if not isinstance(func, primitives.Variable): raise RuntimeError, "No derivative of non-constant function "+str(func) - name = func.Name + name = func.name if name == "sin" and len(pars) == 1: return primitives.Constant("cos")(*pars) @@ -25,6 +24,9 @@ def map_math_functions_by_name(i, func, pars): else: return primitives.Constant(name+"'")(*pars) + + + class DifferentiationMapper: def __init__(self, variable, parameters, func_map): self.Variable = variable @@ -35,34 +37,36 @@ class DifferentiationMapper: return primitives.Constant(0) def map_variable(self, expr): - if expr.Name is self.Variable: + if expr.name == self.Variable: return primitives.Constant(1) - elif expr.Name in self.Parameters: + elif expr.name in self.parameters: return expr else: return primitives.Constant(0) def map_call(self, expr): return primitives.make_sum(tuple( - self.FunctionMap(i, expr.Function, expr.Parameters) + self.FunctionMap(i, expr.function, expr.parameters) * par.invoke_mapper(self) - for i, par in enumerate(expr.Parameters) + for i, par in enumerate(expr.parameters) if not self._isc(par))) + map_subscript = map_variable + + def map_neg(self, expr): + return -expr.child.invoke_mapper(self) + def map_sum(self, expr): return primitives.make_sum(tuple(child.invoke_mapper(self) for child in expr.Children if not self._isc(child))) - def map_neg(self, expr): - return -expr.Child.invoke_mapper(self) - def map_product(self, expr): return primitives.make_sum(tuple( - primitives.make_product(expr.Children[0:i] + + primitives.make_product(expr.children[0:i] + (child.invoke_mapper(self),) + - expr.Children[i+1:]) - for i, child in enumerate(expr.Children) + expr.children[i+1:]) + for i, child in enumerate(expr.children) if not self._isc(child))) def map_rational(self, expr): @@ -83,8 +87,8 @@ class DifferentiationMapper: return (f.invoke_mapper(self)*g-g.invoke_mapper(self)*f)/g**2 def map_power(self, expr): - f = expr.Child1 - g = expr.Child2 + f = expr.base + g = expr.exponent f_const = self._isc(f) g_const = self._isc(g) @@ -106,11 +110,11 @@ class DifferentiationMapper: raise NotImplementedError def _isc(self,subexp): - return constant_detector.is_constant(subexp, [self.Variable]) + return pymbolic.is_constant(subexp, [self.Variable]) def _eval(self,subexp): try: - return primitives.Constant(evaluator.evaluate(subexp)) + return primitives.Constant(pymbolic.evaluate(subexp)) except KeyError: return subexp diff --git a/src/mapper/evaluator.py b/src/mapper/evaluator.py index 068c7d4..13c59ab 100644 --- a/src/mapper/evaluator.py +++ b/src/mapper/evaluator.py @@ -3,28 +3,31 @@ class EvaluationMapper: self.Context = context def map_constant(self, expr): - return expr.Value + return expr.value def map_variable(self, expr): - return self.Context[expr.Name] + return self.Context[expr.name] def map_call(self, expr): - return expr.Function.invoke_mapper(self)( + return expr.function.invoke_mapper(self)( *[par.invoke_mapper(self) - for par in expr.Parameters]) + for par in expr.parameters]) - def map_sum(self, expr): - return sum(child.invoke_mapper(self) - for child in expr.Children) + def map_subscript(self, expr): + return expr.aggregate.invoke_mapper(self)[expr.index.invoke_mapper(self)] def map_negation(self, expr): - return -expr.Child.invoke_mapper(self) + return -expr.child.invoke_mapper(self) + + def map_sum(self, expr): + return sum(child.invoke_mapper(self) + for child in expr.children) def map_product(self, expr): - if len(expr.Children) == 0: + if len(expr.children) == 0: return 1 - result = expr.Children[0].invoke_mapper(self) - for child in expr.Children[1:]: + result = expr.children[0].invoke_mapper(self) + for child in expr.children[1:]: result *= child.invoke_mapper(self) return result @@ -32,16 +35,10 @@ class EvaluationMapper: return expr.numerator.invoke_mapper(self) / expr.denominator.invoke_mapper(self) def map_power(self, expr): - return expr.Child1.invoke_mapper(self) ** expr.Child2.invoke_mapper(self) + return expr.base.invoke_mapper(self) ** expr.exponent.invoke_mapper(self) def map_polynomial(self, expr): - if len(expr.Children) == 0: - return 0 - result = expr.Children[-1].invoke_mapper(self) - b_ev = expr.Base.invoke_mapper(self) - for child in expr.Children[-2::-1]: - result = result * b_ev + child.invoke_mapper(self) - return result + raise NotImplementedError def map_list(self, expr): return [child.invoke_mapper(self) for child in expr.Children] diff --git a/src/mapper/hash_generator.py b/src/mapper/hash_generator.py new file mode 100644 index 0000000..f7d1494 --- /dev/null +++ b/src/mapper/hash_generator.py @@ -0,0 +1,40 @@ +class HashMapper: + def map_constant(self, expr): + return 0x131 ^ hash(expr.value) + + def map_variable(self, expr): + return 0x111 ^ hash(expr.name) + + def map_call(self, expr): + return hash(expr.function) ^ hash(expr.parameters) + + def map_subscript(self, expr): + return 0x123 \ + ^ hash(expr.aggregate) \ + ^ hash(expr.index) + + def map_negation(self, expr): + return ~ hash(expr.child) + + def map_sum(self, expr): + return 0x456 ^ hash(expr.children) + + def map_product(self, expr): + return 0x789 ^ hash(expr.children) + + def map_rational(self, expr): + return 0xabc \ + ^ hash(expr.numerator) \ + ^ hash(expr.denominator) + + def map_power(self, expr): + return 0xdef \ + ^ hash(expr.base) \ + ^ hash(expr.exponent) + + def map_polynomial(self, expr): + raise NotImplementedError + + def map_product(self, expr): + return 0x124 ^ hash(expr.children) + diff --git a/src/mapper/mapper.py b/src/mapper/mapper.py deleted file mode 100644 index 7f17f05..0000000 --- a/src/mapper/mapper.py +++ /dev/null @@ -1,82 +0,0 @@ -class ByArityMapper: - def map_sum(self, expr): - return self.map_n_ary(expr) - - def map_product(self, expr): - return self.map_n_ary(expr) - - def map_negation(self, expr): - return self.map_unary(expr) - - def map_power(self, expr): - return self.map_binary(expr) - - def map_list(self, expr): - return self.map_n_ary(expr) - - - - -class CombineMapper(ByArityMapper): - def combine(self, values): - raise NotImplementedError - - def map_unary(self, expr): - return expr.Child.invoke_mapper(self) - - def map_binary(self, expr): - return self.combine((expr.Child1.invoke_mapper(self), - expr.Child2.invoke_mapper(self))) - - def map_rational(self, expr): - return self.combine((expr.numerator.invoke_mapper(self), - expr.denominator.invoke_mapper(self))) - - def map_n_ary(self, expr): - return self.combine(child.invoke_mapper(self) - for child in expr.Children) - - def map_polynomial(self, expr): - return self.combine([expr.Base.invoke_mapper(self)] + - [child.invoke_mapper(self) - for child in expr.Children]) - - def map_call(self, expr): - return self.combine([expr.Function.invoke_mapper(self)] + - [child.invoke_mapper(self) - for child in expr.Parameters]) - - - - -class IdentityMapper(ByArityMapper): - def map_unary(self, expr): - return expr.__class__(expr.Child.invoke_mapper(self)) - - def map_binary(self, expr): - return expr.__class__(expr.Child1.invoke_mapper(self), - expr.Child2.invoke_mapper(self)) - - def map_n_ary(self, expr): - return expr.__class__(tuple(child.invoke_mapper(self) - for child in expr.Children)) - - def map_quotient(self, expr): - return expr.__class__(expr.numerator.invoke_mapper(self), - expr.denominator.invoke_mapper(self)) - - def map_constant(self, expr): - return expr - - def map_variable(self, expr): - return expr - - def map_polynomial(self, expr): - return expr.__class__(expr.Base.invoke_mapper(self), - tuple(child.invoke_mapper(self) - for child in expr.Children)) - - def map_call(self, expr): - return expr.__class__(expr.Function.invoke_mapper(self), - tuple(child.invoke_mapper(self) - for child in expr.Parameters)) diff --git a/src/mapper/stringifier.py b/src/mapper/stringifier.py index d6ad121..0df6794 100644 --- a/src/mapper/stringifier.py +++ b/src/mapper/stringifier.py @@ -1,37 +1,42 @@ class StringifyMapper: def map_constant(self, expr): - return str(expr.Value) + return str(expr.value) def map_variable(self, expr): - return expr.Name + return expr.name def map_call(self, expr): return "%s(%s)" % \ - (expr.Function.invoke_mapper(self), - ", ".join(i.invoke_mapper(self) for i in expr.Parameters)) + (expr.function.invoke_mapper(self), + ", ".join(i.invoke_mapper(self) for i in expr.parameters)) - def map_sum(self, expr): - return "(%s)" % "+".join(i.invoke_mapper(self) for i in expr.Children) + def map_subscript(self, expr): + return "%s[%s]" % \ + (expr.aggregate.invoke_mapper(self), + expr.index.invoke_mapper(self)) def map_negation(self, expr): - return "-%s" % expr.Child.invoke_mapper(self) + return "-%s" % expr.child.invoke_mapper(self) + + def map_sum(self, expr): + return "(%s)" % "+".join(i.invoke_mapper(self) for i in expr.children) def map_product(self, expr): - return "(%s)" % "*".join(i.invoke_mapper(self) for i in expr.Children) + return "(%s)" % "*".join(i.invoke_mapper(self) for i in expr.children) def map_rational(self, expr): return "(%s/%s)" % (expr.numerator.invoke_mapper(self), expr.denominator.invoke_mapper(self)) def map_power(self, expr): - return "(%s**%s)" % (expr.Child1.invoke_mapper(self), - expr.Child2.invoke_mapper(self)) + return "(%s**%s)" % (expr.base.invoke_mapper(self), + expr.exponent.invoke_mapper(self)) def map_polynomial(self, expr): raise NotImplementedError def map_list(self, expr): - return "[%s]" % ", ".join([i.invoke_mapper(self) for i in expr.Children]) + return "[%s]" % ", ".join([i.invoke_mapper(self) for i in expr.children]) diff --git a/src/mapper/substitutor.py b/src/mapper/substitutor.py index c49b5cc..35dd447 100644 --- a/src/mapper/substitutor.py +++ b/src/mapper/substitutor.py @@ -7,9 +7,15 @@ class SubstitutionMapper(mapper.IdentityMapper): def __init__(self, variable_assignments): self.Assignments = variable_assignments + def map_subscript(self, expr): + try: + return self.Assignments[expr] + except KeyError: + return expr + def map_variable(self, expr): try: - return self.Assignments[expr.Name] + return self.Assignments[expr] except KeyError: return expr diff --git a/src/parser.py b/src/parser.py index 0e52790..163e00d 100644 --- a/src/parser.py +++ b/src/parser.py @@ -1,4 +1,4 @@ -import primitives +import pymbolic.primitives as primitives import pytools.lex _imaginary = intern("imaginary") @@ -97,13 +97,19 @@ def parse(expr_str): tuple(parse_expr_list(pstate))) pstate.expect(_closepar) did_something = True + elif next_tag is _openbracket and _PREC_CALL >= min_precedence: + pstate.advance() + pstate.expect_not_end() + left_exp = primitives.Subscript(left_exp, parse_expression(pstate)) + pstate.expect(_closebracket) + did_something = True elif next_tag is _plus and _PREC_PLUS >= min_precedence: pstate.advance() left_exp = parse_expression(pstate, _PREC_PLUS) did_something = True elif next_tag is _minus and _PREC_PLUS >= min_precedence: pstate.advance() - left_exp -= primitives.Negation(parse_expression(pstate, _PREC_PLUS)) + left_exp -= parse_expression(pstate, _PREC_PLUS) did_something = True elif next_tag is _times and _PREC_TIMES >= min_precedence: pstate.advance() diff --git a/src/polynomial.py b/src/polynomial.py index 574d662..52c6e87 100644 --- a/src/polynomial.py +++ b/src/polynomial.py @@ -48,6 +48,13 @@ class Polynomial(object): def __nonzero__(self): return len(self.Data) != 0 + def __eq__(self, other): + return isinstance(other, Polynomial) \ + and (self.Base == other.Base) \ + and (self.Data == other.Data) + def __ne__(self, other): + return not self.__eq__(other) + def __neg__(self): return Polynomial(self.Base, [(exp, -coeff) diff --git a/src/primitives.py b/src/primitives.py index ff8848b..6422186 100644 --- a/src/primitives.py +++ b/src/primitives.py @@ -1,10 +1,15 @@ +import traits import rational as rat -import mapper.stringifier +import pymbolic.mapper.stringifier +import pymbolic.mapper.hash_generator class Expression(object): + def __ne__(self, other): + return not self.__eq__(other) + def __add__(self, other): if not isinstance(other, Expression): other = Constant(other) @@ -97,142 +102,124 @@ class Expression(object): return Call(self, tuple(processed)) - def __str__(self): - return self.invoke_mapper(mapper.stringifier.StringifyMapper()) - -class UnaryExpression(Expression): - def __init__(self, child): - self.Child = child - def __hash__(self): - return ~hash(self.Child) - -class BinaryExpression(Expression): - def __init__(self, child1, child2): - self.Child1 = child1 - self.Child2 = child2 - - def __hash__(self): - return hash(self.Child1) ^ hash(self.Child2) - -class NAryExpression(Expression): - def __init__(self, children): - assert isinstance(children, tuple) - self.Children = children + try: + return self._HashValue + except AttributeError: + self._HashValue = self.invoke_mapper(pymbolic.mapper.hash_generator.HashMapper()) + return self._HashValue - def __hash__(self): - return hash(self.Children) + def __str__(self): + return self.invoke_mapper(pymbolic.mapper.stringifier.StringifyMapper()) class Constant(Expression): def __init__(self, value): - self.Value = value + self._Value = value - def __eq__(self, other): - return isinstance(other, Constant) and self.Value == other.Value + def _value(self): + return self._Value + value = property(_value) - def __ne__(self, other): - return not isinstance(other, Constant) or self.Value != other.Value + def __eq__(self, other): + return isinstance(other, Constant) and self._Value == other._Value def __add__(self, other): if not isinstance(other, Expression): - return Constant(self.Value + other) + return Constant(self._Value + other) if isinstance(other, Constant): - return Constant(self.Value + other.Value) - if self.Value == 0: + return Constant(self._Value + other._Value) + if self._Value == 0: return other return Expression.__add__(self, other) def __radd__(self, other): if not isinstance(other, Expression): - return Constant(other + self.Value) - if self.Value == 0: + return Constant(other + self._Value) + if self._Value == 0: return other return Expression.__radd__(self, other) def __sub__(self, other): if not isinstance(other, Expression): - return Constant(self.Value - other) + return Constant(self._Value - other) if isinstance(other, Constant): - return Constant(self.Value - other.Value) - if self.Value == 0: + return Constant(self._Value - other._Value) + if self._Value == 0: return Negation(other) return Expression.__sub__(self, other) def __rsub__(self, other): if not isinstance(other, Expression): - return Constant(other - self.Value) - if self.Value == 0: + return Constant(other - self._Value) + if self._Value == 0: return other return Expression.__rsub__(self, other) def __mul__(self, other): if not isinstance(other, Expression): - return Constant(self.Value * other) + return Constant(self._Value * other) if isinstance(other, Constant): - return Constant(self.Value * other.Value) - if self.Value == 1: + return Constant(self._Value * other._Value) + if self._Value == 1: return other - if self.Value == 0: + if self._Value == 0: return self return Expression.__mul__(self, other) def __rmul__(self, other): if not isinstance(other, Expression): - return Constant(other * self.Value) - if self.Value == 1: + return Constant(other * self._Value) + if self._Value == 1: return other - if self.Value == 0: + if self._Value == 0: return self return Expression.__rmul__(self, other) def __div__(self, other): if not isinstance(other, Expression): - return Constant(self.Value / other) + return Constant(self._Value / other) if isinstance(other, Constant): - return Constant(self.Value / other.Value) - if self.Value == 0: + return Constant(self._Value / other._Value) + if self._Value == 0: return self return Expression.__div__(self, other) def __rdiv__(self, other): if not isinstance(other, Expression): - return Constant(other / self.Value) - if self.Value == 1: + return Constant(other / self._Value) + if self._Value == 1: return other return Expression.__rdiv__(self, other) def __pow__(self, other): if not isinstance(other, Expression): - return Constant(self.Value ** other) + return Constant(self._Value ** other) if isinstance(other, Constant): - return Constant(self.Value ** other.Value) - if self.Value == 1: + return Constant(self._Value ** other._Value) + if self._Value == 1: return self return Expression.__pow__(self, other) def __rpow__(self, other): if not isinstance(other, Expression): - return Constant(other ** self.Value) - if self.Value == 0: + return Constant(other ** self._Value) + if self._Value == 0: return Constant(1) - if self.Value == 1: + if self._Value == 1: return other return Expression.__rpow__(self, other) def __neg__(self): - return Constant(-self.Value) + return Constant(-self._Value) def __call__(self, *pars): for par in pars: if isinstance(par, Expression): return Expression.__call__(self, *pars) - return self.Value(*pars) + return self._Value(*pars) def __nonzero__(self): - return bool(self.Value) - - def __hash__(self): - return hash(self.Value) + return bool(self._Value) def invoke_mapper(self, mapper): return mapper.map_constant(self) @@ -240,50 +227,110 @@ class Constant(Expression): class Variable(Expression): def __init__(self, name): - self.Name = name + self._Name = name - def __hash__(self): - return hash(self.Name) + def _name(self): + return self._Name + name = property(_name) + + def __eq__(self, other): + return isinstance(other, Variable) and self._Name == other._Name def invoke_mapper(self, mapper): return mapper.map_variable(self) class Call(Expression): - def __init__(self, func, parameters): - self.Function = func - self.Parameters = parameters + def __init__(self, function, parameters): + self._Function = function + self._Parameters = parameters + + def _function(self): + return self._Function + function = property(_function) + + def _parameters(self): + return self._Parameters + parameters = property(_parameters) + + def __eq__(self, other): + return isinstance(other, Call) \ + and (self._Function == other._Function) \ + and (self._Parameters == other._Parameters) def invoke_mapper(self, mapper): return mapper.map_call(self) - def __hash__(self): - return hash(self.Function) ^ hash(self.Parameters) +class Subscript(Expression): + def __init__(self, aggregate, index): + self._Aggregate = aggregate + self._Index = index + + def _aggregate(self): + return self._Aggregate + aggregate = property(_aggregate) + + def _index(self): + return self._Index + index = property(_index) + + def __eq__(self, other): + return isinstance(other, Subscript) \ + and (self._Aggregate == other._Aggregate) \ + and (self._Index == other._Index) + + def invoke_mapper(self, mapper): + return mapper.map_subscript(self) + +class Negation(Expression): + def __init__(self, child): + self._Child = child + + def _child(self): + return self._Child + child = property(_child) + + def __eq__(self, other): + return isinstance(other, Negation) and (self.Child == other.Child) + + def invoke_mapper(self, mapper): + return mapper.map_negation(self) + +class Sum(Expression): + def __init__(self, children): + assert isinstance(children, tuple) + self._Children = children + + def _children(self): + return self._Children + children = property(_children) + + def __eq__(self, other): + return isinstance(other, Sum) and (self._Children == other._Children) -class Sum(NAryExpression): def __add__(self, other): if not isinstance(other, Expression): other = Constant(other) elif isinstance(other, Sum): - return Sum(self.Children + other.Children) - return Sum(self.Children + (other,)) + return Sum(self._Children + other._Children) + return Sum(self._Children + (other,)) def __radd__(self, other): if not isinstance(other, Expression): other = Constant(other) elif isinstance(other, Sum): - return Sum(other.Children + self.Children) - return Sum(other, + self.Children) + return Sum(other._Children + self._Children) + return Sum(other, + self._Children) def __sub__(self, other): if not isinstance(other, Expression): other = Constant(other) - return Sum(self.Children + (-other,)) + return Sum(self._Children + (-other,)) def __nonzero__(self): - if len(self.Children) == 0: + if len(self._Children) == 0: return True - elif len(self.Children) == 1: - return bool(self.Children[0]) + elif len(self._Children) == 1: + return bool(self._Children[0]) else: # FIXME: Right semantics? return True @@ -291,27 +338,34 @@ class Sum(NAryExpression): def invoke_mapper(self, mapper): return mapper.map_sum(self) -class Negation(UnaryExpression): - def invoke_mapper(self, mapper): - return mapper.map_negation(self) +class Product(Expression): + def __init__(self, children): + assert isinstance(children, tuple) + self._Children = children + + def _children(self): + return self._Children + children = property(_children) + + def __eq__(self, other): + return isinstance(other, Product) and (self._Children == other._Children) -class Product(NAryExpression): def __mul__(self, other): if not isinstance(other, Expression): other = Constant(other) elif isinstance(other, Product): - return Product(self.Children + other.Children) - return Product(self.Children + (other,)) + return Product(self._Children + other._Children) + return Product(self._Children + (other,)) def __rmul__(self, other): if not isinstance(other, Expression): other = Constant(other) elif isinstance(other, Product): - return Product(other.Children + self.Children) - return Product(other, + self.Children) + return Product(other._Children + self._Children) + return Product(other, + self._Children) def __nonzero__(self): - for i in self.Children: + for i in self._Children: if not i: return False return True @@ -332,55 +386,87 @@ class QuotientExpression(Expression): return self.Denominator denominator=property(_den) + def __eq__(self, other): + return isinstance(other, Subscript) \ + and (self.Numerator == other.Numerator) \ + and (self.Denominator == other.Denominator) + def __nonzero__(self): return bool(self.Numerator) - def __hash__(self): - return 0xa0a0aa ^ hash(self.Numerator) ^ hash(self.Denominator) - def invoke_mapper(self, mapper): return mapper.map_rational(self) class RationalExpression(Expression): - def __init__(self, numerator, denominator=1): - self.Numerator = numerator - self.Denominator = denominator + def __init__(self, rational): + self.Rational = rational def _num(self): - return self.Numerator + return self.Rational.numerator numerator=property(_num) def _den(self): - return self.Denominator + return self.Rational.denominator denominator=property(_den) def __nonzero__(self): - return bool(self.Numerator) - - def __hash__(self): - return 0xa0a0aa ^ hash(self.Numerator) ^ hash(self.Denominator) + return bool(self.Rational) def invoke_mapper(self, mapper): return mapper.map_rational(self) -class Power(BinaryExpression): +class Power(Expression): + def __init__(self, base, exponent): + self._Base = base + self._Exponent = exponent + + def _base(self): + return self._Base + base = property(_base) + + def _exponent(self): + return self._Exponent + exponent = property(_exponent) + + def __eq__(self, other): + return isinstance(other, Power) \ + and (self._Base == other._Base) \ + and (self._Exponent == other._Exponent) + def invoke_mapper(self, mapper): return mapper.map_power(self) class PolynomialExpression(Expression): def __init__(self, base=None, data=None, polynomial=None): if polynomial: - self.Polynomial = polynomial + self._Polynomial = polynomial else: - self.Polynomial = polynomial.Polynomial(base, data) + self._Polynomial = polynomial.Polynomial(base, data) + + def _polynomial(self): + return self._Polynomial + polynomial = property(_polynomial) - def __hash__(self): - return hash(self.Polynomial) + def __eq__(self, other): + return isinstance(other, PolynomialExpression) \ + and (self._Polynomial == other._Polynomial) def invoke_mapper(self, mapper): return mapper.map_polynomial(self) -class List(NAryExpression): +class List(Expression): + def __init__(self, children): + assert isinstance(children, tuple) + self._Children = children + + def _children(self): + return self.Children + children = property(_children) + + def __eq__(self, other): + return isinstance(other, List) \ + and (self.Children == other.Children) + def invoke_mapper(self, mapper): return mapper.map_list(self) @@ -410,7 +496,8 @@ def polynomial_from_expression(expression): def make_quotient(numerator, denominator): try: - if isinstance(traits.traits(numerator, denominator), EuclideanRingTraits): + if isinstance(traits.common_traits(numerator, denominator), + EuclideanRingTraits): return RationalExpression(numerator, denominator) except traits.NoCommonTraitsError: pass diff --git a/src/rational.py b/src/rational.py index edb526b..1723571 100644 --- a/src/rational.py +++ b/src/rational.py @@ -25,6 +25,13 @@ class Rational(object): def __neg__(self): return Rational(-self.Numerator, self.Denominator) + def __eq__(self): + if not isinstance(other, Rational): + other = Rational(other) + + return self.Numerator == other.Numerator and \ + self.Denominator == other.Denominator + def __add__(self, other): if not isinstance(other, Rational): other = Rational(other) -- GitLab