diff --git a/src/__init__.py b/src/__init__.py index 5498747a9a909fdda0204f535d3bb9a862afb51f..ac86a9521398271be1e58d42dac3636aa576b16e 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -3,12 +3,14 @@ import evaluator import compiler import stringifier import constant_detector +import substitutor parse = parser.parse evaluate = evaluator.evaluate compile = compiler.compile stringify = stringifier.stringify is_constant = constant_detector.is_constant +substitute = substitutor.substitute if __name__ == "__main__": import math @@ -18,3 +20,4 @@ if __name__ == "__main__": print evaluate(ex, {"alpha":5, "cos":math.cos, "x":-math.pi, "pi":math.pi}) print is_constant(ex) + print substitute(ex, {"alpha": ex}) diff --git a/src/constant_detector.py b/src/constant_detector.py index b286059c2b523e6f3d1109b0e4bc77aee0e0301b..28224d5b859f024a801c6e4e7e1fa9ab7e5b6763 100644 --- a/src/constant_detector.py +++ b/src/constant_detector.py @@ -22,5 +22,6 @@ class ConstantDetectionMapper(mapper.CombineMapper): + def is_constant(expr, with_respect_to=None): return expr.invoke_mapper(ConstantDetectionMapper(with_respect_to)) diff --git a/src/differentiator.py b/src/differentiator.py new file mode 100644 index 0000000000000000000000000000000000000000..bcd8b9630256738977546bf30e2c12918f77c45c --- /dev/null +++ b/src/differentiator.py @@ -0,0 +1,130 @@ +import constant_detector +import evaluator +import primitives +import mapper +import math + + + + +def map_math_functions_by_name(i, func, pars): + if not isinstance(func, primitives.Constant): + raise RuntimeError, "No derivative of non-constant function "+str(func) + name = func.Name + + if name == "sin" and len(pars) == 1: + return primitives.Constant("cos")(pars) + elif name == "cos" and len(pars) == 1: + return -primitives.Constant("sin")(pars) + elif name == "tan" and len(pars) == 1: + return primitives.Constant("tan")(pars)**2+1 + elif name == "log" and len(pars) == 1: + return primitives.Constant(1)/pars[0] + elif name == "exp" and len(pars) == 1: + return primitives.Constant("exp")(pars) + else + return primitives.Constant(name+"'")(pars) + + +class DifferentiationMapper(mapper.IdentityMapper): + def __init__(self, variable, parameters=[], func_map): + self.Variable = variable + self.Parameters = parameters + self.FunctionMap = func_map + + def map_constant(self, expr): + return primitives.Constant(0) + + def map_variable(self, expr): + if expr.Name is self.Variable: + return primitives.Constant(1) + elif expr.Name in self.Parameters: + return expr + else: + return primitives.Constant(0) + + def Call(self, expr): + result = tuple( + self.FunctionMap(i, expr.Function, expr.Parameters) + * par.invoke_mapper(self) + for i, par in enumerate(expr.Parameters) + if not self._isc(par)) + + if len(result) == 0: + return primitives.Constant(0) + elif len(result) == 1: + return result[0] + else: + return primitives.Sum(result) + + def map_product(self, expr): + result = tuple( + Primitives.Product( + expr.Children[0:i] + + child.invoke_mapper(+ + expr.Children[i+1:] + for i, child in enumerate(expr.Children) + if not self._isc(child)) + + if len(result) == 0: + return primitives.Constant(0) + elif len(result) == 1: + return result[0] + else: + return primitives.Sum(result) + + + def map_quotient(self, expr): + f = expr.Child1 + g = expr.Child2 + f_const = self._isc(f) + g_const = self._isc(g) + + if f_const and g_const: + return primitives.Constant(0) + elif f_const: + f = self._eval(f) + return -f*g.invoke_mapper(self)/g**2 + elif g_const: + g = self._eval(g) + return f.invoke_mapper(self)/g + else: + return (f.invoke_mapper(self)*g-g.invoke_mapper(self)*f)/g**2 + + def map_power(self, expr): + f = expr.Child1 + g = expr.Child2 + f_const = self._isc(f) + g_const = self._isc(g) + + log = primitives.Constant("log") + + if f_const and g_const: + return primitives.Constant(0) + elif f_const: + f = self._eval(f) + return log(f) * f**g * g.invoke_mapper(self) + elif g_const: + g = self._eval(g) + return g * f**(g-1) * f.invoke_mapper(self) + else: + return log(f) * f**g * g.invoke_mapper(self) + \ + g * f**(g-1) * f.invoke_mapper(self) + + def map_polynomial(self, expr): + raise NotImplementedError + + def _isc(subexp): + return constant_detector.is_constant(subexp, [self.Variable]) + + def _eval(subexp): + try: + return primitives.Constant(evaluator.evaluate(subexp)) + except KeyError: + return subexp + + + + +def differentiate(expression, variable, func_mapper=map_math_functions_by_name): + return expression.invoke_mapper(DifferentiationMapper(variable)) diff --git a/src/primitives.py b/src/primitives.py index 91ffe9ef8267243b26e11b63d8d6c3138bfe0934..b68b51343faeeedde3cb641a2bb9cb65e0e6ac82 100644 --- a/src/primitives.py +++ b/src/primitives.py @@ -58,7 +58,14 @@ class Expression: return Negation(self, other) def __call__(self, other): - return Call(self, other) + processed = [] + for par in other: + if isinstance(par, Expression): + processed.append(par) + else: + processed.append(Constant(par)) + + return Call(self, processed) def __str__(self): return self.invoke_mapper(stringifier.StringifyMapper()) @@ -89,6 +96,65 @@ class Constant(Expression): def __init__(self, value): self.Value = value + def __add__(self, other): + if not isinstance(other, Expression): + return Constant(self.Value + other) + return Sum((self, other)) + + def __radd__(self, other): + if not isinstance(other, Expression): + return Constant(other + self.Value) + return Sum((other, self)) + + def __sub__(self, other): + if not isinstance(other, Expression): + return Constant(self.Value - other) + return Sum((self, -other)) + + def __rsub__(self, other): + if not isinstance(other, Expression): + return Constant(other - self.Value) + return Sum((other, -self)) + + def __mul__(self, other): + if not isinstance(other, Expression): + return Constant(self.Value * other) + return Product((self, other)) + + def __rmul__(self, other): + if not isinstance(other, Expression): + return Constant(other * self.Value) + return Product((other, self)) + + def __div__(self, other): + if not isinstance(other, Expression): + return Constant(self.Value / other) + return Quotient(self, other) + + def __rdiv__(self, other): + if not isinstance(other, Expression): + return Constant(other / self.Value) + return Quotient(other, self) + + def __pow__(self, other): + if not isinstance(other, Expression): + return Constant(self.Value ** other) + return Power(self, other) + + def __rpow__(self, other): + if not isinstance(other, Expression): + return Constant(other ** self.Value) + return Power(other, self) + + def __neg__(self, other): + return Constant(-self.Value) + + def __call__(self, pars): + for par in pars: + if isinstance(par, Expression): + return Expression.__call__(self, pars) + return self.Value(*pars) + def __hash__(self): return hash(self.Value) diff --git a/src/substitutor.py b/src/substitutor.py new file mode 100644 index 0000000000000000000000000000000000000000..c49b5ccb7466d37ca88d720d04338ab6a7ef53e9 --- /dev/null +++ b/src/substitutor.py @@ -0,0 +1,20 @@ +import mapper + + + + +class SubstitutionMapper(mapper.IdentityMapper): + def __init__(self, variable_assignments): + self.Assignments = variable_assignments + + def map_variable(self, expr): + try: + return self.Assignments[expr.Name] + except KeyError: + return expr + + + + +def substitute(expression, variable_assignments = {}): + return expression.invoke_mapper(SubstitutionMapper(variable_assignments))