From f6f8365f1f55b832a08b65574f48785793361ddc Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <kloeckner@teramite.rice.edu> Date: Wed, 3 Sep 2008 22:28:56 -0500 Subject: [PATCH] Improve pymbolic speed. --- src/compiler.py | 1 - src/mapper/__init__.py | 165 +++++++++++++++++++++-------------------- src/parser.py | 3 +- src/primitives.py | 45 +++++------ 4 files changed, 103 insertions(+), 111 deletions(-) diff --git a/src/compiler.py b/src/compiler.py index 0d58668..90974f7 100644 --- a/src/compiler.py +++ b/src/compiler.py @@ -1,7 +1,6 @@ import math import pymbolic -import pymbolic.mapper.dependency from pymbolic.mapper.stringifier import StringifyMapper, PREC_NONE, PREC_SUM, PREC_POWER diff --git a/src/mapper/__init__.py b/src/mapper/__init__.py index c8bcc75..28542fd 100644 --- a/src/mapper/__init__.py +++ b/src/mapper/__init__.py @@ -1,3 +1,8 @@ +import pymbolic.primitives as primitives + + + + try: import numpy @@ -14,46 +19,43 @@ class Mapper(object): def __init__(self, recurse=True): self.Recurse = True - def handle_unsupported_expression(self, expr, *args, **kwargs): + def handle_unsupported_expression(self, expr, *args): raise ValueError, "%s cannot handle expressions of type %s" % ( self.__class__, expr.__class__) - def __call__(self, expr, *args, **kwargs): - import pymbolic.primitives as primitives - if isinstance(expr, primitives.Expression): - try: - method = expr.get_mapper_method(self) - except AttributeError: - return self.handle_unsupported_expression(expr, *args, **kwargs) + def __call__(self, expr, *args): + try: + method = expr.get_mapper_method(self) + except AttributeError: + if isinstance(expr, primitives.Expression): + return self.handle_unsupported_expression(expr, *args) else: - return method(expr, *args, **kwargs) + return self.map_foreign(expr, *args) else: - return self.map_foreign(expr, *args, **kwargs) + return method(expr, *args) - def map_variable(self, expr, *args, **kwargs): - return self.map_algebraic_leaf(expr, *args, **kwargs) + def map_variable(self, expr, *args): + return self.map_algebraic_leaf(expr, *args) - def map_subscript(self, expr, *args, **kwargs): - return self.map_algebraic_leaf(expr, *args, **kwargs) + def map_subscript(self, expr, *args): + return self.map_algebraic_leaf(expr, *args) - def map_call(self, expr, *args, **kwargs): - return self.map_algebraic_leaf(expr, *args, **kwargs) + def map_call(self, expr, *args): + return self.map_algebraic_leaf(expr, *args) - def map_lookup(self, expr, *args, **kwargs): - return self.map_algebraic_leaf(expr, *args, **kwargs) + def map_lookup(self, expr, *args): + return self.map_algebraic_leaf(expr, *args) - def map_rational(self, expr, *args, **kwargs): - return self.map_quotient(expr, *args, **kwargs) + def map_rational(self, expr, *args): + return self.map_quotient(expr, *args) - def map_foreign(self, expr, *args, **kwargs): - from pymbolic.primitives import is_constant - - if is_constant(expr): - return self.map_constant(expr, *args, **kwargs) + def map_foreign(self, expr, *args): + if isinstance(expr, primitives.VALID_CONSTANT_CLASSES): + return self.map_constant(expr, *args) elif isinstance(expr, list): - return self.map_list(expr, *args, **kwargs) + return self.map_list(expr, *args) elif is_numpy_array(expr): - return self.map_numpy_array(expr, *args, **kwargs) + return self.map_numpy_array(expr, *args) else: raise ValueError, "%s encountered invalid foreign object: %s" % ( self.__class__, repr(expr)) @@ -63,61 +65,60 @@ class Mapper(object): class RecursiveMapper(Mapper): - def rec(self, expr, *args, **kwargs): - import pymbolic.primitives as primitives - if isinstance(expr, primitives.Expression): - try: - method = expr.get_mapper_method(self) - except AttributeError: - return self.handle_unsupported_expression(expr, *args, **kwargs) + def rec(self, expr, *args): + try: + method = expr.get_mapper_method(self) + except AttributeError: + if isinstance(expr, primitives.Expression): + return self.handle_unsupported_expression(expr, *args) else: - return method(expr, *args, **kwargs) + return self.map_foreign(expr, *args) else: - return self.map_foreign(expr, *args, **kwargs) + return method(expr, *args) class CombineMapper(RecursiveMapper): - def map_call(self, expr, *args, **kwargs): + def map_call(self, expr, *args): return self.combine( - (self.rec(expr.function, *args, **kwargs),) + + (self.rec(expr.function, *args),) + tuple( - self.rec(child, *args, **kwargs) for child in expr.parameters) + self.rec(child, *args) for child in expr.parameters) ) - def map_subscript(self, expr, *args, **kwargs): + def map_subscript(self, expr, *args): return self.combine( - [self.rec(expr.aggregate, *args, **kwargs), - self.rec(expr.index, *args, **kwargs)]) + [self.rec(expr.aggregate, *args), + self.rec(expr.index, *args)]) - def map_lookup(self, expr, *args, **kwargs): - return self.rec(expr.aggregate, *args, **kwargs) + def map_lookup(self, expr, *args): + return self.rec(expr.aggregate, *args) - def map_negation(self, expr, *args, **kwargs): - return self.rec(expr.child, *args, **kwargs) + def map_negation(self, expr, *args): + return self.rec(expr.child, *args) - def map_sum(self, expr, *args, **kwargs): - return self.combine(self.rec(child, *args, **kwargs) + def map_sum(self, expr, *args): + return self.combine(self.rec(child, *args) for child in expr.children) map_product = map_sum - def map_quotient(self, expr, *args, **kwargs): + def map_quotient(self, expr, *args): return self.combine(( - self.rec(expr.numerator, *args, **kwargs), - self.rec(expr.denominator, *args, **kwargs))) + self.rec(expr.numerator, *args), + self.rec(expr.denominator, *args))) - def map_power(self, expr, *args, **kwargs): + def map_power(self, expr, *args): return self.combine(( - self.rec(expr.base, *args, **kwargs), - self.rec(expr.exponent, *args, **kwargs))) + self.rec(expr.base, *args), + self.rec(expr.exponent, *args))) - def map_polynomial(self, expr, *args, **kwargs): + def map_polynomial(self, expr, *args): return self.combine( - (self.rec(expr.base, *args, **kwargs),) + + (self.rec(expr.base, *args),) + tuple( - self.rec(coeff, *args, **kwargs) for exp, coeff in expr.data) + self.rec(coeff, *args) for exp, coeff in expr.data) ) map_list = map_sum @@ -129,54 +130,54 @@ class CombineMapper(RecursiveMapper): class IdentityMapperBase(object): - def map_constant(self, expr, *args, **kwargs): + def map_constant(self, expr, *args): # leaf -- no need to rebuild return expr - def map_variable(self, expr, *args, **kwargs): + def map_variable(self, expr, *args): # leaf -- no need to rebuild return expr - def map_call(self, expr, *args, **kwargs): + def map_call(self, expr, *args): return expr.__class__( - self.rec(expr.function, *args, **kwargs), - tuple(self.rec(child, *args, **kwargs) + self.rec(expr.function, *args), + tuple(self.rec(child, *args) for child in expr.parameters)) - def map_subscript(self, expr, *args, **kwargs): + def map_subscript(self, expr, *args): return expr.__class__( - self.rec(expr.aggregate, *args, **kwargs), - self.rec(expr.index, *args, **kwargs)) + self.rec(expr.aggregate, *args), + self.rec(expr.index, *args)) - def map_lookup(self, expr, *args, **kwargs): + def map_lookup(self, expr, *args): return expr.__class__( - self.rec(expr.aggregate, *args, **kwargs), + self.rec(expr.aggregate, *args), expr.name) - def map_negation(self, expr, *args, **kwargs): - return expr.__class__(self.rec(expr.child, *args, **kwargs)) + def map_negation(self, expr, *args): + return expr.__class__(self.rec(expr.child, *args)) - def map_sum(self, expr, *args, **kwargs): + def map_sum(self, expr, *args): from pymbolic.primitives import flattened_sum return flattened_sum(tuple( - self.rec(child, *args, **kwargs) for child in expr.children)) + self.rec(child, *args) for child in expr.children)) - def map_product(self, expr, *args, **kwargs): + def map_product(self, expr, *args): from pymbolic.primitives import flattened_product return flattened_product(tuple( - self.rec(child, *args, **kwargs) for child in expr.children)) + self.rec(child, *args) for child in expr.children)) - def map_quotient(self, expr, *args, **kwargs): - return expr.__class__(self.rec(expr.numerator, *args, **kwargs), - self.rec(expr.denominator, *args, **kwargs)) + def map_quotient(self, expr, *args): + return expr.__class__(self.rec(expr.numerator, *args), + self.rec(expr.denominator, *args)) - def map_power(self, expr, *args, **kwargs): - return expr.__class__(self.rec(expr.base, *args, **kwargs), - self.rec(expr.exponent, *args, **kwargs)) + def map_power(self, expr, *args): + return expr.__class__(self.rec(expr.base, *args), + self.rec(expr.exponent, *args)) - def map_polynomial(self, expr, *args, **kwargs): - return expr.__class__(self.rec(expr.base, *args, **kwargs), - ((exp, self.rec(coeff, *args, **kwargs)) + def map_polynomial(self, expr, *args): + return expr.__class__(self.rec(expr.base, *args), + ((exp, self.rec(coeff, *args)) for exp, coeff in expr.data)) map_list = map_sum diff --git a/src/parser.py b/src/parser.py index 8507504..0e5cf6b 100644 --- a/src/parser.py +++ b/src/parser.py @@ -1,4 +1,3 @@ -import pymbolic.primitives as primitives import pytools.lex _imaginary = intern("imaginary") @@ -48,6 +47,8 @@ _PREC_UNARY_MINUS = 40 _PREC_CALL = 50 def parse(expr_str): + import pymbolic.primitives as primitives + def parse_terminal(pstate): next_tag = pstate.next_tag() if next_tag is _int: diff --git a/src/primitives.py b/src/primitives.py index 32db774..4740642 100644 --- a/src/primitives.py +++ b/src/primitives.py @@ -1,5 +1,4 @@ import traits -import pymbolic.mapper.stringifier @@ -156,8 +155,7 @@ class Constant(Leaf): return self.value, def __hash__(self): - from pytools import hash_combine - return hash_combine(self.__class__, self.value) + return hash((self.__class__, self.value)) def get_mapper_method(self, mapper): return mapper.map_constant @@ -183,8 +181,7 @@ class Variable(Leaf): and self.name == other.name) def __hash__(self): - from pytools import hash_combine - return hash_combine(self.__class__, self.name) + return hash((self.__class__, self.name)) def get_mapper_method(self, mapper): return mapper.map_variable @@ -203,8 +200,7 @@ class Call(AlgebraicLeaf): and (self.parameters == other.parameters) def __hash__(self): - from pytools import hash_combine - return hash_combine(self.__class__, self.function, self.parameters) + return hash((self.__class__, self.function, self.parameters)) def get_mapper_method(self, mapper): return mapper.map_call @@ -226,8 +222,7 @@ class Subscript(AlgebraicLeaf): and (self.index == other.index) def __hash__(self): - from pytools import hash_combine - return hash_combine(self.__class__, self.aggregate, self.index) + return hash((self.__class__, self.aggregate, self.index)) def get_mapper_method(self, mapper): return mapper.map_subscript @@ -249,8 +244,7 @@ class Lookup(AlgebraicLeaf): and (self.name == other.name) def __hash__(self): - from pytools import hash_combine - return hash_combine(self.__class__, self.aggregate, self.name) + return hash((self.__class__, self.aggregate, self.name)) def get_mapper_method(self, mapper): return mapper.map_lookup @@ -309,8 +303,7 @@ class Sum(Expression): return True def __hash__(self): - from pytools import hash_combine - return hash_combine(self.__class__, self.children) + return hash((self.__class__, self.children)) def get_mapper_method(self, mapper): return mapper.map_sum @@ -359,8 +352,7 @@ class Product(Expression): return True def __hash__(self): - from pytools import hash_combine - return hash_combine(self.__class__, self.children) + return hash((self.__class__, self.children)) def get_mapper_method(self, mapper): return mapper.map_product @@ -394,8 +386,7 @@ class Quotient(Expression): return bool(self.numerator) def __hash__(self): - from pytools import hash_combine - return hash_combine(self.__class__, self.numerator, self.denominator) + return hash((self.__class__, self.numerator, self.denominator)) def get_mapper_method(self, mapper): return mapper.map_quotient @@ -417,8 +408,7 @@ class Power(Expression): and (self.exponent == other.exponent) def __hash__(self): - from pytools import hash_combine - return hash_combine(self.__class__, self.base, self.exponent) + return hash((self.__class__, self.base, self.exponent)) def get_mapper_method(self, mapper): return mapper.map_power @@ -490,8 +480,7 @@ class Vector(Expression): return self.children def __hash__(self): - from pytools import hash_combine - return hash_combine(self.__class__, self.children) + return hash((self.__class__, self.children)) def get_mapper_method(self, mapper): return mapper.map_vector @@ -595,25 +584,27 @@ def quotient(numerator, denominator): # tool functions -------------------------------------------------------------- -VALID_CONSTANT_CLASSES = [int, float, complex] -VALID_OPERANDS = [Expression] +VALID_CONSTANT_CLASSES = (int, float, complex) +VALID_OPERANDS = (Expression,) def is_constant(value): - return isinstance(value, tuple(VALID_CONSTANT_CLASSES)) + return isinstance(value, VALID_CONSTANT_CLASSES) def is_valid_operand(value): - return isinstance(value, tuple(VALID_OPERANDS)) or is_constant(value) + return isinstance(value, VALID_OPERANDS) or is_constant(value) def register_constant_class(class_): - VALID_CONSTANT_CLASSES.append(class_) + VALID_CONSTANT_CLASSES += (class_,) def unregister_constant_class(class_): - VALID_CONSTANT_CLASSES.remove(class_) + tmp = list(VALID_CONSTANT_CLASSES) + tmp.remove(class_) + VALID_CONSTANT_CLASSES = tuple(tmp) -- GitLab