From f6f8365f1f55b832a08b65574f48785793361ddc Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <kloeckner@teramite.rice.edu>
Date: Wed, 3 Sep 2008 22:28:56 -0500
Subject: [PATCH] Improve pymbolic speed.

---
 src/compiler.py        |   1 -
 src/mapper/__init__.py | 165 +++++++++++++++++++++--------------------
 src/parser.py          |   3 +-
 src/primitives.py      |  45 +++++------
 4 files changed, 103 insertions(+), 111 deletions(-)

diff --git a/src/compiler.py b/src/compiler.py
index 0d58668..90974f7 100644
--- a/src/compiler.py
+++ b/src/compiler.py
@@ -1,7 +1,6 @@
 import math
 
 import pymbolic
-import pymbolic.mapper.dependency
 from pymbolic.mapper.stringifier import StringifyMapper, PREC_NONE, PREC_SUM, PREC_POWER
 
 
diff --git a/src/mapper/__init__.py b/src/mapper/__init__.py
index c8bcc75..28542fd 100644
--- a/src/mapper/__init__.py
+++ b/src/mapper/__init__.py
@@ -1,3 +1,8 @@
+import pymbolic.primitives as primitives
+
+
+
+
 try:
     import numpy
 
@@ -14,46 +19,43 @@ class Mapper(object):
     def __init__(self, recurse=True):
         self.Recurse = True
 
-    def handle_unsupported_expression(self, expr, *args, **kwargs):
+    def handle_unsupported_expression(self, expr, *args):
         raise ValueError, "%s cannot handle expressions of type %s" % (
                 self.__class__, expr.__class__)
 
-    def __call__(self, expr, *args, **kwargs):
-        import pymbolic.primitives as primitives
-        if isinstance(expr, primitives.Expression):
-            try:
-                method = expr.get_mapper_method(self)
-            except AttributeError:
-                return self.handle_unsupported_expression(expr, *args, **kwargs)
+    def __call__(self, expr, *args):
+        try:
+            method = expr.get_mapper_method(self)
+        except AttributeError:
+            if isinstance(expr, primitives.Expression):
+                return self.handle_unsupported_expression(expr, *args)
             else:
-                return method(expr, *args, **kwargs)
+                return self.map_foreign(expr, *args)
         else:
-            return self.map_foreign(expr, *args, **kwargs)
+            return method(expr, *args)
 
-    def map_variable(self, expr, *args, **kwargs):
-        return self.map_algebraic_leaf(expr, *args, **kwargs)
+    def map_variable(self, expr, *args):
+        return self.map_algebraic_leaf(expr, *args)
 
-    def map_subscript(self, expr, *args, **kwargs):
-        return self.map_algebraic_leaf(expr, *args, **kwargs)
+    def map_subscript(self, expr, *args):
+        return self.map_algebraic_leaf(expr, *args)
 
-    def map_call(self, expr, *args, **kwargs):
-        return self.map_algebraic_leaf(expr, *args, **kwargs)
+    def map_call(self, expr, *args):
+        return self.map_algebraic_leaf(expr, *args)
 
-    def map_lookup(self, expr, *args, **kwargs):
-        return self.map_algebraic_leaf(expr, *args, **kwargs)
+    def map_lookup(self, expr, *args):
+        return self.map_algebraic_leaf(expr, *args)
 
-    def map_rational(self, expr, *args, **kwargs):
-        return self.map_quotient(expr, *args, **kwargs)
+    def map_rational(self, expr, *args):
+        return self.map_quotient(expr, *args)
 
-    def map_foreign(self, expr, *args, **kwargs):
-        from pymbolic.primitives import is_constant
-        
-        if is_constant(expr):
-            return self.map_constant(expr, *args, **kwargs)
+    def map_foreign(self, expr, *args):
+        if isinstance(expr, primitives.VALID_CONSTANT_CLASSES):
+            return self.map_constant(expr, *args)
         elif isinstance(expr, list):
-            return self.map_list(expr, *args, **kwargs)
+            return self.map_list(expr, *args)
         elif is_numpy_array(expr):
-            return self.map_numpy_array(expr, *args, **kwargs)
+            return self.map_numpy_array(expr, *args)
         else:
             raise ValueError, "%s encountered invalid foreign object: %s" % (
                     self.__class__, repr(expr))
