diff --git a/src/mapper/differentiator.py b/src/mapper/differentiator.py index b377dce4251e4ab049eb254cd3132f79fc5e4d92..1bf9dc639485c93facd09dfa6d2bc6ca6f8bf5a2 100644 --- a/src/mapper/differentiator.py +++ b/src/mapper/differentiator.py @@ -65,9 +65,9 @@ class DifferentiationMapper: for i, child in enumerate(expr.Children) if not self._isc(child))) - def map_quotient(self, expr): - f = expr.Child1 - g = expr.Child2 + def map_rational(self, expr): + f = expr.numerator + g = expr.denominator f_const = self._isc(f) g_const = self._isc(g) diff --git a/src/mapper/mapper.py b/src/mapper/mapper.py index 3175a725a3b6078bb5fda60948e267a56e2b5101..7f17f05cb9193c6f21c91f2f17acb89a0fb4cdb2 100644 --- a/src/mapper/mapper.py +++ b/src/mapper/mapper.py @@ -37,14 +37,14 @@ class CombineMapper(ByArityMapper): for child in expr.Children) def map_polynomial(self, expr): - return self.combine((expr.Base.invoke_mapper(self)) - + (child.invoke_mapper(self) - for child in expr.Children)) + return self.combine([expr.Base.invoke_mapper(self)] + + [child.invoke_mapper(self) + for child in expr.Children]) def map_call(self, expr): - return self.combine((expr.Function.invoke_mapper(self)) - + (child.invoke_mapper(self) - for child in expr.Parameters)) + return self.combine([expr.Function.invoke_mapper(self)] + + [child.invoke_mapper(self) + for child in expr.Parameters]) diff --git a/src/mapper/stringifier.py b/src/mapper/stringifier.py index 71c9283ecad572f897c6e1b0e5ae48fb755193a7..d6ad121a2b0522661d174aee94f50c6f15a9f47a 100644 --- a/src/mapper/stringifier.py +++ b/src/mapper/stringifier.py @@ -19,9 +19,9 @@ class StringifyMapper: def map_product(self, expr): return "(%s)" % "*".join(i.invoke_mapper(self) for i in expr.Children) - def map_quotient(self, expr): - return "(%s/%s)" % (expr.Child1.invoke_mapper(self), - expr.Child2.invoke_mapper(self)) + def map_rational(self, expr): + return "(%s/%s)" % (expr.numerator.invoke_mapper(self), + expr.denominator.invoke_mapper(self)) def map_power(self, expr): return "(%s**%s)" % (expr.Child1.invoke_mapper(self), diff --git a/src/primitives.py b/src/primitives.py index 613f36ca59d6f7f389caacf52ea93b4fad88be97..ff8848b1eaf3671a140384f3f52fae5c97c4949f 100644 --- a/src/primitives.py +++ b/src/primitives.py @@ -1,3 +1,4 @@ +import rational as rat import mapper.stringifier @@ -60,11 +61,11 @@ class Expression(object): other = Constant(other) if not (other-1): return self - return RationalExpression(self, other) + return make_quotient(self, other) def __rdiv__(self, other): assert not isinstance(other, Expression) - return RationalExpression(Constant(other), self) + return make_quotient(Constant(other), self) def __pow__(self, other): if not isinstance(other, Expression): @@ -285,7 +286,7 @@ class Sum(NAryExpression): return bool(self.Children[0]) else: # FIXME: Right semantics? - return False + return True def invoke_mapper(self, mapper): return mapper.map_sum(self) @@ -318,30 +319,49 @@ class Product(NAryExpression): def invoke_mapper(self, mapper): return mapper.map_product(self) -class RationalExpression(Expression): - def __init__(self, numerator=None, denominator=1, rational=None): - if rational: - self.Rational = rational - else: - self.Rational = rat.Rational(numerator, denominator) +class QuotientExpression(Expression): + def __init__(self, numerator, denominator=1): + self.Numerator = numerator + self.Denominator = denominator def _num(self): - return self.Rational.numerator + return self.Numerator numerator=property(_num) def _den(self): - return self.Rational.denominator + return self.Denominator denominator=property(_den) def __nonzero__(self): - return bool(self.Rational) + return bool(self.Numerator) def __hash__(self): - return hash(self.Rational) + return 0xa0a0aa ^ hash(self.Numerator) ^ hash(self.Denominator) def invoke_mapper(self, mapper): return mapper.map_rational(self) +class RationalExpression(Expression): + def __init__(self, numerator, denominator=1): + self.Numerator = numerator + self.Denominator = denominator + + def _num(self): + return self.Numerator + numerator=property(_num) + + def _den(self): + return self.Denominator + denominator=property(_den) + + def __nonzero__(self): + return bool(self.Numerator) + + def __hash__(self): + return 0xa0a0aa ^ hash(self.Numerator) ^ hash(self.Denominator) + + def invoke_mapper(self, mapper): + return mapper.map_rational(self) class Power(BinaryExpression): def invoke_mapper(self, mapper): @@ -388,3 +408,15 @@ def make_product(components): def polynomial_from_expression(expression): pass +def make_quotient(numerator, denominator): + try: + if isinstance(traits.traits(numerator, denominator), EuclideanRingTraits): + return RationalExpression(numerator, denominator) + except traits.NoCommonTraitsError: + pass + except traits.NoTraitsError: + pass + + return QuotientExpression(numerator, denominator) + +# FIXME: add traits types to expressions diff --git a/src/traits.py b/src/traits.py index 60d910c464a71ab8ee28a3db888e16fbc0997492..b77c972c9e6e545f0cde2ed1e964b7797bdeff1d 100644 --- a/src/traits.py +++ b/src/traits.py @@ -3,13 +3,22 @@ import algorithm +class NoTraitsError(Exception): + pass + +class NoCommonTraitsError(Exception): + pass + + + + def traits(x): try: return x.traits() except AttributeError: if isinstance(x, (complex, float)): return FieldTraits if isinstance(x, int): return IntegerTraits - raise NotImplementedError + raise NoTraitsError @@ -21,7 +30,8 @@ def common_traits(*args): elif isinstance(t_x, t_y.__class__): return t_x else: - raise RuntimeError, "No common traits type between '%s' and '%s'" % \ + raise NoCommonTraitsError, \ + "No common traits type between '%s' and '%s'" % \ (t_x.__class__.__name__, t_y.__class__.__name__) return reduce(common_traits_two, (traits(arg) for arg in args))