diff --git a/pymbolic/mapper/differentiator.py b/pymbolic/mapper/differentiator.py index 17c23f49f3e4056b1d05e72c56e69cc0c91885bd..f924f71602fda87d5601c0ce4b35fd4f17dacc63 100644 --- a/pymbolic/mapper/differentiator.py +++ b/pymbolic/mapper/differentiator.py @@ -33,7 +33,7 @@ import pymbolic.mapper.evaluator def map_math_functions_by_name(i, func, pars, allow_non_smooth=False, - allow_discontinuity=False): + allow_discontinuous=False): try: f = pymbolic.evaluate(func, {"math": math, "cmath": cmath}) except pymbolic.mapper.evaluator.UnknownVariableError: @@ -62,16 +62,17 @@ def map_math_functions_by_name(i, func, pars, return make_f("exp")(*pars) elif f is math.fabs and len(pars) == 1: if allow_non_smooth: - return make_f("sign")(*pars) + from pymbolic.functions import sign + return sign(*pars) else: raise ValueError("fabs is not smooth" ", pass allow_non_smooth=True to return sign") elif f is math.copysign and len(pars) == 2: - if allow_discontinuity: + if allow_discontinuous: return 0 else: raise ValueError("sign is discontinuous" - ", pass allow_discontinuity=True to return 0") + ", pass allow_discontinuous=True to return 0") else: raise RuntimeError("unrecognized function, cannot differentiate") @@ -93,7 +94,7 @@ class DifferentiationMapper(pymbolic.mapper.RecursiveMapper): """ def __init__(self, variable, func_map=map_math_functions_by_name, - allow_non_smooth=False, allow_discontinuity=False): + allow_non_smooth=False, allow_discontinuous=False): """ :arg variable: A :class:`pymbolic.primitives.Variable` instance by which to differentiate. @@ -103,20 +104,20 @@ class DifferentiationMapper(pymbolic.mapper.RecursiveMapper): which are not smooth (e.g., ``fabs``), i.e., by ignoring the discontinuity in the resulting derivative. Defaults to *False*. - :arg allow_discontinuity: Whether to allow differentiation of + :arg allow_discontinuous: Whether to allow differentiation of which are not continuous (e.g., ``sign``), i.e., by ignoring the discontinuity. Defaults to *False*. .. versionchanged:: 2019.2 - Added *allow_non_smooth* and *allow_discontinuity*. + Added *allow_non_smooth* and *allow_discontinuous*. """ self.variable = variable self.function_map = func_map self.allow_non_smooth = allow_non_smooth - self.allow_discontinuity = allow_discontinuity + self.allow_discontinuous = allow_discontinuous def rec_undiff(self, expr, *args): """This method exists for the benefit of subclasses that may need to @@ -138,7 +139,7 @@ class DifferentiationMapper(pymbolic.mapper.RecursiveMapper): self.function_map( i, expr.function, self.rec_undiff(expr.parameters, *args), allow_non_smooth=self.allow_non_smooth, - allow_discontinuity=self.allow_discontinuity) + allow_discontinuous=self.allow_discontinuous) * self.rec(par, *args) for i, par in enumerate(expr.parameters) ) @@ -230,9 +231,9 @@ def differentiate(expression, variable, func_mapper=map_math_functions_by_name, allow_non_smooth=False, - allow_discontinuity=False): + allow_discontinuous=False): if not isinstance(variable, (primitives.Variable, primitives.Subscript)): variable = primitives.make_variable(variable) return DifferentiationMapper(variable, func_mapper, allow_non_smooth=allow_non_smooth, - allow_discontinuity=allow_discontinuity)(expression) + allow_discontinuous=allow_discontinuous)(expression) diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 8ed651bbb92e89c243b6d12064b509fd0a3222bf..220bc7919d8c8b914255f51905d5b978797a7b52 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -641,6 +641,25 @@ def test_multiplicative_stringify_preserves_association(): assert_parse_roundtrip("(-1)*(((-1)*x) / 5)") +def test_differentiator_flags_for_nonsmooth_and_discontinuous(): + import pymbolic.functions as pf + from pymbolic.mapper.differentiator import differentiate + + x = prim.Variable('x') + + with pytest.raises(ValueError): + differentiate(pf.fabs(x), x) + + result = differentiate(pf.fabs(x), x, allow_non_smooth=True) + assert result == pf.sign(x) + + with pytest.raises(ValueError): + differentiate(pf.sign(x), x) + + result = differentiate(pf.sign(x), x, allow_discontinuous=True) + assert result == 0 + + if __name__ == "__main__": import sys if len(sys.argv) > 1: