diff --git a/doc/codegen.rst b/doc/codegen.rst index 66fc5b2dccac2f1f3c16918d7325854004e54d31..2d5d6f4dcdc14f7259610585ffdd1478a807358d 100644 --- a/doc/codegen.rst +++ b/doc/codegen.rst @@ -3,3 +3,4 @@ Code Generation .. automodule:: sumpy.codegen .. automodule:: sumpy.assignment_collection +.. automodule:: sumpy.cse diff --git a/sumpy/assignment_collection.py b/sumpy/assignment_collection.py index af568ddb9de05d3a0126dd419cc8e207bcd0f88e..00c0266ba89b5f584ea7133b59308c2adbb96038 100644 --- a/sumpy/assignment_collection.py +++ b/sumpy/assignment_collection.py @@ -160,17 +160,18 @@ class SymbolicAssignmentCollection(object): def run_global_cse(self, extra_exprs=[]): logger.info("common subexpression elimination: start") - assign_names = list(self.assignments) + assign_names = sorted(self.assign_names) assign_exprs = [self.assignments[name] for name in assign_names] # Options here: - # - cached_cse: Uses on-disk cache to speed up CSE. # - checked_cse: if you mistrust the result of the cse. # Uses maxima to verify. - # - sp.cse: The underlying sympy thing. + # - sp.cse: The sympy thing. + # - sumpy.cse.cse: Based on sympy, designed to go faster. #from sumpy.symbolic import checked_cse - new_assignments, new_exprs = cached_cse(assign_exprs + extra_exprs, + from sumpy.cse import cse + new_assignments, new_exprs = cse(assign_exprs + extra_exprs, symbols=self.symbol_generator) new_assign_exprs = new_exprs[:len(assign_exprs)] diff --git a/sumpy/cse.py b/sumpy/cse.py index 9bea6e569721705d97506bafdce00c200ce0865b..466157c8f948d02e1d13b76a2848da3156f945a8 100644 --- a/sumpy/cse.py +++ b/sumpy/cse.py @@ -68,6 +68,16 @@ from sympy.core.compatibility import iterable from sympy.utilities.iterables import numbered_symbols +__doc__ = """ + +Common subexpression elimination +-------------------------------- + +.. autofunction:: cse + +""" + + # {{{ cse pre/postprocessing def preprocess_for_cse(expr, optimizations): @@ -111,7 +121,8 @@ def postprocess_for_cse(expr, optimizations): class FuncArgTracker(object): """ - A class which manages an inverse mapping from arguments to functions. + A class which manages a mapping from functions to arguments and an inverse + mapping from arguments to functions. """ def __init__(self, funcs): @@ -455,39 +466,28 @@ def tree_cse(exprs, symbols, opt_subs=None): def cse(exprs, symbols=None, optimizations=None): - """Perform common subexpression elimination on an expression. + """ + Perform common subexpression elimination on an expression. :arg exprs: A list of sympy expressions, or a single sympy expression to reduce :arg symbols: An iterator yielding unique Symbols used to label the common subexpressions which are pulled out. The ``numbered_symbols`` - generator is useful. The default is a stream of symbols of the form - "x0", "x1", etc. This must be an infinite iterator. + generator from sympy is useful. The default is a stream of symbols of the + form "x0", "x1", etc. This must be an infinite iterator. :arg optimizations: A list of (callable, callable) pairs consisting of (preprocessor, postprocessor) pairs of external optimization functions. - Returns - ======= - - replacements : list of (Symbol, expression) pairs - All of the common subexpressions that were replaced. Subexpressions - earlier in this list might show up in subexpressions later in this - list. - reduced_exprs : list of sympy expressions - The reduced expressions with all of the replacements above. + :return: This returns a pair ``(replacements, reduced_exprs)``. + * ``replacements`` is a list of (Symbol, expression) pairs consisting of + all of the common subexpressions that were replaced. Subexpressions + earlier in this list might show up in subexpressions later in this list. + * ``reduced_exprs`` is a list of sympy expressions. This contains the + reduced expressions with all of the replacements above. """ if isinstance(exprs, Basic): exprs = [exprs] - import time - - start = time.time() - - import logging - - logger = logging.getLogger(__name__) - logger.info("cse: start") - exprs = list(exprs) if optimizations is None: @@ -510,8 +510,6 @@ def cse(exprs, symbols=None, optimizations=None): # Find other optimization opportunities. opt_subs = opt_cse(reduced_exprs) - logger.info("cse: done after {}".format(time.time() - start)) - # Main CSE algorithm. replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs) @@ -522,3 +520,5 @@ def cse(exprs, symbols=None, optimizations=None): reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs] return replacements, reduced_exprs + +# vim: fdm=marker