diff --git a/src/__init__.py b/src/__init__.py index e6f0c4aeddbb6924e9d2147dd3b18c791584f2b4..5ded8accd3eeaedd1f138683935e8587dddf73cf 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -13,12 +13,12 @@ const = pymbolic.primitives.Constant sum = pymbolic.primitives.sum subscript = pymbolic.primitives.subscript product = pymbolic.primitives.product +quotient = pymbolic.primitives.quotient linear_combination = pymbolic.primitives.linear_combination parse = pymbolic.parser.parse evaluate = pymbolic.mapper.evaluator.evaluate compile = pymbolic.compiler.compile -stringify = pymbolic.mapper.stringifier.stringify is_constant = pymbolic.mapper.dependency.is_constant get_dependencies = pymbolic.mapper.dependency.get_dependencies substitute = pymbolic.mapper.substitutor.substitute diff --git a/src/mapper/__init__.py b/src/mapper/__init__.py index 7f3fb4821ce2fe5c755882df5cd84aa95bb66f05..be8e3c3dd7f5c301169b87c117b5aad174345d24 100644 --- a/src/mapper/__init__.py +++ b/src/mapper/__init__.py @@ -1,3 +1,5 @@ +class Mapper: + def __call__(self, *args, **kwargs): class CombineMapper: def combine(self, values): raise NotImplementedError diff --git a/src/mapper/stringifier.py b/src/mapper/stringifier.py index 32d17d36a8250acd194bc3a629b5523394f77f44..586a6e4756be16e3e29d37e6d3c23077ece2d1de 100644 --- a/src/mapper/stringifier.py +++ b/src/mapper/stringifier.py @@ -64,7 +64,7 @@ class StringifyMapper: else: return result - def map_rational(self, expr, enclosing_prec): + def map_quotient(self, expr, enclosing_prec): result = "%s/%s" % ( expr.numerator.invoke_mapper(self, PREC_PRODUCT), expr.denominator.invoke_mapper(self, PREC_PRODUCT) @@ -73,6 +73,7 @@ class StringifyMapper: return "(%s)" % result else: return result + map_rational = map_quotient def map_power(self, expr, enclosing_prec): result = "%s**%s" % ( @@ -89,9 +90,3 @@ class StringifyMapper: def map_list(self, expr, enclosing_prec): return "[%s]" % ", ".join([i.invoke_mapper(self) for i in expr.children]) - - - - -def stringify(expression): - return expression.invoke_mapper(StringifyMapper(), PREC_NONE) diff --git a/src/polynomial.py b/src/polynomial.py index 52c6e8763b9f8f42643e01cd58cc5b9d0102cbd5..e9e9d8ab4b6f80bb702ed333a3a496a6f438474f 100644 --- a/src/polynomial.py +++ b/src/polynomial.py @@ -1,4 +1,6 @@ from __future__ import division +import pymbolic +import pymbolic.primitives as primitives import pymbolic.algorithm as algorithm import pymbolic.traits as traits @@ -28,14 +30,18 @@ def _sort_uniq(data): -class Polynomial(object): - def __init__(self, base, data = ((1,1),)): +class Polynomial(primitives.Expression): + def __init__(self, base, data=None, unit=1): self.Base = base + self.Unit = unit # list of (exponent, coefficient tuples) # sorted in increasing order # one entry per degree - self.Data = data + if data is None: + self.Data = ((1, unit),) + else: + self.Data = tuple(data) # Remember the Zen, Luke: Sparse is better than dense. @@ -189,6 +195,10 @@ class Polynomial(object): return self.Base base = property(_base) + def _unit(self): + return self.Unit + unit = property(_unit) + def _degree(self): try: return self.Data[-1][0] @@ -204,10 +214,13 @@ class Polynomial(object): def __hash__(self): return hash(self.Base) ^ hash(self.Children) + def invoke_mapper(self, mapper, *args, **kwargs): + return mapper.map_polynomial(self, *args, **kwargs) -def derivative(poly): + +def differentiate(poly): return Polynomial( poly.base, tuple((exp-1, exp*coeff) @@ -216,6 +229,15 @@ def derivative(poly): +def integrate(poly): + return Polynomial( + poly.base, + tuple((exp+1, pymbolic.quotient(poly.unit, (exp+1))*coeff) + for exp, coeff in poly.data)) + + + + def leading_coefficient(poly): return poly.data[-1][1] @@ -236,19 +258,26 @@ class PolynomialTraits(traits.EuclideanRingTraits): if __name__ == "__main__": - x = Polynomial("x") - y = Polynomial("y") - xpoly = x**2 + 1 + + x = Polynomial(pymbolic.var("x"), unit=pymbolic.const(1)) + y = Polynomial(pymbolic.var("y"), unit=pymbolic.const(1)) + xpoly = x**2 + pymbolic.const(1) ypoly = -y**2*xpoly + xpoly - print xpoly - print ypoly - u = xpoly*ypoly - print u - print u**18 - print - - print 3*xpoly**3 + 1 - print xpoly - q,r = divmod(3*xpoly**3 + 1, xpoly) - print q, r + #print xpoly + #print ypoly + #u = xpoly*ypoly + #print u + #print u**3 + #print + + xp3 = xpoly**3 + print xp3 + print integrate(xp3) + + #print 3*xpoly**3 + 1 + #print xpoly + #q,r = divmod(3*xpoly**3 + 1, xpoly) + #print q, r + + diff --git a/src/primitives.py b/src/primitives.py index 197e1b3a675d8e6c39a9093acaf961b0d417fdbc..56fdc00e9140056d6d11958ee9b2d0852661dbae 100644 --- a/src/primitives.py +++ b/src/primitives.py @@ -1,5 +1,4 @@ import traits -import rational as rat import pymbolic.mapper.stringifier import pymbolic.mapper.hash_generator @@ -11,8 +10,6 @@ class Expression(object): return not self.__eq__(other) def __add__(self, other): - if not isinstance(other, Expression): - other = Constant(other) if not other: return self return Sum(self, other) @@ -22,13 +19,10 @@ class Expression(object): if not other: return self - else: - other = Constant(other) + return Sum(other, self) def __sub__(self, other): - if not isinstance(other, Expression): - other = Constant(other) if not other: return self return Sum(self, -other) @@ -38,19 +32,16 @@ class Expression(object): if not other: return Negation(self) - else: - other = Constant(other) + return Sum(other, -self) def __mul__(self, other): - if not isinstance(other, Expression): - other = Constant(other) if not (other - 1): return self elif not (other+1): return Negation(self) elif not other: - return Constant(0) + return 0 return Product(self, other) def __rmul__(self, other): @@ -59,23 +50,21 @@ class Expression(object): if not (other-1): return self elif not other: - return Constant(0) + return 0 elif not (other+1): return Negation(self) else: - return Product(Constant(other), self) + return Product(other, self) def __div__(self, other): - if not isinstance(other, Expression): - other = Constant(other) if not (other-1): return self - return make_quotient(self, other) + return quotient(self, other) __truediv__ = __div__ def __rdiv__(self, other): assert not isinstance(other, Expression) - return make_quotient(Constant(other), self) + return quotient(Constant(other), self) def __pow__(self, other): if not isinstance(other, Expression): @@ -115,133 +104,12 @@ class Expression(object): return self._HashValue def __str__(self): - return pymbolic.stringify(self) + from pymbolic.mapper.stringifier import StringifyMapper, PREC_NONE + return self.invoke_mapper(StringifyMapper(), PREC_NONE) def __repr__(self): return "%s%s" % (self.__class__.__name__, repr(self.__getinitargs__())) -class Constant(Expression): - def __init__(self, value): - self._Value = value - - def __getinitargs__(self): - return self._Value, - - def _value(self): - return self._Value - value = property(_value) - - def __lt__(self, other): - if isinstance(other, Constant): - return self._Value.__lt__(other._Value) - else: - return NotImplemented - - def __eq__(self, other): - return isinstance(other, Constant) and self._Value == other._Value - - def __add__(self, other): - if not isinstance(other, Expression): - return Constant(self._Value + other) - if isinstance(other, Constant): - return Constant(self._Value + other._Value) - if self._Value == 0: - return other - return Expression.__add__(self, other) - - def __radd__(self, other): - if not isinstance(other, Expression): - return Constant(other + self._Value) - if self._Value == 0: - return other - return Expression.__radd__(self, other) - - def __sub__(self, other): - if not isinstance(other, Expression): - return Constant(self._Value - other) - if isinstance(other, Constant): - return Constant(self._Value - other._Value) - if self._Value == 0: - return Negation(other) - return Expression.__sub__(self, other) - - def __rsub__(self, other): - if not isinstance(other, Expression): - return Constant(other - self._Value) - if self._Value == 0: - return other - return Expression.__rsub__(self, other) - - def __mul__(self, other): - if not isinstance(other, Expression): - return Constant(self._Value * other) - if isinstance(other, Constant): - return Constant(self._Value * other._Value) - if self._Value == 1: - return other - if self._Value == 0: - return self - return Expression.__mul__(self, other) - - def __rmul__(self, other): - if not isinstance(other, Expression): - return Constant(other * self._Value) - if self._Value == 1: - return other - if self._Value == 0: - return self - return Expression.__rmul__(self, other) - - def __div__(self, other): - if not isinstance(other, Expression): - return Constant(self._Value / other) - if isinstance(other, Constant): - return Constant(self._Value / other._Value) - if self._Value == 0: - return self - return Expression.__div__(self, other) - - def __rdiv__(self, other): - if not isinstance(other, Expression): - return Constant(other / self._Value) - if self._Value == 1: - return other - return Expression.__rdiv__(self, other) - - def __pow__(self, other): - if not isinstance(other, Expression): - return Constant(self._Value ** other) - if isinstance(other, Constant): - return Constant(self._Value ** other._Value) - if self._Value == 1: - return self - return Expression.__pow__(self, other) - - def __rpow__(self, other): - if not isinstance(other, Expression): - return Constant(other ** self._Value) - if self._Value == 0: - return Constant(1) - if self._Value == 1: - return other - return Expression.__rpow__(self, other) - - def __neg__(self): - return Constant(-self._Value) - - def __call__(self, *pars): - for par in pars: - if isinstance(par, Expression): - return Expression.__call__(self, *pars) - return self._Value(*pars) - - def __nonzero__(self): - return bool(self._Value) - - def invoke_mapper(self, mapper, *args, **kwargs): - return mapper.map_constant(self, *args, **kwargs) - - class Variable(Expression): def __init__(self, name): self._Name = name @@ -330,7 +198,7 @@ class ElementLookup(Expression): name = property(_name) def __eq__(self, other): - return isinstance(other, Subscript) \ + return isinstance(other, ElementLookup) \ and (self._Aggregate == other._Aggregate) \ and (self._Name == other._Name) @@ -467,7 +335,7 @@ class Quotient(Expression): denominator=property(_den) def __eq__(self, other): - return isinstance(other, Subscript) \ + return isinstance(other, Quotient) \ and (self._Numerator == other._Numerator) \ and (self._Denominator == other._Denominator) @@ -475,25 +343,7 @@ class Quotient(Expression): return bool(self._Numerator) def invoke_mapper(self, mapper, *args, **kwargs): - return mapper.map_rational(self, *args, **kwargs) - -class RationalExpression(Expression): - def __init__(self, rational): - self.Rational = rational - - def _num(self): - return self.Rational.numerator - numerator=property(_num) - - def _den(self): - return self.Rational.denominator - denominator=property(_den) - - def __nonzero__(self): - return bool(self.Rational) - - def invoke_mapper(self, mapper, *args, **kwargs): - return mapper.map_rational(self, *args, **kwargs) + return mapper.map_quotient(self, *args, **kwargs) class Power(Expression): def __init__(self, base, exponent): @@ -618,11 +468,17 @@ def polynomial_from_expression(expression): -def make_quotient(numerator, denominator): +def quotient(numerator, denominator): + if not isinstance(numerator, Expression): + numerator = Constant(numerator) + if not isinstance(denominator, Expression): + denominator = Constant(denominator) + try: + import pymbolic.rational as rat if isinstance(traits.common_traits(numerator, denominator), - EuclideanRingTraits): - return RationalExpression(numerator, denominator) + traits.EuclideanRingTraits): + return rat.Rational(numerator, denominator) except traits.NoCommonTraitsError: pass except traits.NoTraitsError: diff --git a/src/rational.py b/src/rational.py index 1723571d994d5bfe2e5181b0bb21f467410dc608..6b304bb8e0d81a010e7d7057a536a27d7d9feae7 100644 --- a/src/rational.py +++ b/src/rational.py @@ -1,9 +1,10 @@ +import pymbolic.primitives as prm import pymbolic.traits as traits -class Rational(object): +class Rational(prm.Expression): def __init__(self, numerator, denominator=1): d_unit = traits.traits(denominator).get_unit(denominator) numerator /= d_unit @@ -34,55 +35,42 @@ class Rational(object): def __add__(self, other): if not isinstance(other, Rational): - other = Rational(other) - - t = traits.common_traits(self.Denominator, other.Denominator) - newden = t.lcm(self.Denominator, other.Denominator) - newnum = self.Numerator * newden/self.Denominator + \ - other.Numerator * newden/other.Denominator - gcd = t.gcd(newden, newnum) - return Rational(newnum/gcd, newden/gcd) + newother = Rational(other) - def __radd__(self, other): - if not isinstance(other, Rational): - other = Rational(other) + try: + t = traits.common_traits(self.Denominator, newother.Denominator) + newden = t.lcm(self.Denominator, newother.Denominator) + newnum = self.Numerator * newden/self.Denominator + \ + newother.Numerator * newden/newother.Denominator + gcd = t.gcd(newden, newnum) + return Rational(newnum/gcd, newden/gcd) + except traits.NoCommonTraitsError: + return prm.Expression.__add__(self, other) - t = traits.common_traits(self.Denominator, other.Denominator) - newden = t.lcm(self.Denominator, other.Denominator) - newnum = other.Numerator * newden/other.Denominator + \ - self.Numerator * newden/self.Denominator - gcd = t.gcd(newden, newnum) - return Rational(newnum/gcd, newden/gcd) + __radd__ = __add__ def __sub__(self, other): - return self.__add__(other.__neg__()) + return self.__add__(-other) def __rsub__(self, other): - return self.__neg__().__radd__(other) + return (-self).__radd__(other) def __mul__(self, other): if not isinstance(other, Rational): other = Rational(other) - t = traits.common_traits(self.Numerator, other.Numerator, - self.Denominator, other. Denominator) - gcd_1 = t.gcd(self.Numerator, other.Denominator) - gcd_2 = t.gcd(other.Numerator, self.Denominator) - - return Rational(self.Numerator/gcd_1 * other.Numerator/gcd_2, - self.Denominator/gcd_2 * other.Denominator/gcd_1) - - def __rmul__(self, other): - if not isinstance(other, Rational): - other = Rational(other) + try: + t = traits.common_traits(self.Numerator, other.Numerator, + self.Denominator, other. Denominator) + gcd_1 = t.gcd(self.Numerator, other.Denominator) + gcd_2 = t.gcd(other.Numerator, self.Denominator) - t = traits.common_traits(self.Numerator, other.Numerator, - self.Denominator, other. Denominator) - gcd_1 = t.gcd(self.Numerator, other.Denominator) - gcd_2 = t.gcd(other.Numerator, self.Denominator) + return Rational(self.Numerator/gcd_1 * other.Numerator/gcd_2, + self.Denominator/gcd_2 * other.Denominator/gcd_1) + except traits.NoCommonTraitsError: + return prm.Expression.__mul__(self, other) - return Rational(other.Numerator/gcd_2 * self.Numerator/gcd_1, - other.Denominator/gcd_1 * self.Denominator/gcd_2) + __rmul__ = __mul__ def __div__(self, other): if not isinstance(other, Rational): @@ -99,19 +87,22 @@ class Rational(object): def __pow__(self, other): return Rational(self.Denominator**other, self.Numerator**other) - def __str__(self): - return "%s/%s" % (str(self.Numerator), str(self.Denominator)) - - def __repr__(self): - return "%s(%s, %s)" % (self.__class__.__name__, - repr(self.Numerator), repr(self.Denominator)) - def __float__(self): return float(self.Numerator) / flaot(self.Denominator) def __hash__(self): return 0xcafe ^ hash(self.Numerator) ^ hash(self.Denominator) + def invoke_mapper(self, mapper, *args, **kwargs): + return mapper.map_rational(self, *args, **kwargs) + + def __str__(self): + if isinstance(self.Numerator, primitives.Expression): + return primitives.Expression.__str__(self) + else: + return "%s/%s" % (str(self.Numerator), str(self.Denominator)) + +