From 0aea67b6248de24bf5c20988f84fdf8ba68e066d Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 10 Jun 2005 16:28:28 +0000 Subject: [PATCH] [pymbolic @ Arch-1:inform@tiker.net--iam-2005%pymbolic--mainline--1.0--patch-4] Primitives fixed. --- src/primitives.py | 175 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 131 insertions(+), 44 deletions(-) diff --git a/src/primitives.py b/src/primitives.py index d5511ce..15d025f 100644 --- a/src/primitives.py +++ b/src/primitives.py @@ -1,3 +1,4 @@ +import tests import stringifier @@ -7,44 +8,53 @@ class Expression: def __add__(self, other): if not isinstance(other, Expression): other = Constant(other) - if isinstance(other, Constant) and other.Value == 0: + if tests.is_zero(other): return self return Sum((self, other)) def __radd__(self, other): - if not isinstance(other, Expression): - other = Constant(other) - if isinstance(other, Constant) and other.Value == 0: + assert not isinstance(other, Expression) + + if other == 0: return self + else: + other = Constant(other) return Sum((other, self)) def __sub__(self, other): if not isinstance(other, Expression): - other = Constant(other) - if isinstance(other, Constant) and other.Value == 0: + other = Constant(other) + if tests.is_zero(other): return self return Sum((self, -other)) def __rsub__(self, other): - if not isinstance(other, Expression): - other = Constant(other) - if isinstance(other, Constant) and other.Value == 0: + assert not isinstance(other, Expression) + + if other == 0: return Negation(self) + else: + other = Constant(other) return Sum((other, -self)) def __mul__(self, other): if not isinstance(other, Expression): other = Constant(other) - if isinstance(other, Constant) and other.Value == 1: + if test.is_one(other): return self + if tests.is_zero(other): + return Constant(0) return Product((self, other)) def __rmul__(self, other): - if not isinstance(other, Expression): - other = Constant(other) - if isinstance(other, Constant) and other.Value == 1: + assert not isinstance(other, Expression) + + if other == 1: return self - return Product((other, self)) + elif other == 0: + return Constant(0) + else: + return Product((other, self)) def __div__(self, other): if not isinstance(other, Expression): @@ -54,23 +64,24 @@ class Expression: return Quotient(self, other) def __rdiv__(self, other): - if not isinstance(other, Expression): - other = Constant(other) - return Quotient(other, self) + assert not isinstance(other, Expression) + return Quotient(Constant(other), self) def __pow__(self, other): if not isinstance(other, Expression): other = Constant(other) - if isinstance(other, Constant): - if other.Value == 0: - return Constant(1) - elif other.Value == 1: - return self + if tests.is_zero(other): # exponent zero + return Constant(1) + elif tests.is_one(other): # exponent one + return self return Power(self, other) def __rpow__(self, other): - if not isinstance(other, Expression): - other = Constant(other) + assert not isinstance(other, Expression) + if tests.is_zero(other): # base zero + return Constant(0) + elif tests.is_one(other): # base one + return Constant(1) return Power(other, self) def __neg__(self): @@ -116,85 +127,105 @@ class Constant(Expression): def __init__(self, value): self.Value = value + def __eq__(self, other): + return isinstance(other, Constant) and self.Value == other.Value + + def __ne__(self, other): + return not isinstance(other, Constant) or self.Value != other.Value + def __add__(self, other): if not isinstance(other, Expression): return Constant(self.Value + other) + if isinstance(other, Constant): + return Constant(self.Value + other.Value) if self.Value == 0: return other - return Sum((self, other)) + return Expression.__add__(self, other) def __radd__(self, other): if not isinstance(other, Expression): return Constant(other + self.Value) if self.Value == 0: return other - return Sum((other, self)) + return Expression.__radd__(self, other) def __sub__(self, other): if not isinstance(other, Expression): return Constant(self.Value - other) + if isinstance(other, Constant): + return Constant(self.Value - other.Value) if self.Value == 0: return Negation(other) - return Sum((self, -other)) + return Expression.__sub__(self, other) def __rsub__(self, other): if not isinstance(other, Expression): return Constant(other - self.Value) if self.Value == 0: return other - return Sum((other, -self)) + return Expression.__rsub__(self, other) def __mul__(self, other): if not isinstance(other, Expression): return Constant(self.Value * other) + if isinstance(other, Constant): + return Constant(self.Value * other.Value) if self.Value == 1: return other - return Product((self, other)) + if self.Value == 0: + return self + return Expression.__mul__(self, other) def __rmul__(self, other): if not isinstance(other, Expression): return Constant(other * self.Value) if self.Value == 1: return other - return Product((other, self)) + if self.Value == 0: + return self + return Expression.__rmul__(self, other) def __div__(self, other): if not isinstance(other, Expression): return Constant(self.Value / other) - if self.Value == 1: - return other - return Quotient(self, other) + if isinstance(other, Constant): + return Constant(self.Value / other.Value) + if self.Value == 0: + return self + return Expression.__div__(self, other) def __rdiv__(self, other): if not isinstance(other, Expression): return Constant(other / self.Value) - return Quotient(other, self) + if self.Value == 1: + return other + return Expression.__rdiv__(self, other) def __pow__(self, other): if not isinstance(other, Expression): return Constant(self.Value ** other) - if self.Value == 0: - return self + if isinstance(other, Constant): + return Constant(self.Value ** other.Value) if self.Value == 1: return self - return Power(self, other) + return Expression.__pow__(self, other) def __rpow__(self, other): if not isinstance(other, Expression): return Constant(other ** self.Value) - if self.Value == 1: - return other if self.Value == 0: return Constant(1) - return Power(other, self) + if self.Value == 1: + return other + return Expression.__rpow__(self, other) def __neg__(self): return Constant(-self.Value) - def __call__(self, pars): + def __call__(self, *pars): for par in pars: if isinstance(par, Expression): - return Expression.__call__(self, pars) + return Expression.__call__(self, *pars) return self.Value(*pars) def __hash__(self): @@ -279,9 +310,62 @@ class Power(BinaryExpression): return mapper.map_power(self) class Polynomial(Expression): - def __init__(self, base, children): + def __init__(self, base, coeff): self.Base = base - self.Children = children + + # list of (exponent, coefficient tuples) + # sorted in increasing order + # one entry per degree + self.Data = children + + # Remember the Zen, Luke: Sparse is better than dense. + + def __neg__(self): + return Polynomial(self.Base, + [(exp, -coeff) + for (exp, coeff) in self.Data]) + + def __add__(self, other): + if not isinstance(other, Polynomial) or other.Base != self.Base: + return Expression.__add__(self, other) + + iself = 0 + iother = 0 + + result = [] + while iself < len(self.Data) and iother < len(other.Data): + exp_self = self.Data[iself][0] + exp_other = other.Data[iother][0] + if exp_self == exp_other: + coeff = self.Data[iself][1] + other.Data[iother][1] + if coeff != Constant(0): + result.append((exp_self, coeff)) + iself += 1 + iother += 1 + elif exp_self > exp_other: + result.append((exp_other, other.Data[iother][1])) + iother += 1 + elif exp_self < exp_other: + result.append((exp_self, self.Data[iself][1])) + iself += 1 + + # we have exhausted at least one list, exhaust the other + while iself < len(self.Data): + exp_self = self.Data[iself][0] + result.append((exp_self, self.Data[iself][1])) + iself += 1 + + while iother < len(other.Data): + exp_other = other.Data[iother][0] + result.append((exp_other, other.Data[iother][1])) + iother += 1 + + return Polynomial(self.Base, result) + + def __mul__(self, other): + if not isinstance(other, Polynomial) or other.Base != self.Base: + return Expression.__mul__(self, other) + raise NotImplementedError def __hash__(self): return hash(self.Base) ^ hash(self.Children) @@ -314,3 +398,6 @@ def make_product(components): else: return Product(components) +def polynomial_from_expression(expression): + pass + -- GitLab