From 1d3f4a982afc25852d3cc2c639305ce88ed762ad Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Sun, 16 Nov 2008 13:03:17 -0600 Subject: [PATCH] Add SimplifyingSortingStringifyMapper. --- src/mapper/stringifier.py | 63 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/src/mapper/stringifier.py b/src/mapper/stringifier.py index a80a26b..861f101 100644 --- a/src/mapper/stringifier.py +++ b/src/mapper/stringifier.py @@ -131,3 +131,66 @@ class StringifyMapper(pymbolic.mapper.RecursiveMapper): def map_numpy_array(self, expr, enclosing_prec): return 'array(%s)' % str(expr) + + + + +class SimplifyingSortingStringifyMapper(StringifyMapper): + def map_sum(self, expr, enclosing_prec): + entries = [self.rec(i, PREC_SUM) for i in expr.children] + + def get_neg_product(expr): + from pymbolic.primitives import is_zero, Product + + if isinstance(expr, Product) \ + and len(expr.children) and is_zero(expr.children[0]+1): + return Product(expr.children[1:]) + else: + return None + + positives = [] + negatives = [] + + for ch in expr.children: + neg_prod = get_neg_product(ch) + if neg_prod is not None: + negatives.append(self.rec(neg_prod, PREC_SUM)) + else: + positives.append(self.rec(ch, PREC_SUM)) + + positives.sort() + positives = " + ".join(positives[::-1]) + negatives.sort() + negatives = "".join(" - " + entry for entry in negatives[::-1]) + + result = positives + negatives + + if enclosing_prec > PREC_SUM: + return "(%s)" % result + else: + return result + + def map_product(self, expr, enclosing_prec): + def generate_entries(): + i = 0 + from pymbolic.primitives import is_zero + + while i < len(expr.children): + child = expr.children[i] + if is_zero(child+1) and i+1 < len(expr.children): + # NOTE: That space needs to be there. + # Otherwise two unary minus signs merge into a pre-decrement. + yield "- %s" % self.rec(expr.children[i+1], PREC_UNARY) + i += 2 + else: + yield self.rec(child, PREC_PRODUCT) + i += 1 + + entries = list(generate_entries()) + entries.sort() + #entries = entries[::-1] + + if enclosing_prec > PREC_PRODUCT: + return "(%s)" % "*".join(entries) + else: + return "*".join(entries) -- GitLab