@@ -63,61 +65,60 @@ class Mapper(object):
 
 
 class RecursiveMapper(Mapper):
-    def rec(self, expr, *args, **kwargs):
-        import pymbolic.primitives as primitives
-        if isinstance(expr, primitives.Expression):
-            try:
-                method = expr.get_mapper_method(self)
-            except AttributeError:
-                return self.handle_unsupported_expression(expr, *args, **kwargs)
+    def rec(self, expr, *args):
+        try:
+            method = expr.get_mapper_method(self)
+        except AttributeError:
+            if isinstance(expr, primitives.Expression):
+                return self.handle_unsupported_expression(expr, *args)
             else:
-                return method(expr, *args, **kwargs)
+                return self.map_foreign(expr, *args)
         else:
-            return self.map_foreign(expr, *args, **kwargs)
+            return method(expr, *args)
 
 
 
 
 class CombineMapper(RecursiveMapper):
-    def map_call(self, expr, *args, **kwargs):
+    def map_call(self, expr, *args):
         return self.combine(
-                (self.rec(expr.function, *args, **kwargs),) + 
+                (self.rec(expr.function, *args),) + 
                 tuple(
-                    self.rec(child, *args, **kwargs) for child in expr.parameters)
+                    self.rec(child, *args) for child in expr.parameters)
                 )
 
-    def map_subscript(self, expr, *args, **kwargs):
+    def map_subscript(self, expr, *args):
         return self.combine(
-                [self.rec(expr.aggregate, *args, **kwargs), 
-                    self.rec(expr.index, *args, **kwargs)])
+                [self.rec(expr.aggregate, *args), 
+                    self.rec(expr.index, *args)])
 
-    def map_lookup(self, expr, *args, **kwargs):
-        return self.rec(expr.aggregate, *args, **kwargs)
+    def map_lookup(self, expr, *args):
+        return self.rec(expr.aggregate, *args)
 
-    def map_negation(self, expr, *args, **kwargs):
-        return self.rec(expr.child, *args, **kwargs)
+    def map_negation(self, expr, *args):
+        return self.rec(expr.child, *args)
 
-    def map_sum(self, expr, *args, **kwargs):
-        return self.combine(self.rec(child, *args, **kwargs) 
+    def map_sum(self, expr, *args):
+        return self.combine(self.rec(child, *args) 
                 for child in expr.children)
 
     map_product = map_sum
 
-    def map_quotient(self, expr, *args, **kwargs):
+    def map_quotient(self, expr, *args):
         return self.combine((
-            self.rec(expr.numerator, *args, **kwargs), 
-            self.rec(expr.denominator, *args, **kwargs)))
+            self.rec(expr.numerator, *args), 
+            self.rec(expr.denominator, *args)))
 
-    def map_power(self, expr, *args, **kwargs):
+    def map_power(self, expr, *args):
         return self.combine((
-                self.rec(expr.base, *args, **kwargs), 
-                self.rec(expr.exponent, *args, **kwargs)))
+                self.rec(expr.base, *args), 
+                self.rec(expr.exponent, *args)))
 
-    def map_polynomial(self, expr, *args, **kwargs):
+    def map_polynomial(self, expr, *args):
         return self.combine(
-                (self.rec(expr.base, *args, **kwargs),) + 
+                (self.rec(expr.base, *args),) + 
                 tuple(
-                    self.rec(coeff, *args, **kwargs) for exp, coeff in expr.data)
+                    self.rec(coeff, *args) for exp, coeff in expr.data)
                 )
 
     map_list = map_sum
@@ -129,54 +130,54 @@ class CombineMapper(RecursiveMapper):
 
 
 class IdentityMapperBase(object):
-    def map_constant(self, expr, *args, **kwargs):
+    def map_constant(self, expr, *args):
         # leaf -- no need to rebuild
         return expr
 
-    def map_variable(self, expr, *args, **kwargs):
+    def map_variable(self, expr, *args):
         # leaf -- no need to rebuild
         return expr
 
