diff --git a/pymbolic/geometric_algebra/mapper.py b/pymbolic/geometric_algebra/mapper.py index 0da2cf9e1fe9750340600bfbeb9aee1376cee488..8129757ca41c68300960b7f61f1045e19d3a7003 100644 --- a/pymbolic/geometric_algebra/mapper.py +++ b/pymbolic/geometric_algebra/mapper.py @@ -30,13 +30,13 @@ THE SOFTWARE. # Consider yourself warned. from pymbolic.geometric_algebra import MultiVector -import pymbolic.primitives as pprim import pymbolic.geometric_algebra.primitives as prim from pymbolic.mapper import ( CombineMapper as CombineMapperBase, Collector as CollectorBase, IdentityMapper as IdentityMapperBase, - WalkMapper as WalkMapperBase + WalkMapper as WalkMapperBase, + CachingMapperMixin ) from pymbolic.mapper.stringifier import ( StringifyMapper as StringifyMapperBase, @@ -154,7 +154,7 @@ class Dimensionalizer(EvaluationMapper): # {{{ derivative binder -class DerivativeSourceAndNablaComponentCollector(Collector): +class DerivativeSourceAndNablaComponentCollector(CachingMapperMixin, Collector): def map_nabla(self, expr): raise RuntimeError("DerivativeOccurrenceMapper must be invoked after " "Dimensionalizer--Nabla found, not allowed") @@ -163,7 +163,7 @@ class DerivativeSourceAndNablaComponentCollector(Collector): return set([expr]) def map_derivative_source(self, expr): - return set([expr]) + return set([expr]) | self.rec(expr.operand) class NablaComponentToUnitVector(EvaluationMapper): @@ -171,6 +171,9 @@ class NablaComponentToUnitVector(EvaluationMapper): self.nabla_id = nabla_id self.ambient_axis = ambient_axis + def map_variable(self, expr): + return expr + def map_nabla_component(self, expr): if expr.nabla_id == self.nabla_id: if expr.ambient_axis == self.ambient_axis: @@ -209,9 +212,9 @@ class DerivativeBinder(IdentityMapper): self.derivative_collector = \ self.derivative_source_and_nabla_component_collector() - def do_bind(self, rec_children): + def map_product(self, expr): # We may write to this below. Make a copy. - rec_children = list(rec_children) + children = list(expr.children) # {{{ gather NablaComponents and DerivativeSources @@ -220,7 +223,7 @@ class DerivativeBinder(IdentityMapper): # id to set((child index, axis), ...) nabla_finder = {} - for child_idx, rec_child in enumerate(rec_children): + for child_idx, rec_child in enumerate(children): nabla_component_ids = set() derivative_source_ids = set() @@ -245,10 +248,10 @@ class DerivativeBinder(IdentityMapper): # }}} # a list of lists, the outer level presenting a sum, the inner a product - result = [rec_children] + result = [children] for child_idx, (d_source_nabla_ids, child) in enumerate( - zip(d_source_nabla_ids_per_child, rec_children)): + zip(d_source_nabla_ids_per_child, children)): if not d_source_nabla_ids: continue @@ -257,7 +260,11 @@ class DerivativeBinder(IdentityMapper): "child in a product") nabla_id, = d_source_nabla_ids - nablas = nabla_finder[nabla_id] + try: + nablas = nabla_finder[nabla_id] + except KeyError: + continue + n_axes = max(axis for _, axis in nablas) + 1 new_result = [] @@ -278,15 +285,38 @@ class DerivativeBinder(IdentityMapper): from pymbolic.primitives import flattened_sum return flattened_sum( - pprim.Product(tuple( + type(expr)(tuple( self.rec(prod_term) for prod_term in prod_term_list)) for prod_term_list in result) - def map_product(self, expr): - return self.do_bind(expr.children) + map_bitwise_xor = map_product + map_bitwise_or = map_product + map_left_shift = map_product + map_right_shift = map_product def map_derivative_source(self, expr): - return self.do_bind([expr]) + rec_operand = self.rec(expr.operand) + + nablas = [] + for d_or_n in self.derivative_collector(rec_operand): + if isinstance(d_or_n, prim.NablaComponent): + nablas.append(d_or_n) + elif isinstance(d_or_n, prim.DerivativeSource): + pass + else: + raise RuntimeError("unexpected result from " + "DerivativeSourceAndNablaComponentCollector") + + n_axes = max(n.ambient_axis for n in nablas) + 1 + assert n_axes + + from pymbolic.primitives import flattened_sum + return flattened_sum( + self.take_derivative( + axis, + self.nabla_component_to_unit_vector(expr.nabla_id, axis) + (rec_operand)) + for axis in range(n_axes)) # }}}