From 3b2d2a5ac32bddb3965dc6ffb3ebaa6495c20146 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 9 Jun 2005 23:21:41 +0000
Subject: [PATCH] [pymbolic @
 Arch-1:inform@tiker.net--iam-2005%pymbolic--mainline--1.0--patch-2] Finished
 Differentiator.

---
 src/__init__.py       |  9 ++++--
 src/differentiator.py | 60 +++++++++++++++++--------------------
 src/parser.py         | 14 ++++-----
 src/primitives.py     | 70 +++++++++++++++++++++++++++++++++++++++++--
 4 files changed, 106 insertions(+), 47 deletions(-)

diff --git a/src/__init__.py b/src/__init__.py
index ac86a95..7b488a0 100644
--- a/src/__init__.py
+++ b/src/__init__.py
@@ -4,6 +4,7 @@ import compiler
 import stringifier
 import constant_detector
 import substitutor
+import differentiator
 
 parse = parser.parse
 evaluate = evaluator.evaluate
@@ -11,13 +12,15 @@ compile = compiler.compile
 stringify = stringifier.stringify
 is_constant = constant_detector.is_constant
 substitute = substitutor.substitute
+differentiate = differentiator.differentiate
 
 if __name__ == "__main__":
     import math
     ex = parse("0 + 4.3e3j * alpha * cos(x+pi)") + 5
 
     print ex
+    #print evaluate(ex, {"alpha":5, "cos":math.cos, "x":-math.pi, "pi":math.pi})
+    #print is_constant(ex)
+    #print substitute(ex, {"alpha": ex})
+    print differentiate(ex, "x")
 
-    print evaluate(ex, {"alpha":5, "cos":math.cos, "x":-math.pi, "pi":math.pi})
-    print is_constant(ex)
-    print substitute(ex, {"alpha": ex})
diff --git a/src/differentiator.py b/src/differentiator.py
index bcd8b96..f7ec8db 100644
--- a/src/differentiator.py
+++ b/src/differentiator.py
@@ -8,7 +8,7 @@ import math
 
 
 def map_math_functions_by_name(i, func, pars):
-    if not isinstance(func, primitives.Constant):
+    if not isinstance(func, primitives.Variable):
         raise RuntimeError, "No derivative of non-constant function "+str(func)
     name = func.Name
 
@@ -22,12 +22,11 @@ def map_math_functions_by_name(i, func, pars):
         return primitives.Constant(1)/pars[0]
     elif name == "exp" and len(pars) == 1:
         return primitives.Constant("exp")(pars)
-    else
+    else:
         return primitives.Constant(name+"'")(pars)
 
-
-class DifferentiationMapper(mapper.IdentityMapper):
-    def __init__(self, variable, parameters=[], func_map):
+class DifferentiationMapper:
+    def __init__(self, variable, parameters, func_map):
         self.Variable = variable
         self.Parameters = parameters
         self.FunctionMap = func_map
@@ -43,36 +42,28 @@ class DifferentiationMapper(mapper.IdentityMapper):
         else:
             return primitives.Constant(0)
 
-    def Call(self, expr):
-        result = tuple(
+    def map_call(self, expr):
+        return primitives.make_sum(tuple(
             self.FunctionMap(i, expr.Function, expr.Parameters)
             * par.invoke_mapper(self)
             for i, par in enumerate(expr.Parameters)
-            if not self._isc(par))
+            if not self._isc(par)))
 
-        if len(result) == 0:
-            return primitives.Constant(0)
-        elif len(result) == 1:
-            return result[0]
-        else:
-            return primitives.Sum(result)
+    def map_sum(self, expr):
+        return primitives.make_sum(tuple(child.invoke_mapper(self)
+                                           for child in expr.Children
+                                           if not self._isc(child)))
+
+    def map_neg(self, expr):
+        return -expr.Child.invoke_mapper(self)
 
     def map_product(self, expr):