-    def map_call(self, expr, *args, **kwargs):
+    def map_call(self, expr, *args):
         return expr.__class__(
-                self.rec(expr.function, *args, **kwargs),
-                tuple(self.rec(child, *args, **kwargs)
+                self.rec(expr.function, *args),
+                tuple(self.rec(child, *args)
                     for child in expr.parameters))
 
-    def map_subscript(self, expr, *args, **kwargs):
+    def map_subscript(self, expr, *args):
         return expr.__class__(
-                self.rec(expr.aggregate, *args, **kwargs), 
-                self.rec(expr.index, *args, **kwargs))
+                self.rec(expr.aggregate, *args), 
+                self.rec(expr.index, *args))
 
-    def map_lookup(self, expr, *args, **kwargs):
+    def map_lookup(self, expr, *args):
         return expr.__class__(
-                self.rec(expr.aggregate, *args, **kwargs), 
+                self.rec(expr.aggregate, *args), 
                 expr.name)
 
-    def map_negation(self, expr, *args, **kwargs):
-        return expr.__class__(self.rec(expr.child, *args, **kwargs))
+    def map_negation(self, expr, *args):
+        return expr.__class__(self.rec(expr.child, *args))
 
-    def map_sum(self, expr, *args, **kwargs):
+    def map_sum(self, expr, *args):
         from pymbolic.primitives import flattened_sum
         return flattened_sum(tuple(
-            self.rec(child, *args, **kwargs) for child in expr.children))
+            self.rec(child, *args) for child in expr.children))
     
-    def map_product(self, expr, *args, **kwargs):
+    def map_product(self, expr, *args):
         from pymbolic.primitives import flattened_product
         return flattened_product(tuple(
-            self.rec(child, *args, **kwargs) for child in expr.children))
+            self.rec(child, *args) for child in expr.children))
     
-    def map_quotient(self, expr, *args, **kwargs):
-        return expr.__class__(self.rec(expr.numerator, *args, **kwargs),
-                              self.rec(expr.denominator, *args, **kwargs))
+    def map_quotient(self, expr, *args):
+        return expr.__class__(self.rec(expr.numerator, *args),
+                              self.rec(expr.denominator, *args))
 
-    def map_power(self, expr, *args, **kwargs):
-        return expr.__class__(self.rec(expr.base, *args, **kwargs),
-                              self.rec(expr.exponent, *args, **kwargs))
+    def map_power(self, expr, *args):
+        return expr.__class__(self.rec(expr.base, *args),
+                              self.rec(expr.exponent, *args))
 
