diff --git a/src/mapper/__init__.py b/src/mapper/__init__.py index f2994bb4264a2d23539e9e8dc67921d149d3fd3b..d54d1dafd9c50b5a509196cde2e7fa14e7578740 100644 --- a/src/mapper/__init__.py +++ b/src/mapper/__init__.py @@ -124,9 +124,12 @@ class CombineMapper(RecursiveMapper): map_list = map_sum map_vector = map_sum - def map_numpy_array(self, expr): + def map_numpy_array(self, expr, *args): return self.combine(self.rec(el) for el in expr.flat) + def map_common_subexpression(self, expr, *args): + return self.rec(expr.child, *args) + class IdentityMapperBase(object): @@ -191,6 +194,10 @@ class IdentityMapperBase(object): result[i] = self.rec(expr[i]) return result + def map_common_subexpression(self, expr, *args, **kwargs): + return expr.__class__(self.rec(expr.child, *args, **kwargs)) + + class IdentityMapper(IdentityMapperBase, RecursiveMapper): diff --git a/src/mapper/c_code.py b/src/mapper/c_code.py new file mode 100644 index 0000000000000000000000000000000000000000..3dde4e9be85aedbd5bf58110d85032c1e1dfb5d4 --- /dev/null +++ b/src/mapper/c_code.py @@ -0,0 +1,48 @@ +from pymbolic.mapper.stringifier import SimplifyingSortingStringifyMapper + + + + +class CCodeMapper(SimplifyingSortingStringifyMapper): + def __init__(self, constant_mapper=repr, reverse=True, cse_prefix="cse"): + SimplifyingSortingStringifyMapper.__init__(self, constant_mapper, reverse) + self.cse_prefix = cse_prefix + self.cses = [] + self.cse_to_index = {} + + # mappings ---------------------------------------------------------------- + def map_power(self, expr, enclosing_prec): + from pymbolic.mapper.stringifier import PREC_NONE + from pymbolic.primitives import is_constant, is_zero + if is_constant(expr.exponent): + if is_zero(expr.exponent): + return "1" + elif is_zero(expr.exponent - 1): + return self.rec(expr.base, enclosing_prec) + elif is_zero(expr.exponent - 2): + return self.rec(expr.base*expr.base, enclosing_prec) + + return self.format("pow(%s, %s)", + self.rec(expr.base, PREC_NONE), + self.rec(expr.exponent, PREC_NONE)) + + def map_if_positive(self, expr, enclosing_prec): + # This occurs in hedge fluxes. We cheat and define it here. + from pymbolic.mapper.stringifier import PREC_NONE + return self.format("(%s > 0 ? %s : %s)", + self.rec(expr.criterion, PREC_NONE), + self.rec(expr.then, PREC_NONE), + self.rec(expr.else_, PREC_NONE), + ) + + def map_common_subexpression(self, expr, enclosing_prec): + try: + cse_index = self.cse_to_index[expr] + except KeyError: + from pymbolic.mapper.stringifier import PREC_NONE + my_cse_str = self.rec(expr.child, PREC_NONE) + cse_index = len(self.cses) + self.cse_to_index[expr] = cse_index + self.cses.append(my_cse_str) + + return self.cse_prefix + str(cse_index) diff --git a/src/mapper/dependency.py b/src/mapper/dependency.py index 8d1926891f023de99eda9b7a9574b4fd6a644bf2..389ddda36eebd8bc2dbab648dc1761e6815fafd9 100644 --- a/src/mapper/dependency.py +++ b/src/mapper/dependency.py @@ -8,6 +8,7 @@ class DependencyMapper(CombineMapper): include_subscripts=True, include_lookups=True, include_calls=True, + include_cses=False, composite_leaves=None): if composite_leaves == False: @@ -19,16 +20,22 @@ class DependencyMapper(CombineMapper): include_lookups = True include_calls = True - self.IncludeSubscripts = include_subscripts - self.IncludeLookups = include_lookups - self.IncludeCalls = include_calls + self.include_subscripts = include_subscripts + self.include_lookups = include_lookups + self.include_calls = include_calls + + self.include_cses = include_cses def combine(self, values): import operator return reduce(operator.or_, values, set()) - def handle_unsupported_expression(self, expr, *args, **kwargs): - return set([expr]) + def handle_unsupported_expression(self, expr): + from pymbolic.primitives import AlgebraicLeaf + if isinstance(expr, AlgebraicLeaf): + return set([expr]) + else: + CombineMapper.handle_unsupported_expression(self, expr) def map_constant(self, expr): return set() @@ -37,23 +44,29 @@ class DependencyMapper(CombineMapper): return set([expr]) def map_call(self, expr): - if self.IncludeCalls: + if self.include_calls: return set([expr]) else: return CombineMapper.map_call(self, expr) def map_lookup(self, expr): - if self.IncludeLookups: + if self.include_lookups: return set([expr]) else: return CombineMapper.map_lookup(self, expr) def map_subscript(self, expr): - if self.IncludeSubscripts: + if self.include_subscripts: return set([expr]) else: return CombineMapper.map_subscript(self, expr) + def map_common_subexpression(self, expr): + if self.include_cses: + return set([expr]) + else: + return CombineMapper.map_common_subexpression(self, expr) + diff --git a/src/mapper/evaluator.py b/src/mapper/evaluator.py index d5b35979add66da93147cabdd4aa354f91722002..98799c952c925e7b8eae0336135b39b50252bf7f 100644 --- a/src/mapper/evaluator.py +++ b/src/mapper/evaluator.py @@ -12,14 +12,15 @@ class UnknownVariableError(Exception): class EvaluationMapper(RecursiveMapper): def __init__(self, context={}): - self.Context = context + self.context = context + self.common_subexp_cache = {} def map_constant(self, expr): return expr def map_variable(self, expr): try: - return self.Context[expr.name] + return self.context[expr.name] except KeyError: raise UnknownVariableError, expr.name @@ -71,6 +72,14 @@ class EvaluationMapper(RecursiveMapper): result[i] = self.rec(expr[i]) return result + def map_common_subexpression(self, expr, out=None): + try: + return self.common_subexp_cache[expr.child] + except KeyError: + self.common_subexp_cache[expr.child] = value = self.rec(expr.child) + return value + + class FloatEvaluationMapper(EvaluationMapper): diff --git a/src/mapper/stringifier.py b/src/mapper/stringifier.py index 4005bc60972ed86812e2c7cae343836d3847fb44..0ae98c4d6944214b20c32e4d56d5626a89ab812f 100644 --- a/src/mapper/stringifier.py +++ b/src/mapper/stringifier.py @@ -16,6 +16,27 @@ class StringifyMapper(pymbolic.mapper.RecursiveMapper): def __init__(self, constant_mapper=str): self.constant_mapper = constant_mapper + # replaceable string composition interface -------------------------------- + def format(self, s, *args): + return s % args + + def join(self, joiner, iterable): + return self.format(joiner.join("%s" for i in iterable), *iterable) + + def join_rec(self, joiner, iterable, prec): + f = joiner.join("%s" for i in iterable) + return self.format(f, *[self.rec(i, prec) for i in iterable]) + + def parenthesize(self, s): + return "(%s)" % s + + def parenthesize_if_needed(self, s, enclosing_prec, my_prec): + if enclosing_prec > my_prec: + return "(%s)" % s + else: + return s + + # mappings ---------------------------------------------------------------- def handle_unsupported_expression(self, victim, enclosing_prec): strifier = victim.stringifier() if isinstance(self, strifier): @@ -31,7 +52,7 @@ class StringifyMapper(pymbolic.mapper.RecursiveMapper): or (isinstance(expr, complex) and expr.imag and expr.real) ) and (enclosing_prec > PREC_SUM): - return "(%s)" % result + return self.parenthesize(result) else: return result @@ -40,98 +61,64 @@ class StringifyMapper(pymbolic.mapper.RecursiveMapper): return expr.name def map_call(self, expr, enclosing_prec): - result = "%s(%s)" % \ - (self.rec(expr.function, PREC_CALL), - ", ".join(self.rec(i, PREC_NONE) - for i in expr.parameters)) - if enclosing_prec > PREC_CALL: - return "(%s)" % result - else: - return result + return self.format("%s(%s)", + self.rec(expr.function, PREC_CALL), + self.join_rec(", ", expr.parameters, PREC_NONE)) def map_subscript(self, expr, enclosing_prec): - result = "%s[%s]" % \ - (self.rec(expr.aggregate, PREC_CALL), self.rec(expr.index, PREC_CALL)) - if enclosing_prec > PREC_CALL: - return "(%s)" % result - else: - return result + return self.parenthesize_if_needed( + self.format("%s[%s]", + self.rec(expr.aggregate, PREC_CALL), + self.rec(expr.index, PREC_NONE)), + enclosing_prec, PREC_CALL) def map_lookup(self, expr, enclosing_prec): - result = "%s.%s" % (self.rec(expr.aggregate, PREC_CALL), expr.name) - if enclosing_prec > PREC_CALL: - return "(%s)" % result - else: - return result + return self.parenthesize_if_needed( + self.format("%s.%s", + self.rec(expr.aggregate, PREC_CALL), + expr.name), + enclosing_prec, PREC_CALL) def map_sum(self, expr, enclosing_prec): - result = " + ".join(self.rec(i, PREC_SUM) for i in expr.children) - if enclosing_prec > PREC_SUM: - return "(%s)" % result - else: - return result + return self.parenthesize_if_needed( + self.join_rec(" + ", expr.children, PREC_SUM), + enclosing_prec, PREC_SUM) def map_product(self, expr, enclosing_prec): - result = "*".join(self.rec(i, PREC_PRODUCT) for i in expr.children) - if enclosing_prec > PREC_PRODUCT: - return "(%s)" % result - else: - return result + return self.parenthesize_if_needed( + self.join_rec("*", expr.children, PREC_PRODUCT), + enclosing_prec, PREC_PRODUCT) def map_quotient(self, expr, enclosing_prec): - result = "%s/%s" % ( - self.rec(expr.numerator, PREC_PRODUCT), - self.rec(expr.denominator, PREC_PRODUCT) - ) - if enclosing_prec > PREC_PRODUCT: - return "(%s)" % result - else: - return result + return self.parenthesize_if_needed( + self.format("%s/%s", + self.rec(expr.numerator, PREC_PRODUCT), + self.rec(expr.denominator, PREC_POWER)), # analogous to ^{-1} + enclosing_prec, PREC_PRODUCT) def map_power(self, expr, enclosing_prec): - result = "%s**%s" % ( - self.rec(expr.base, PREC_POWER), - self.rec(expr.exponent, PREC_POWER) - ) - if enclosing_prec > PREC_POWER: - return "(%s)" % result - else: - return result + return self.parenthesize_if_needed( + self.format("%s**%s", + self.rec(expr.numerator, PREC_POWER), + self.rec(expr.denominator, PREC_POWER)), + enclosing_prec, PREC_POWER) def map_polynomial(self, expr, enclosing_prec): - sbase = self(expr.base, PREC_POWER) - def stringify_expcoeff((exp, coeff)): - if exp == 0: - return self(coeff, PREC_SUM) - elif exp == 1: - strexp = "" - else: - strexp = "**%s" % exp - - if not (coeff-1): - return "%s%s" % (sbase, strexp) - elif not (coeff+1): - return "-%s%s" % (sbase, strexp) - else: - return "%s*%s%s" % (self(coeff, PREC_PRODUCT), sbase, strexp) - - if not expr.data: - return "0" - - result = "%s" % " + ".join(stringify_expcoeff(i) for i in expr.data[::-1]) - if enclosing_prec > PREC_SUM and len(expr.data) > 1: - return "(%s)" % result - else: - return result + from pymbolic.primitives import flattened_sum + return self.rec(flattened_sum( + [coeff*base**exp for exp, coeff in expr.data[::-1]]), + enclosing_prec) def map_list(self, expr, enclosing_prec): - return "[%s]" % ", ".join([self.rec(i, PREC_NONE) for i in expr]) + return self.format("[%s]", self.join_rec(", ", expr, PREC_NONE)) map_vector = map_list def map_numpy_array(self, expr, enclosing_prec): - return 'array(%s)' % str(expr) + return self.format('array(%s)', str(expr)) + def map_common_subexpression(self, expr, enclosing_prec): + return self.format("CSE(%s)", self.rec(expr.child, PREC_NONE)) @@ -143,22 +130,16 @@ class SortingStringifyMapper(StringifyMapper): def map_sum(self, expr, enclosing_prec): entries = [self.rec(i, PREC_SUM) for i in expr.children] entries.sort(reverse=self.reverse) - result = " + ".join(entries) - - if enclosing_prec > PREC_SUM: - return "(%s)" % result - else: - return result + return self.parenthesize_if_needed( + self.join(" + ", entries), + enclosing_prec, PREC_SUM) def map_product(self, expr, enclosing_prec): entries = [self.rec(i, PREC_PRODUCT) for i in expr.children] entries.sort(reverse=self.reverse) - result = "*".join(entries) - - if enclosing_prec > PREC_PRODUCT: - return "(%s)" % result - else: - return result + return self.parenthesize_if_needed( + self.join("*", entries), + enclosing_prec, PREC_PRODUCT) @@ -195,14 +176,12 @@ class SimplifyingSortingStringifyMapper(StringifyMapper): positives.sort(reverse=self.reverse) positives = " + ".join(positives) negatives.sort(reverse=self.reverse) - negatives = "".join(" - " + entry for entry in negatives) + negatives = self.join("", + [self.format(" - %s", entry) for entry in negatives]) result = positives + negatives - if enclosing_prec > PREC_SUM: - return "(%s)" % result - else: - return result + return self.parenthesize_if_needed(result, enclosing_prec, PREC_SUM) def map_product(self, expr, enclosing_prec): entries = [] @@ -214,7 +193,8 @@ class SimplifyingSortingStringifyMapper(StringifyMapper): if False and is_zero(child+1) and i+1 < len(expr.children): # NOTE: That space needs to be there. # Otherwise two unary minus signs merge into a pre-decrement. - entries.append("- %s" % self.rec(expr.children[i+1], PREC_UNARY)) + entries.append( + self.format("- %s", self.rec(expr.children[i+1], PREC_UNARY))) i += 2 else: entries.append(self.rec(child, PREC_PRODUCT)) @@ -223,7 +203,4 @@ class SimplifyingSortingStringifyMapper(StringifyMapper): entries.sort(reverse=self.reverse) result = "*".join(entries) - if enclosing_prec > PREC_PRODUCT: - return "(%s)" % result - else: - return result + return self.parenthesize_if_needed(result, enclosing_prec, PREC_PRODUCT) diff --git a/src/primitives.py b/src/primitives.py index 1f61326f720ca937eafb7a859f4318af5f2c584e..14b066fabb321f0c4f3c8e5a3961995a35b34613 100644 --- a/src/primitives.py +++ b/src/primitives.py @@ -177,12 +177,17 @@ class Expression(object): class AlgebraicLeaf(Expression): + """An expression that serves as a leaf for arithmetic evaluation. + This may end up having child nodes still, but they're not reached by + ways of arithmetic.""" pass class Leaf(AlgebraicLeaf): + """An expression that is irreducible, i.e. has no Expression-type parts + whatsoever.""" pass @@ -529,6 +534,26 @@ class Vector(Expression): +class CommonSubexpression(Expression): + def __init__(self, child): + self.child = child + + def __getinitargs__(self): + return (self.child,) + + def get_hash(self): + return hash((self.__class__, self.child)) + + def is_equal(self, other): + return (other.__class__ == self.__class__ + and other.child == self.child) + + def get_mapper_method(self, mapper): + return mapper.map_common_subexpression + + + + # intelligent makers --------------------------------------------------------- def make_variable(var_or_string): if not isinstance(var_or_string, Expression):