From 86f5a6f4029a0b40049ad46be76a2398b1f868cf Mon Sep 17 00:00:00 2001
From: Isuru Fernando <idf2@illinois.edu>
Date: Thu, 2 Sep 2021 09:11:13 -0500
Subject: [PATCH] support quotient and power in CoefficientCollector

---
 pymbolic/mapper/coefficient.py | 23 +++++++++++++++++++++++
 test/test_pymbolic.py          | 12 ++++++++++++
 2 files changed, 35 insertions(+)

diff --git a/pymbolic/mapper/coefficient.py b/pymbolic/mapper/coefficient.py
index 73276ed..1d23ae5 100644
--- a/pymbolic/mapper/coefficient.py
+++ b/pymbolic/mapper/coefficient.py
@@ -71,6 +71,29 @@ class CoefficientCollector(Mapper):
 
         return result
 
+    def map_quotient(self, expr):
+        from pymbolic.primitives import Quotient
+        d_num = self.rec(expr.numerator)
+        d_den = self.rec(expr.denominator)
+        # d_den should look like {1: k}
+        if len(d_den) > 1 or 1 not in d_den:
+            raise RuntimeError("nonlinear expression")
+        val = d_den[1]
+        for k in d_num.keys():
+            d_num[k] *= Quotient(1, val)
+        return d_num
+
+    def map_power(self, expr):
+        d_base = self.rec(expr.base)
+        d_exponent = self.rec(expr.exponent)
+        # d_exponent should look like {1: k}
+        if len(d_exponent) > 1 or 1 not in d_exponent:
+            raise RuntimeError("nonlinear expression")
+        # d_base should look like {1: k}
+        if len(d_base) > 1 or 1 not in d_base:
+            raise RuntimeError("nonlinear expression")
+        return {1: expr}
+
     def map_constant(self, expr):
         return {1: expr}
 
diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py
index 6f58587..428623d 100644
--- a/test/test_pymbolic.py
+++ b/test/test_pymbolic.py
@@ -659,6 +659,18 @@ def test_differentiator_flags_for_nonsmooth_and_discontinuous():
     assert result == 0
 
 
+def test_coefficient_collector():
+    from pymbolic.mapper.coefficient import CoefficientCollector
+    x = prim.Variable("x")
+    y = prim.Variable("y")
+    z = prim.Variable("z")
+
+    cc = CoefficientCollector([x.name, y.name])
+    assert cc(2*x + y) == {x: 2, y: 1}
+    assert cc(2*x + y - z) == {x: 2, y: 1, 1: -z}
+    assert cc(x/2 + z**2) == {x: prim.Quotient(1, 2), 1: z**2}
+
+
 def test_np_bool_handling():
     from pymbolic.mapper.evaluator import evaluate
     numpy = pytest.importorskip("numpy")
-- 
GitLab