-        result = tuple(
-            Primitives.Product(
-            expr.Children[0:i] +
-            child.invoke_mapper(+
-            expr.Children[i+1:]
+        return primitives.make_sum(tuple(
+            primitives.make_product(expr.Children[0:i] +
+                                    (child.invoke_mapper(self),) +
+                                    expr.Children[i+1:])
             for i, child in enumerate(expr.Children)
-            if not self._isc(child))
-
-        if len(result) == 0:
-            return primitives.Constant(0)
-        elif len(result) == 1:
-            return result[0]
-        else:
-            return primitives.Sum(result)
-
+            if not self._isc(child)))
 
     def map_quotient(self, expr):
         f = expr.Child1
@@ -114,10 +105,10 @@ class DifferentiationMapper(mapper.IdentityMapper):
     def map_polynomial(self, expr):
         raise NotImplementedError
     
-    def _isc(subexp):
+    def _isc(self,subexp):
         return constant_detector.is_constant(subexp, [self.Variable])
 
-    def _eval(subexp):
+    def _eval(self,subexp):
         try:
             return primitives.Constant(evaluator.evaluate(subexp))
         except KeyError:
@@ -126,5 +117,10 @@ class DifferentiationMapper(mapper.IdentityMapper):
 
 
 
-def differentiate(expression, variable, func_mapper=map_math_functions_by_name):
-    return expression.invoke_mapper(DifferentiationMapper(variable))
+def differentiate(expression, 
+                  variable, 
+                  parameters=[],
+                  func_mapper=map_math_functions_by_name):
+    return expression.invoke_mapper(DifferentiationMapper(variable,
+                                                          parameters,
+                                                          func_mapper))
diff --git a/src/parser.py b/src/parser.py
index ff0e244..aa0bb63 100644
--- a/src/parser.py
+++ b/src/parser.py
@@ -97,23 +97,19 @@ def parse(expr_str):
 
         if next_tag is _plus and _PREC_PLUS >= min_precedence:
             pstate.advance()
-            return primitives.Sum((left_exp, parse_expression(pstate, _PREC_PLUS)))
+            return left_exp+parse_expression(pstate, _PREC_PLUS)
         elif next_tag is _minus and _PREC_PLUS >= min_precedence:
             pstate.advance()
-            return primitives.Sum((
-                left_exp, primitives.Negation(parse_expression(pstate, _PREC_PLUS))))
+            return left_exp-primitives.Negation(parse_expression(pstate, _PREC_PLUS))
         elif next_tag is _times and _PREC_TIMES >= min_precedence:
             pstate.advance()
-            return primitives.Product(
-                (left_exp, parse_expression(pstate, _PREC_TIMES)))
+            return left_exp*parse_expression(pstate, _PREC_TIMES)
         elif next_tag is _over and _PREC_TIMES >= min_precedence:
             pstate.advance()
-            return primitives.Quotient(
-                left_exp, parse_expression(pstate, _PREC_TIMES))
+            return left_exp/parse_expression(pstate, _PREC_TIMES)
         elif next_tag is _power and _PREC_POWER >= min_precedence:
             pstate.advance()
-            return primitives.Power(
-                left_exp, parse_expression(pstate, _PREC_TIMES))
+            return left_exp**parse_expression(pstate, _PREC_TIMES)
         else:
             return left_exp
 
diff --git a/src/primitives.py b/src/primitives.py
index b68b513..d5511ce 100644
--- a/src/primitives.py
+++ b/src/primitives.py
@@ -7,36 +7,50 @@ class Expression:
     def __add__(self, other):
         if not isinstance(other, Expression):
             other = Constant(other)
+        if isinstance(other, Constant) and other.Value == 0:
+            return self
         return Sum((self, other))
 
     def __radd__(self, other):
         if not isinstance(other, Expression):
             other = Constant(other)
+        if isinstance(other, Constant) and other.Value == 0:
+            return self
         return Sum((other, self))
 
     def __sub__(self, other):
         if not isinstance(other, Expression):
             other = Constant(other)
+        if isinstance(other, Constant) and other.Value == 0:
+            return self
         return Sum((self, -other))
 
     def __rsub__(self, other):
         if not isinstance(other, Expression):
             other = Constant(other)
+        if isinstance(other, Constant) and other.Value == 0:
+            return Negation(self)
         return Sum((other, -self))
 
     def __mul__(self, other):
         if not isinstance(other, Expression):
             other = Constant(other)
+        if isinstance(other, Constant) and other.Value == 1:
+            return self
         return Product((self, other))
 
     def __rmul__(self, other):
         if not isinstance(other, Expression):
             other = Constant(other)
+        if isinstance(other, Constant) and other.Value == 1:
+            return self
         return Product((other, self))
 
     def __div__(self, other):
         if not isinstance(other, Expression):
             other = Constant(other)
+        if isinstance(other, Constant) and other.Value == 1:
+            return self
         return Quotient(self, other)
 
     def __rdiv__(self, other):
@@ -47,6 +61,11 @@ class Expression:
     def __pow__(self, other):
         if not isinstance(other, Expression):
             other = Constant(other)
+        if isinstance(other, Constant):
+            if other.Value == 0:
+                return Constant(1)
+            elif other.Value == 1:
+                return self
         return Power(self, other)
 
     def __rpow__(self, other):
@@ -54,8 +73,8 @@ class Expression:
             other = Constant(other)
         return Power(other, self)
 
-    def __neg__(self, other):
-        return Negation(self, other)
+    def __neg__(self):
+        return Negation(self)
 
     def __call__(self, other):
         processed = []
@@ -87,6 +106,7 @@ class BinaryExpression(Expression):
 
 class NAryExpression(Expression):
     def __init__(self, children):
+        assert isinstance(children, tuple)
         self.Children = children
 
     def __hash__(self):
@@ -99,36 +119,50 @@ class Constant(Expression):
     def __add__(self, other):
         if not isinstance(other, Expression):
             return Constant(self.Value + other)
+        if self.Value == 0:
+            return other
         return Sum((self, other))
 
     def __radd__(self, other):
         if not isinstance(other, Expression):
             return Constant(other + self.Value)
+        if self.Value == 0:
+            return other
         return Sum((other, self))
 
     def __sub__(self, other):
         if not isinstance(other, Expression):
             return Constant(self.Value - other)
+        if self.Value == 0:
+            return Negation(other)
         return Sum((self, -other))
 
     def __rsub__(self, other):
         if not isinstance(other, Expression):
             return Constant(other - self.Value)
+        if self.Value == 0:
+            return other
         return Sum((other, -self))
 
     def __mul__(self, other):
         if not isinstance(other, Expression):
             return Constant(self.Value * other)
+        if self.Value == 1:
+            return other
         return Product((self, other))
 
     def __rmul__(self, other):
         if not isinstance(other, Expression):
             return Constant(other * self.Value)
+        if self.Value == 1:
+            return other
         return Product((other, self))
 
     def __div__(self, other):
         if not isinstance(other, Expression):
             return Constant(self.Value / other)
+        if self.Value == 1:
+            return other
         return Quotient(self, other)
 
     def __rdiv__(self, other):
@@ -139,14 +173,22 @@ class Constant(Expression):
     def __pow__(self, other):
         if not isinstance(other, Expression):
             return Constant(self.Value ** other)
+        if self.Value == 0:
+            return self
+        if self.Value == 1:
+            return self
         return Power(self, other)
 
     def __rpow__(self, other):
         if not isinstance(other, Expression):
             return Constant(other ** self.Value)
+        if self.Value == 1:
+            return other
+        if self.Value == 0:
+            return Constant(1)
         return Power(other, self)
 
-    def __neg__(self, other):
+    def __neg__(self):
         return Constant(-self.Value)
 
     def __call__(self, pars):
@@ -250,3 +292,25 @@ class Polynomial(Expression):
 class List(NAryExpression):
     def invoke_mapper(self, mapper):
         return mapper.map_list(self)
+
+
+
+
+
+# intelligent makers ---------------------------------------------------------
+def make_sum(components):
+    if len(components) == 0:
+        return primitives.Constant(0)
+    elif len(components) == 1:
+        return components[0]
+    else:
+        return Sum(components)
+
+def make_product(components):
+    if len(components) == 0:
+        return primitives.Constant(1)
+    elif len(components) == 1:
+        return components[0]
+    else:
+        return Product(components)
+
-- 
GitLab