Skip to content
Snippets Groups Projects
Commit e8de9f50 authored by zachjweiner's avatar zachjweiner
Browse files

add test for nonsmooth/discontinuous differentiating

parent 81afd8a2
No related branches found
No related tags found
1 merge request!26Add more math functions
......@@ -33,7 +33,7 @@ import pymbolic.mapper.evaluator
def map_math_functions_by_name(i, func, pars,
allow_non_smooth=False,
allow_discontinuity=False):
allow_discontinuous=False):
try:
f = pymbolic.evaluate(func, {"math": math, "cmath": cmath})
except pymbolic.mapper.evaluator.UnknownVariableError:
......@@ -62,16 +62,17 @@ def map_math_functions_by_name(i, func, pars,
return make_f("exp")(*pars)
elif f is math.fabs and len(pars) == 1:
if allow_non_smooth:
return make_f("sign")(*pars)
from pymbolic.functions import sign
return sign(*pars)
else:
raise ValueError("fabs is not smooth"
", pass allow_non_smooth=True to return sign")
elif f is math.copysign and len(pars) == 2:
if allow_discontinuity:
if allow_discontinuous:
return 0
else:
raise ValueError("sign is discontinuous"
", pass allow_discontinuity=True to return 0")
", pass allow_discontinuous=True to return 0")
else:
raise RuntimeError("unrecognized function, cannot differentiate")
......@@ -93,7 +94,7 @@ class DifferentiationMapper(pymbolic.mapper.RecursiveMapper):
"""
def __init__(self, variable, func_map=map_math_functions_by_name,
allow_non_smooth=False, allow_discontinuity=False):
allow_non_smooth=False, allow_discontinuous=False):
"""
:arg variable: A :class:`pymbolic.primitives.Variable` instance
by which to differentiate.
......@@ -103,20 +104,20 @@ class DifferentiationMapper(pymbolic.mapper.RecursiveMapper):
which are not smooth (e.g., ``fabs``), i.e., by ignoring the
discontinuity in the resulting derivative.
Defaults to *False*.
:arg allow_discontinuity: Whether to allow differentiation of
:arg allow_discontinuous: Whether to allow differentiation of
which are not continuous (e.g., ``sign``), i.e., by ignoring the
discontinuity.
Defaults to *False*.
.. versionchanged:: 2019.2
Added *allow_non_smooth* and *allow_discontinuity*.
Added *allow_non_smooth* and *allow_discontinuous*.
"""
self.variable = variable
self.function_map = func_map
self.allow_non_smooth = allow_non_smooth
self.allow_discontinuity = allow_discontinuity
self.allow_discontinuous = allow_discontinuous
def rec_undiff(self, expr, *args):
"""This method exists for the benefit of subclasses that may need to
......@@ -138,7 +139,7 @@ class DifferentiationMapper(pymbolic.mapper.RecursiveMapper):
self.function_map(
i, expr.function, self.rec_undiff(expr.parameters, *args),
allow_non_smooth=self.allow_non_smooth,
allow_discontinuity=self.allow_discontinuity)
allow_discontinuous=self.allow_discontinuous)
* self.rec(par, *args)
for i, par in enumerate(expr.parameters)
)
......@@ -230,9 +231,9 @@ def differentiate(expression,
variable,
func_mapper=map_math_functions_by_name,
allow_non_smooth=False,
allow_discontinuity=False):
allow_discontinuous=False):
if not isinstance(variable, (primitives.Variable, primitives.Subscript)):
variable = primitives.make_variable(variable)
return DifferentiationMapper(variable, func_mapper,
allow_non_smooth=allow_non_smooth,
allow_discontinuity=allow_discontinuity)(expression)
allow_discontinuous=allow_discontinuous)(expression)
......@@ -641,6 +641,25 @@ def test_multiplicative_stringify_preserves_association():
assert_parse_roundtrip("(-1)*(((-1)*x) / 5)")
def test_differentiator_flags_for_nonsmooth_and_discontinuous():
import pymbolic.functions as pf
from pymbolic.mapper.differentiator import differentiate
x = prim.Variable('x')
with pytest.raises(ValueError):
differentiate(pf.fabs(x), x)
result = differentiate(pf.fabs(x), x, allow_non_smooth=True)
assert result == pf.sign(x)
with pytest.raises(ValueError):
differentiate(pf.sign(x), x)
result = differentiate(pf.sign(x), x, allow_discontinuous=True)
assert result == 0
if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment