diff --git a/pymbolic/functions.py b/pymbolic/functions.py index 4a0ae78cbd497f063c65df1da8800c8a2232c919..b1a98df200e12514621a9830e88bbffe4c7eee65 100644 --- a/pymbolic/functions.py +++ b/pymbolic/functions.py @@ -23,29 +23,48 @@ THE SOFTWARE. """ -import pymbolic.primitives as primitives +import pymbolic.primitives as p def sin(x): - return primitives.Call( - primitives.Lookup(primitives.Variable("math"), "sin"), (x,)) + return p.Call(p.Lookup(p.Variable("math"), "sin"), (x,)) def cos(x): - return primitives.Call( - primitives.Lookup(primitives.Variable("math"), "cos"), (x,)) + return p.Call(p.Lookup(p.Variable("math"), "cos"), (x,)) def tan(x): - return primitives.Call( - primitives.Lookup(primitives.Variable("math"), "tan"), (x,)) + return p.Call(p.Lookup(p.Variable("math"), "tan"), (x,)) def log(x): - return primitives.Call( - primitives.Lookup(primitives.Variable("math"), "log"), (x,)) + return p.Call(p.Lookup(p.Variable("math"), "log"), (x,)) def exp(x): - return primitives.Call( - primitives.Lookup(primitives.Variable("math"), "exp"), (x,)) + return p.Call(p.Lookup(p.Variable("math"), "exp"), (x,)) + + +def sinh(x): + return p.Call(p.Lookup(p.Variable("math"), "sinh"), (x,)) + + +def cosh(x): + return p.Call(p.Lookup(p.Variable("math"), "cosh"), (x,)) + + +def tanh(x): + return p.Call(p.Lookup(p.Variable("math"), "tanh"), (x,)) + + +def expm1(x): + return p.Call(p.Lookup(p.Variable("math"), "expm1"), (x,)) + + +def fabs(x): + return p.Call(p.Lookup(p.Variable("math"), "fabs"), (x,)) + + +def sign(x): + return p.Call(p.Lookup(p.Variable("math"), "copysign"), (1, x,)) diff --git a/pymbolic/mapper/differentiator.py b/pymbolic/mapper/differentiator.py index b6ec695c77b72a8a61e662dbfa4c0cc5582e0326..beb901ae35d95a100653f63493a7f0420c48874b 100644 --- a/pymbolic/mapper/differentiator.py +++ b/pymbolic/mapper/differentiator.py @@ -22,34 +22,49 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -import math -import cmath - import pymbolic import pymbolic.primitives as primitives import pymbolic.mapper import pymbolic.mapper.evaluator -def map_math_functions_by_name(i, func, pars): - try: - f = pymbolic.evaluate(func, {"math": math, "cmath": cmath}) - except pymbolic.mapper.evaluator.UnknownVariableError: - raise RuntimeError("No derivative of non-constant function "+str(func)) - +def map_math_functions_by_name(i, func, pars, allowed_nonsmoothness="none"): def make_f(name): return primitives.Lookup(primitives.Variable("math"), name) - if f is math.sin and len(pars) == 1: + if func == make_f("sin") and len(pars) == 1: return make_f("cos")(*pars) - elif f is math.cos and len(pars) == 1: + elif func == make_f("cos") and len(pars) == 1: return -make_f("sin")(*pars) - elif f is math.tan and len(pars) == 1: + elif func == make_f("tan") and len(pars) == 1: return make_f("tan")(*pars)**2+1 - elif f is math.log and len(pars) == 1: + elif func == make_f("log") and len(pars) == 1: return primitives.quotient(1, pars[0]) - elif f is math.exp and len(pars) == 1: + elif func == make_f("exp") and len(pars) == 1: return make_f("exp")(*pars) + elif func == make_f("sinh") and len(pars) == 1: + return make_f("cosh")(*pars) + elif func == make_f("cosh") and len(pars) == 1: + return make_f("sinh")(*pars) + elif func == make_f("tanh") and len(pars) == 1: + return 1-make_f("tanh")(*pars)**2 + elif func == make_f("expm1") and len(pars) == 1: + return make_f("exp")(*pars) + elif func == make_f("fabs") and len(pars) == 1: + if allowed_nonsmoothness in ["continuous", "discontinuous"]: + from pymbolic.functions import sign + return sign(*pars) + else: + raise ValueError("fabs is not smooth" + ", pass allowed_nonsmoothness='continuous' " + "to return sign") + elif func == make_f("copysign") and len(pars) == 2: + if allowed_nonsmoothness == "discontinuous": + return 0 + else: + raise ValueError("sign is discontinuous" + ", pass allowed_nonsmoothness='discontinuous' " + "to return 0") else: raise RuntimeError("unrecognized function, cannot differentiate") @@ -70,16 +85,30 @@ class DifferentiationMapper(pymbolic.mapper.RecursiveMapper): / (x + -1)**2**2 """ - def __init__(self, variable, func_map=map_math_functions_by_name): + def __init__(self, variable, func_map=map_math_functions_by_name, + allowed_nonsmoothness="none"): """ :arg variable: A :class:`pymbolic.primitives.Variable` instance by which to differentiate. :arg func_map: A function for computing derivatives of function calls, signature ``(arg_index, function_variable, parameters)``. + :arg allowed_nonsmoothness: Whether to allow differentiation of + functions which are not smooth or continuous. + Pass ``"continuous"`` to allow nonsmooth but not discontinuous + functions or ``"discontinuous"`` to allow both. + Defaults to ``"none"``, in which case neither is allowed. + + .. versionchanged:: 2019.2 + + Added *allowed_nonsmoothness*. """ self.variable = variable self.function_map = func_map + if allowed_nonsmoothness not in ["none", "continuous", "discontinuous"]: + raise ValueError("allowed_nonsmoothness=%s is not a valid option" + % allowed_nonsmoothness) + self.allowed_nonsmoothness = allowed_nonsmoothness def rec_undiff(self, expr, *args): """This method exists for the benefit of subclasses that may need to @@ -99,7 +128,8 @@ class DifferentiationMapper(pymbolic.mapper.RecursiveMapper): def map_call(self, expr, *args): return pymbolic.flattened_sum( self.function_map( - i, expr.function, self.rec_undiff(expr.parameters, *args)) + i, expr.function, self.rec_undiff(expr.parameters, *args), + allowed_nonsmoothness=self.allowed_nonsmoothness) * self.rec(par, *args) for i, par in enumerate(expr.parameters) ) @@ -189,7 +219,10 @@ class DifferentiationMapper(pymbolic.mapper.RecursiveMapper): def differentiate(expression, variable, - func_mapper=map_math_functions_by_name): + func_mapper=map_math_functions_by_name, + allowed_nonsmoothness="none"): if not isinstance(variable, (primitives.Variable, primitives.Subscript)): variable = primitives.make_variable(variable) - return DifferentiationMapper(variable, func_mapper)(expression) + return DifferentiationMapper( + variable, func_mapper, allowed_nonsmoothness=allowed_nonsmoothness + )(expression) diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 8ed651bbb92e89c243b6d12064b509fd0a3222bf..2d6f3f786a3ef9c60d9299a2907785a4370a3b42 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -641,6 +641,25 @@ def test_multiplicative_stringify_preserves_association(): assert_parse_roundtrip("(-1)*(((-1)*x) / 5)") +def test_differentiator_flags_for_nonsmooth_and_discontinuous(): + import pymbolic.functions as pf + from pymbolic.mapper.differentiator import differentiate + + x = prim.Variable('x') + + with pytest.raises(ValueError): + differentiate(pf.fabs(x), x) + + result = differentiate(pf.fabs(x), x, allowed_nonsmoothness="continuous") + assert result == pf.sign(x) + + with pytest.raises(ValueError): + differentiate(pf.sign(x), x) + + result = differentiate(pf.sign(x), x, allowed_nonsmoothness="discontinuous") + assert result == 0 + + if __name__ == "__main__": import sys if len(sys.argv) > 1: