From cd626ad0ca82378509a0aa7c94ac0c2feb605773 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Fri, 28 Nov 2014 13:41:40 -0600
Subject: [PATCH] Fix derivative binder to deal with nablas nested inside of
 derivative source

---
 pymbolic/geometric_algebra/mapper.py | 58 +++++++++++++++++++++-------
 1 file changed, 44 insertions(+), 14 deletions(-)

diff --git a/pymbolic/geometric_algebra/mapper.py b/pymbolic/geometric_algebra/mapper.py
index 0da2cf9..8129757 100644
--- a/pymbolic/geometric_algebra/mapper.py
+++ b/pymbolic/geometric_algebra/mapper.py
@@ -30,13 +30,13 @@ THE SOFTWARE.
 # Consider yourself warned.
 
 from pymbolic.geometric_algebra import MultiVector
-import pymbolic.primitives as pprim
 import pymbolic.geometric_algebra.primitives as prim
 from pymbolic.mapper import (
         CombineMapper as CombineMapperBase,
         Collector as CollectorBase,
         IdentityMapper as IdentityMapperBase,
-        WalkMapper as WalkMapperBase
+        WalkMapper as WalkMapperBase,
+        CachingMapperMixin
         )
 from pymbolic.mapper.stringifier import (
         StringifyMapper as StringifyMapperBase,
@@ -154,7 +154,7 @@ class Dimensionalizer(EvaluationMapper):
 
 # {{{ derivative binder
 
-class DerivativeSourceAndNablaComponentCollector(Collector):
+class DerivativeSourceAndNablaComponentCollector(CachingMapperMixin, Collector):
     def map_nabla(self, expr):
         raise RuntimeError("DerivativeOccurrenceMapper must be invoked after "
                 "Dimensionalizer--Nabla found, not allowed")
@@ -163,7 +163,7 @@ class DerivativeSourceAndNablaComponentCollector(Collector):
         return set([expr])
 
     def map_derivative_source(self, expr):
-        return set([expr])
+        return set([expr]) | self.rec(expr.operand)
 
 
 class NablaComponentToUnitVector(EvaluationMapper):
@@ -171,6 +171,9 @@ class NablaComponentToUnitVector(EvaluationMapper):
         self.nabla_id = nabla_id
         self.ambient_axis = ambient_axis
 
+    def map_variable(self, expr):
+        return expr
+
     def map_nabla_component(self, expr):
         if expr.nabla_id == self.nabla_id:
             if expr.ambient_axis == self.ambient_axis:
@@ -209,9 +212,9 @@ class DerivativeBinder(IdentityMapper):
         self.derivative_collector = \
                 self.derivative_source_and_nabla_component_collector()
 
-    def do_bind(self, rec_children):
+    def map_product(self, expr):
         # We may write to this below. Make a copy.
-        rec_children = list(rec_children)
+        children = list(expr.children)
 
         # {{{ gather NablaComponents and DerivativeSources
 
@@ -220,7 +223,7 @@ class DerivativeBinder(IdentityMapper):
         # id to set((child index, axis), ...)
         nabla_finder = {}
 
-        for child_idx, rec_child in enumerate(rec_children):
+        for child_idx, rec_child in enumerate(children):
             nabla_component_ids = set()
             derivative_source_ids = set()
 
@@ -245,10 +248,10 @@ class DerivativeBinder(IdentityMapper):
         # }}}
 
         # a list of lists, the outer level presenting a sum, the inner a product
-        result = [rec_children]
+        result = [children]
 
         for child_idx, (d_source_nabla_ids, child) in enumerate(
-                zip(d_source_nabla_ids_per_child, rec_children)):
+                zip(d_source_nabla_ids_per_child, children)):
             if not d_source_nabla_ids:
                 continue
 
@@ -257,7 +260,11 @@ class DerivativeBinder(IdentityMapper):
                         "child in a product")
 
             nabla_id, = d_source_nabla_ids
-            nablas = nabla_finder[nabla_id]
+            try:
+                nablas = nabla_finder[nabla_id]
+            except KeyError:
+                continue
+
             n_axes = max(axis for _, axis in nablas) + 1
 
             new_result = []
@@ -278,15 +285,38 @@ class DerivativeBinder(IdentityMapper):
 
         from pymbolic.primitives import flattened_sum
         return flattened_sum(
-                    pprim.Product(tuple(
+                    type(expr)(tuple(
                         self.rec(prod_term) for prod_term in prod_term_list))
                     for prod_term_list in result)
 
-    def map_product(self, expr):
-        return self.do_bind(expr.children)
+    map_bitwise_xor = map_product
+    map_bitwise_or = map_product
+    map_left_shift = map_product
+    map_right_shift = map_product
 
     def map_derivative_source(self, expr):
-        return self.do_bind([expr])
+        rec_operand = self.rec(expr.operand)
+
+        nablas = []
+        for d_or_n in self.derivative_collector(rec_operand):
+            if isinstance(d_or_n, prim.NablaComponent):
+                nablas.append(d_or_n)
+            elif isinstance(d_or_n, prim.DerivativeSource):
+                pass
+            else:
+                raise RuntimeError("unexpected result from "
+                        "DerivativeSourceAndNablaComponentCollector")
+
+        n_axes = max(n.ambient_axis for n in nablas) + 1
+        assert n_axes
+
+        from pymbolic.primitives import flattened_sum
+        return flattened_sum(
+                self.take_derivative(
+                    axis,
+                    self.nabla_component_to_unit_vector(expr.nabla_id, axis)
+                    (rec_operand))
+                for axis in range(n_axes))
 
 # }}}
 
-- 
GitLab