diff --git a/pymbolic/geometric_algebra/mapper.py b/pymbolic/geometric_algebra/mapper.py index a3824bcc2e4a39a3f10d6857288509b30b743613..a33fce06cbf7f4bab6db411032b7cff1e29608df 100644 --- a/pymbolic/geometric_algebra/mapper.py +++ b/pymbolic/geometric_algebra/mapper.py @@ -26,6 +26,7 @@ THE SOFTWARE. # Consider yourself warned. from pymbolic.geometric_algebra import MultiVector +import pymbolic.primitives as pprim import pymbolic.geometric_algebra.primitives as prim from pymbolic.mapper import ( CombineMapper as CombineMapperBase, @@ -179,27 +180,32 @@ class DerivativeSourceFinder(EvaluationMapper): class DerivativeBinder(IdentityMapper): + derivative_source_and_nabla_component_collector = \ + DerivativeSourceAndNablaComponentCollector + nabla_component_to_unit_vector = NablaComponentToUnitVector + derivative_source_finder = DerivativeSourceFinder + def __init__(self): - self.derivative_collector = DerivativeSourceAndNablaComponentCollector() + self.derivative_collector = \ + self.derivative_source_and_nabla_component_collector() + + def do_bind(self, rec_children): + # We may write to this below. Make a copy. + rec_children = list(rec_children) - def map_product(self, expr): # {{{ gather NablaComponents and DerivativeSources - rec_children = [] d_source_nabla_ids_per_child = [] # id to set((child index, axis), ...) nabla_finder = {} - for child_idx, child in enumerate(expr.children): - rec_expr = self.rec(child) - rec_children.append(rec_expr) - + for child_idx, rec_child in enumerate(rec_children): nabla_component_ids = set() derivative_source_ids = set() nablas = [] - for d_or_n in self.derivative_collector(rec_expr): + for d_or_n in self.derivative_collector(rec_child): if isinstance(d_or_n, prim.NablaComponent): nabla_component_ids.add(d_or_n.nabla_id) nablas.append(d_or_n) @@ -238,12 +244,12 @@ class DerivativeBinder(IdentityMapper): for prod_term_list in result: for axis in xrange(n_axes): new_ptl = prod_term_list[:] - dsfinder = DerivativeSourceFinder(nabla_id, self, axis) + dsfinder = self.derivative_source_finder(nabla_id, self, axis) new_ptl[child_idx] = dsfinder(new_ptl[child_idx]) for nabla_child_index, _ in nablas: new_ptl[nabla_child_index] = \ - NablaComponentToUnitVector(nabla_id, axis)( + self.nabla_component_to_unit_vector(nabla_id, axis)( new_ptl[nabla_child_index]) new_result.append(new_ptl) @@ -252,7 +258,15 @@ class DerivativeBinder(IdentityMapper): from pymbolic.primitives import flattened_sum return flattened_sum( - type(expr)(tuple(prod_term_list)) for prod_term_list in result) + pprim.Product(tuple( + self.rec(prod_term) for prod_term in prod_term_list)) + for prod_term_list in result) + + def map_product(self, expr): + return self.do_bind(expr.children) + + def map_derivative_source(self, expr): + return self.do_bind([expr]) # }}}