From 86f5a6f4029a0b40049ad46be76a2398b1f868cf Mon Sep 17 00:00:00 2001 From: Isuru Fernando <idf2@illinois.edu> Date: Thu, 2 Sep 2021 09:11:13 -0500 Subject: [PATCH] support quotient and power in CoefficientCollector --- pymbolic/mapper/coefficient.py | 23 +++++++++++++++++++++++ test/test_pymbolic.py | 12 ++++++++++++ 2 files changed, 35 insertions(+) diff --git a/pymbolic/mapper/coefficient.py b/pymbolic/mapper/coefficient.py index 73276ed..1d23ae5 100644 --- a/pymbolic/mapper/coefficient.py +++ b/pymbolic/mapper/coefficient.py @@ -71,6 +71,29 @@ class CoefficientCollector(Mapper): return result + def map_quotient(self, expr): + from pymbolic.primitives import Quotient + d_num = self.rec(expr.numerator) + d_den = self.rec(expr.denominator) + # d_den should look like {1: k} + if len(d_den) > 1 or 1 not in d_den: + raise RuntimeError("nonlinear expression") + val = d_den[1] + for k in d_num.keys(): + d_num[k] *= Quotient(1, val) + return d_num + + def map_power(self, expr): + d_base = self.rec(expr.base) + d_exponent = self.rec(expr.exponent) + # d_exponent should look like {1: k} + if len(d_exponent) > 1 or 1 not in d_exponent: + raise RuntimeError("nonlinear expression") + # d_base should look like {1: k} + if len(d_base) > 1 or 1 not in d_base: + raise RuntimeError("nonlinear expression") + return {1: expr} + def map_constant(self, expr): return {1: expr} diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 6f58587..428623d 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -659,6 +659,18 @@ def test_differentiator_flags_for_nonsmooth_and_discontinuous(): assert result == 0 +def test_coefficient_collector(): + from pymbolic.mapper.coefficient import CoefficientCollector + x = prim.Variable("x") + y = prim.Variable("y") + z = prim.Variable("z") + + cc = CoefficientCollector([x.name, y.name]) + assert cc(2*x + y) == {x: 2, y: 1} + assert cc(2*x + y - z) == {x: 2, y: 1, 1: -z} + assert cc(x/2 + z**2) == {x: prim.Quotient(1, 2), 1: z**2} + + def test_np_bool_handling(): from pymbolic.mapper.evaluator import evaluate numpy = pytest.importorskip("numpy") -- GitLab