From 6a9b843a329956e7c8cd03db50f1fd71efb9ba81 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 8 Jul 2011 00:53:48 -0400 Subject: [PATCH] Faster mapper method selection, plus add FloorDiv class. --- pymbolic/mapper/__init__.py | 34 ++++++++++++++---------- pymbolic/mapper/stringifier.py | 39 ++++++++++++++++------------ pymbolic/polynomial.py | 3 +-- pymbolic/primitives.py | 47 +++++++++++++++------------------- pymbolic/rational.py | 3 +-- 5 files changed, 65 insertions(+), 61 deletions(-) diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index ffde8f0..2fba9ab 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -28,14 +28,17 @@ class Mapper(object): def __call__(self, expr, *args): try: - method = expr.get_mapper_method(self) + method = getattr(self, expr.mapper_method) except AttributeError: - if isinstance(expr, primitives.Expression): - return self.handle_unsupported_expression(expr, *args) - else: - return self.map_foreign(expr, *args) - else: - return method(expr, *args) + try: + method = expr.get_mapper_method(self) + except AttributeError: + if isinstance(expr, primitives.Expression): + return self.handle_unsupported_expression(expr, *args) + else: + return self.map_foreign(expr, *args) + + return method(expr, *args) def map_variable(self, expr, *args): return self.map_algebraic_leaf(expr, *args) @@ -73,14 +76,17 @@ class Mapper(object): class RecursiveMapper(Mapper): def rec(self, expr, *args, **kwargs): try: - method = expr.get_mapper_method(self) + method = getattr(self, expr.mapper_method) except AttributeError: - if isinstance(expr, primitives.Expression): - return self.handle_unsupported_expression(expr, *args, **kwargs) - else: - return self.map_foreign(expr, *args, **kwargs) - else: - return method(expr, *args) + try: + method = expr.get_mapper_method(self) + except AttributeError: + if isinstance(expr, primitives.Expression): + return self.handle_unsupported_expression(expr, *args, **kwargs) + else: + return self.map_foreign(expr, *args, **kwargs) + + return method(expr, *args) diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py index 17f5aae..68c57b4 100644 --- a/pymbolic/mapper/stringifier.py +++ b/pymbolic/mapper/stringifier.py @@ -40,7 +40,7 @@ class StringifyMapper(pymbolic.mapper.RecursiveMapper): def handle_unsupported_expression(self, victim, enclosing_prec): strifier = victim.stringifier() if isinstance(self, strifier): - raise ValueError("stringifier '%s' can't handle '%s'" + raise ValueError("stringifier '%s' can't handle '%s'" % (self, victim.__class__)) return strifier(self.constant_mapper)(victim, enclosing_prec) @@ -68,15 +68,15 @@ class StringifyMapper(pymbolic.mapper.RecursiveMapper): def map_subscript(self, expr, enclosing_prec): return self.parenthesize_if_needed( - self.format("%s[%s]", - self.rec(expr.aggregate, PREC_CALL), + self.format("%s[%s]", + self.rec(expr.aggregate, PREC_CALL), self.rec(expr.index, PREC_NONE)), enclosing_prec, PREC_CALL) def map_lookup(self, expr, enclosing_prec): return self.parenthesize_if_needed( - self.format("%s.%s", - self.rec(expr.aggregate, PREC_CALL), + self.format("%s.%s", + self.rec(expr.aggregate, PREC_CALL), expr.name), enclosing_prec, PREC_CALL) @@ -92,22 +92,29 @@ class StringifyMapper(pymbolic.mapper.RecursiveMapper): def map_quotient(self, expr, enclosing_prec): return self.parenthesize_if_needed( - self.format("%s/%s", - self.rec(expr.numerator, PREC_PRODUCT), + self.format("%s/%s", + self.rec(expr.numerator, PREC_PRODUCT), + self.rec(expr.denominator, PREC_POWER)), # analogous to ^{-1} + enclosing_prec, PREC_PRODUCT) + + def map_floor_div(self, expr, enclosing_prec): + return self.parenthesize_if_needed( + self.format("%s//%s", + self.rec(expr.numerator, PREC_PRODUCT), self.rec(expr.denominator, PREC_POWER)), # analogous to ^{-1} enclosing_prec, PREC_PRODUCT) def map_power(self, expr, enclosing_prec): return self.parenthesize_if_needed( - self.format("%s**%s", - self.rec(expr.base, PREC_POWER), + self.format("%s**%s", + self.rec(expr.base, PREC_POWER), self.rec(expr.exponent, PREC_POWER)), enclosing_prec, PREC_POWER) def map_remainder(self, expr, enclosing_prec): return self.parenthesize_if_needed( - self.format("%s %% %s", - self.rec(expr.numerator, PREC_PRODUCT), + self.format("%s %% %s", + self.rec(expr.numerator, PREC_PRODUCT), self.rec(expr.denominator, PREC_POWER)), # analogous to ^{-1} enclosing_prec, PREC_PRODUCT) @@ -150,7 +157,7 @@ class StringifyMapper(pymbolic.mapper.RecursiveMapper): def map_if_positive(self, expr, enclosing_prec): return "If(%s > 0, %s, %s)" % ( - self.rec(expr.criterion, PREC_NONE), + self.rec(expr.criterion, PREC_NONE), self.rec(expr.then, PREC_NONE), self.rec(expr.else_, PREC_NONE)) @@ -171,7 +178,7 @@ class CSESplittingStringifyMapperMixin(object): cse_name = self.cse_to_name[expr.child] except KeyError: str_child = self.rec(expr.child, PREC_NONE) - + if expr.prefix is not None: def generate_cse_names(): yield expr.prefix @@ -198,7 +205,7 @@ class CSESplittingStringifyMapperMixin(object): def get_cse_strings(self): return [ "%s : %s" % (cse_name, cse_str) - for cse_name, cse_str in + for cse_name, cse_str in sorted(getattr(self, "cse_name_list", []))] @@ -258,7 +265,7 @@ class SimplifyingSortingStringifyMapper(StringifyMapper): positives.sort(reverse=self.reverse) positives = " + ".join(positives) negatives.sort(reverse=self.reverse) - negatives = self.join("", + negatives = self.join("", [self.format(" - %s", entry) for entry in negatives]) result = positives + negatives @@ -273,7 +280,7 @@ class SimplifyingSortingStringifyMapper(StringifyMapper): while i < len(expr.children): child = expr.children[i] if False and is_zero(child+1) and i+1 < len(expr.children): - # NOTE: That space needs to be there. + # NOTE: That space needs to be there. # Otherwise two unary minus signs merge into a pre-decrement. entries.append( self.format("- %s", self.rec(expr.children[i+1], PREC_UNARY))) diff --git a/pymbolic/polynomial.py b/pymbolic/polynomial.py index 632ef1f..30a67d3 100644 --- a/pymbolic/polynomial.py +++ b/pymbolic/polynomial.py @@ -254,8 +254,7 @@ class Polynomial(Expression): def __getinitargs__(self): return (self.Base, self.Data, self.Unit, self.VarLess) - def get_mapper_method(self, mapper): - return mapper.map_polynomial + mapper_method = intern("map_polynomial") def as_primitives(self): deps = pymbolic.get_dependencies(self) diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 401b0a5..b78ead5 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -238,8 +238,7 @@ class Variable(Leaf): def get_hash(self): return hash((self.__class__, self.name)) - def get_mapper_method(self, mapper): - return mapper.map_variable + mapper_method = intern("map_variable") @@ -260,8 +259,7 @@ class FunctionSymbol(AlgebraicLeaf): def get_hash(self): return hash(self.__class__) - def get_mapper_method(self, mapper): - return mapper.map_function_symbol + mapper_method = intern("map_function_symbol") @@ -293,8 +291,7 @@ class Call(AlgebraicLeaf): def get_hash(self): return hash((self.__class__, self.function, self.parameters)) - def get_mapper_method(self, mapper): - return mapper.map_call + mapper_method = intern("map_call") @@ -315,9 +312,8 @@ class Subscript(AlgebraicLeaf): def get_hash(self): return hash((self.__class__, self.aggregate, self.index)) - def get_mapper_method(self, mapper): - return mapper.map_subscript - + mapper_method = intern("map_subscript") + @@ -337,8 +333,7 @@ class Lookup(AlgebraicLeaf): def get_hash(self): return hash((self.__class__, self.aggregate, self.name)) - def get_mapper_method(self, mapper): - return mapper.map_lookup + mapper_method = intern("map_lookup") @@ -396,8 +391,7 @@ class Sum(Expression): def get_hash(self): return hash((self.__class__, self.children)) - def get_mapper_method(self, mapper): - return mapper.map_sum + mapper_method = intern("map_sum") @@ -445,8 +439,7 @@ class Product(Expression): def get_hash(self): return hash((self.__class__, self.children)) - def get_mapper_method(self, mapper): - return mapper.map_product + mapper_method = intern("map_product") @@ -483,8 +476,13 @@ class Quotient(QuotientBase): and (self.numerator == other.numerator) \ and (self.denominator == other.denominator) - def get_mapper_method(self, mapper): - return mapper.map_quotient + mapper_method = intern("map_quotient") + + + + +class FloorDiv(QuotientBase): + mapper_method = intern("map_floor_div") @@ -496,8 +494,7 @@ class Remainder(QuotientBase): and (self.numerator == other.numerator) \ and (self.denominator == other.denominator) - def get_mapper_method(self, mapper): - return mapper.map_remainder + mapper_method = intern("map_remainder") @@ -518,8 +515,7 @@ class Power(Expression): def get_hash(self): return hash((self.__class__, self.base, self.exponent)) - def get_mapper_method(self, mapper): - return mapper.map_power + mapper_method = intern("map_power") @@ -590,8 +586,7 @@ class Vector(Expression): def get_hash(self): return hash((self.__class__, self.children)) - def get_mapper_method(self, mapper): - return mapper.map_vector + mapper_method = intern("map_vector") @@ -614,8 +609,7 @@ class CommonSubexpression(Expression): def get_extra_properties(self): return {} - def get_mapper_method(self, mapper): - return mapper.map_common_subexpression + mapper_method = intern("map_common_subexpression") @@ -643,8 +637,7 @@ class IfPositive(Expression): self.then, self.else_)) - def get_mapper_method(self, mapper): - return mapper.map_if_positive + mapper_method = intern("map_if_positive") diff --git a/pymbolic/rational.py b/pymbolic/rational.py index 6252a96..444b21a 100644 --- a/pymbolic/rational.py +++ b/pymbolic/rational.py @@ -104,8 +104,7 @@ class Rational(primitives.Expression): def reciprocal(self): return Rational(self.Denominator, self.Numerator) - def get_mapper_method(self, mapper): - return mapper.map_rational + mapper_method = intern("map_rational") -- GitLab