-    def map_polynomial(self, expr, *args, **kwargs):
-        return expr.__class__(self.rec(expr.base, *args, **kwargs),
-                              ((exp, self.rec(coeff, *args, **kwargs))
+    def map_polynomial(self, expr, *args):
+        return expr.__class__(self.rec(expr.base, *args),
+                              ((exp, self.rec(coeff, *args))
                                   for exp, coeff in expr.data))
 
     map_list = map_sum
diff --git a/src/parser.py b/src/parser.py
index 8507504..0e5cf6b 100644
--- a/src/parser.py
+++ b/src/parser.py
@@ -1,4 +1,3 @@
-import pymbolic.primitives as primitives
 import pytools.lex
 
 _imaginary = intern("imaginary")
@@ -48,6 +47,8 @@ _PREC_UNARY_MINUS = 40
 _PREC_CALL = 50
 
 def parse(expr_str):
+    import pymbolic.primitives as primitives
+
     def parse_terminal(pstate):
         next_tag = pstate.next_tag()
         if next_tag is _int:
diff --git a/src/primitives.py b/src/primitives.py
index 32db774..4740642 100644
--- a/src/primitives.py
+++ b/src/primitives.py
@@ -1,5 +1,4 @@
 import traits
-import pymbolic.mapper.stringifier
 
 
 
@@ -156,8 +155,7 @@ class Constant(Leaf):
         return self.value,
 
     def __hash__(self):
-        from pytools import hash_combine
-        return hash_combine(self.__class__, self.value)
+        return hash((self.__class__, self.value))
 
     def get_mapper_method(self, mapper):
         return mapper.map_constant
@@ -183,8 +181,7 @@ class Variable(Leaf):
                 and self.name == other.name)
 
     def __hash__(self):
-        from pytools import hash_combine
-        return hash_combine(self.__class__, self.name)
+        return hash((self.__class__, self.name))
 
     def get_mapper_method(self, mapper):
         return mapper.map_variable
@@ -203,8 +200,7 @@ class Call(AlgebraicLeaf):
                and (self.parameters == other.parameters)
 
     def __hash__(self):
-        from pytools import hash_combine
-        return hash_combine(self.__class__, self.function, self.parameters)
+        return hash((self.__class__, self.function, self.parameters))
 
     def get_mapper_method(self, mapper):
         return mapper.map_call
@@ -226,8 +222,7 @@ class Subscript(AlgebraicLeaf):
                and (self.index == other.index)
 
     def __hash__(self):
-        from pytools import hash_combine
-        return hash_combine(self.__class__, self.aggregate, self.index)
+        return hash((self.__class__, self.aggregate, self.index))
 
     def get_mapper_method(self, mapper):
         return mapper.map_subscript
@@ -249,8 +244,7 @@ class Lookup(AlgebraicLeaf):
                and (self.name == other.name)
 
     def __hash__(self):
-        from pytools import hash_combine
-        return hash_combine(self.__class__, self.aggregate, self.name)
+        return hash((self.__class__, self.aggregate, self.name))
 
     def get_mapper_method(self, mapper):
         return mapper.map_lookup
@@ -309,8 +303,7 @@ class Sum(Expression):
             return True
 
     def __hash__(self):
-        from pytools import hash_combine
-        return hash_combine(self.__class__, self.children)
+        return hash((self.__class__, self.children))
 
     def get_mapper_method(self, mapper):
         return mapper.map_sum
@@ -359,8 +352,7 @@ class Product(Expression):
         return True
 
     def __hash__(self):
-        from pytools import hash_combine
-        return hash_combine(self.__class__, self.children)
+        return hash((self.__class__, self.children))
 
     def get_mapper_method(self, mapper):
         return mapper.map_product
@@ -394,8 +386,7 @@ class Quotient(Expression):
         return bool(self.numerator)
 
     def __hash__(self):
-        from pytools import hash_combine
-        return hash_combine(self.__class__, self.numerator, self.denominator)
+        return hash((self.__class__, self.numerator, self.denominator))
 
     def get_mapper_method(self, mapper):
         return mapper.map_quotient
@@ -417,8 +408,7 @@ class Power(Expression):
                and (self.exponent == other.exponent)
 
     def __hash__(self):
-        from pytools import hash_combine
-        return hash_combine(self.__class__, self.base, self.exponent)
+        return hash((self.__class__, self.base, self.exponent))
 
     def get_mapper_method(self, mapper):
         return mapper.map_power
@@ -490,8 +480,7 @@ class Vector(Expression):
         return self.children
 
     def __hash__(self):
-        from pytools import hash_combine
-        return hash_combine(self.__class__, self.children)
+        return hash((self.__class__, self.children))
 
     def get_mapper_method(self, mapper):
         return mapper.map_vector
@@ -595,25 +584,27 @@ def quotient(numerator, denominator):
 
 
 # tool functions --------------------------------------------------------------
-VALID_CONSTANT_CLASSES = [int, float, complex]
-VALID_OPERANDS = [Expression]
+VALID_CONSTANT_CLASSES = (int, float, complex)
+VALID_OPERANDS = (Expression,)
 
 
 
 def is_constant(value):
-    return isinstance(value, tuple(VALID_CONSTANT_CLASSES))
+    return isinstance(value, VALID_CONSTANT_CLASSES)
 
 def is_valid_operand(value):
-    return isinstance(value, tuple(VALID_OPERANDS)) or is_constant(value)
+    return isinstance(value, VALID_OPERANDS) or is_constant(value)
 
 
 
 
 def register_constant_class(class_):
-    VALID_CONSTANT_CLASSES.append(class_)
+    VALID_CONSTANT_CLASSES += (class_,)
 
 def unregister_constant_class(class_):
-    VALID_CONSTANT_CLASSES.remove(class_)
+    tmp = list(VALID_CONSTANT_CLASSES)
+    tmp.remove(class_)
+    VALID_CONSTANT_CLASSES = tuple(tmp)
 
 
 
-- 
GitLab