diff --git a/src/__init__.py b/src/__init__.py index ac86a9521398271be1e58d42dac3636aa576b16e..7b488a05f7fdd02fa851c6e84825963db063179b 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -4,6 +4,7 @@ import compiler import stringifier import constant_detector import substitutor +import differentiator parse = parser.parse evaluate = evaluator.evaluate @@ -11,13 +12,15 @@ compile = compiler.compile stringify = stringifier.stringify is_constant = constant_detector.is_constant substitute = substitutor.substitute +differentiate = differentiator.differentiate if __name__ == "__main__": import math ex = parse("0 + 4.3e3j * alpha * cos(x+pi)") + 5 print ex + #print evaluate(ex, {"alpha":5, "cos":math.cos, "x":-math.pi, "pi":math.pi}) + #print is_constant(ex) + #print substitute(ex, {"alpha": ex}) + print differentiate(ex, "x") - print evaluate(ex, {"alpha":5, "cos":math.cos, "x":-math.pi, "pi":math.pi}) - print is_constant(ex) - print substitute(ex, {"alpha": ex}) diff --git a/src/differentiator.py b/src/differentiator.py index bcd8b9630256738977546bf30e2c12918f77c45c..f7ec8dbb744f10280d774a9d49d30841fe9c3250 100644 --- a/src/differentiator.py +++ b/src/differentiator.py @@ -8,7 +8,7 @@ import math def map_math_functions_by_name(i, func, pars): - if not isinstance(func, primitives.Constant): + if not isinstance(func, primitives.Variable): raise RuntimeError, "No derivative of non-constant function "+str(func) name = func.Name @@ -22,12 +22,11 @@ def map_math_functions_by_name(i, func, pars): return primitives.Constant(1)/pars[0] elif name == "exp" and len(pars) == 1: return primitives.Constant("exp")(pars) - else + else: return primitives.Constant(name+"'")(pars) - -class DifferentiationMapper(mapper.IdentityMapper): - def __init__(self, variable, parameters=[], func_map): +class DifferentiationMapper: + def __init__(self, variable, parameters, func_map): self.Variable = variable self.Parameters = parameters self.FunctionMap = func_map @@ -43,36 +42,28 @@ class DifferentiationMapper(mapper.IdentityMapper): else: return primitives.Constant(0) - def Call(self, expr): - result = tuple( + def map_call(self, expr): + return primitives.make_sum(tuple( self.FunctionMap(i, expr.Function, expr.Parameters) * par.invoke_mapper(self) for i, par in enumerate(expr.Parameters) - if not self._isc(par)) + if not self._isc(par))) - if len(result) == 0: - return primitives.Constant(0) - elif len(result) == 1: - return result[0] - else: - return primitives.Sum(result) + def map_sum(self, expr): + return primitives.make_sum(tuple(child.invoke_mapper(self) + for child in expr.Children + if not self._isc(child))) + + def map_neg(self, expr): + return -expr.Child.invoke_mapper(self) def map_product(self, expr): - result = tuple( - Primitives.Product( - expr.Children[0:i] + - child.invoke_mapper(+ - expr.Children[i+1:] + return primitives.make_sum(tuple( + primitives.make_product(expr.Children[0:i] + + (child.invoke_mapper(self),) + + expr.Children[i+1:]) for i, child in enumerate(expr.Children) - if not self._isc(child)) - - if len(result) == 0: - return primitives.Constant(0) - elif len(result) == 1: - return result[0] - else: - return primitives.Sum(result) - + if not self._isc(child))) def map_quotient(self, expr): f = expr.Child1 @@ -114,10 +105,10 @@ class DifferentiationMapper(mapper.IdentityMapper): def map_polynomial(self, expr): raise NotImplementedError - def _isc(subexp): + def _isc(self,subexp): return constant_detector.is_constant(subexp, [self.Variable]) - def _eval(subexp): + def _eval(self,subexp): try: return primitives.Constant(evaluator.evaluate(subexp)) except KeyError: @@ -126,5 +117,10 @@ class DifferentiationMapper(mapper.IdentityMapper): -def differentiate(expression, variable, func_mapper=map_math_functions_by_name): - return expression.invoke_mapper(DifferentiationMapper(variable)) +def differentiate(expression, + variable, + parameters=[], + func_mapper=map_math_functions_by_name): + return expression.invoke_mapper(DifferentiationMapper(variable, + parameters, + func_mapper)) diff --git a/src/parser.py b/src/parser.py index ff0e244519eb23bbd4409fb1ba95477bf6abb507..aa0bb6337be55d6f2aabf13ed059b625e1e8a7a2 100644 --- a/src/parser.py +++ b/src/parser.py @@ -97,23 +97,19 @@ def parse(expr_str): if next_tag is _plus and _PREC_PLUS >= min_precedence: pstate.advance() - return primitives.Sum((left_exp, parse_expression(pstate, _PREC_PLUS))) + return left_exp+parse_expression(pstate, _PREC_PLUS) elif next_tag is _minus and _PREC_PLUS >= min_precedence: pstate.advance() - return primitives.Sum(( - left_exp, primitives.Negation(parse_expression(pstate, _PREC_PLUS)))) + return left_exp-primitives.Negation(parse_expression(pstate, _PREC_PLUS)) elif next_tag is _times and _PREC_TIMES >= min_precedence: pstate.advance() - return primitives.Product( - (left_exp, parse_expression(pstate, _PREC_TIMES))) + return left_exp*parse_expression(pstate, _PREC_TIMES) elif next_tag is _over and _PREC_TIMES >= min_precedence: pstate.advance() - return primitives.Quotient( - left_exp, parse_expression(pstate, _PREC_TIMES)) + return left_exp/parse_expression(pstate, _PREC_TIMES) elif next_tag is _power and _PREC_POWER >= min_precedence: pstate.advance() - return primitives.Power( - left_exp, parse_expression(pstate, _PREC_TIMES)) + return left_exp**parse_expression(pstate, _PREC_TIMES) else: return left_exp diff --git a/src/primitives.py b/src/primitives.py index b68b51343faeeedde3cb641a2bb9cb65e0e6ac82..d5511ce96aa89a77d0a0cca74baeccfadd45674d 100644 --- a/src/primitives.py +++ b/src/primitives.py @@ -7,36 +7,50 @@ class Expression: def __add__(self, other): if not isinstance(other, Expression): other = Constant(other) + if isinstance(other, Constant) and other.Value == 0: + return self return Sum((self, other)) def __radd__(self, other): if not isinstance(other, Expression): other = Constant(other) + if isinstance(other, Constant) and other.Value == 0: + return self return Sum((other, self)) def __sub__(self, other): if not isinstance(other, Expression): other = Constant(other) + if isinstance(other, Constant) and other.Value == 0: + return self return Sum((self, -other)) def __rsub__(self, other): if not isinstance(other, Expression): other = Constant(other) + if isinstance(other, Constant) and other.Value == 0: + return Negation(self) return Sum((other, -self)) def __mul__(self, other): if not isinstance(other, Expression): other = Constant(other) + if isinstance(other, Constant) and other.Value == 1: + return self return Product((self, other)) def __rmul__(self, other): if not isinstance(other, Expression): other = Constant(other) + if isinstance(other, Constant) and other.Value == 1: + return self return Product((other, self)) def __div__(self, other): if not isinstance(other, Expression): other = Constant(other) + if isinstance(other, Constant) and other.Value == 1: + return self return Quotient(self, other) def __rdiv__(self, other): @@ -47,6 +61,11 @@ class Expression: def __pow__(self, other): if not isinstance(other, Expression): other = Constant(other) + if isinstance(other, Constant): + if other.Value == 0: + return Constant(1) + elif other.Value == 1: + return self return Power(self, other) def __rpow__(self, other): @@ -54,8 +73,8 @@ class Expression: other = Constant(other) return Power(other, self) - def __neg__(self, other): - return Negation(self, other) + def __neg__(self): + return Negation(self) def __call__(self, other): processed = [] @@ -87,6 +106,7 @@ class BinaryExpression(Expression): class NAryExpression(Expression): def __init__(self, children): + assert isinstance(children, tuple) self.Children = children def __hash__(self): @@ -99,36 +119,50 @@ class Constant(Expression): def __add__(self, other): if not isinstance(other, Expression): return Constant(self.Value + other) + if self.Value == 0: + return other return Sum((self, other)) def __radd__(self, other): if not isinstance(other, Expression): return Constant(other + self.Value) + if self.Value == 0: + return other return Sum((other, self)) def __sub__(self, other): if not isinstance(other, Expression): return Constant(self.Value - other) + if self.Value == 0: + return Negation(other) return Sum((self, -other)) def __rsub__(self, other): if not isinstance(other, Expression): return Constant(other - self.Value) + if self.Value == 0: + return other return Sum((other, -self)) def __mul__(self, other): if not isinstance(other, Expression): return Constant(self.Value * other) + if self.Value == 1: + return other return Product((self, other)) def __rmul__(self, other): if not isinstance(other, Expression): return Constant(other * self.Value) + if self.Value == 1: + return other return Product((other, self)) def __div__(self, other): if not isinstance(other, Expression): return Constant(self.Value / other) + if self.Value == 1: + return other return Quotient(self, other) def __rdiv__(self, other): @@ -139,14 +173,22 @@ class Constant(Expression): def __pow__(self, other): if not isinstance(other, Expression): return Constant(self.Value ** other) + if self.Value == 0: + return self + if self.Value == 1: + return self return Power(self, other) def __rpow__(self, other): if not isinstance(other, Expression): return Constant(other ** self.Value) + if self.Value == 1: + return other + if self.Value == 0: + return Constant(1) return Power(other, self) - def __neg__(self, other): + def __neg__(self): return Constant(-self.Value) def __call__(self, pars): @@ -250,3 +292,25 @@ class Polynomial(Expression): class List(NAryExpression): def invoke_mapper(self, mapper): return mapper.map_list(self) + + + + + +# intelligent makers --------------------------------------------------------- +def make_sum(components): + if len(components) == 0: + return primitives.Constant(0) + elif len(components) == 1: + return components[0] + else: + return Sum(components) + +def make_product(components): + if len(components) == 0: + return primitives.Constant(1) + elif len(components) == 1: + return components[0] + else: + return Product(components) +