From e94735550a20d3f119cce4c224dea4677a1b98b3 Mon Sep 17 00:00:00 2001 From: Matt Wala <wala1@illinois.edu> Date: Tue, 10 Jan 2017 08:16:42 -0600 Subject: [PATCH] Documentation, various updates. --- doc/codegen.rst | 1 + sumpy/assignment_collection.py | 9 ++++--- sumpy/cse.py | 48 +++++++++++++++++----------------- 3 files changed, 30 insertions(+), 28 deletions(-) diff --git a/doc/codegen.rst b/doc/codegen.rst index 66fc5b2d..2d5d6f4d 100644 --- a/doc/codegen.rst +++ b/doc/codegen.rst @@ -3,3 +3,4 @@ Code Generation .. automodule:: sumpy.codegen .. automodule:: sumpy.assignment_collection +.. automodule:: sumpy.cse diff --git a/sumpy/assignment_collection.py b/sumpy/assignment_collection.py index af568ddb..00c0266b 100644 --- a/sumpy/assignment_collection.py +++ b/sumpy/assignment_collection.py @@ -160,17 +160,18 @@ class SymbolicAssignmentCollection(object): def run_global_cse(self, extra_exprs=[]): logger.info("common subexpression elimination: start") - assign_names = list(self.assignments) + assign_names = sorted(self.assign_names) assign_exprs = [self.assignments[name] for name in assign_names] # Options here: - # - cached_cse: Uses on-disk cache to speed up CSE. # - checked_cse: if you mistrust the result of the cse. # Uses maxima to verify. - # - sp.cse: The underlying sympy thing. + # - sp.cse: The sympy thing. + # - sumpy.cse.cse: Based on sympy, designed to go faster. #from sumpy.symbolic import checked_cse - new_assignments, new_exprs = cached_cse(assign_exprs + extra_exprs, + from sumpy.cse import cse + new_assignments, new_exprs = cse(assign_exprs + extra_exprs, symbols=self.symbol_generator) new_assign_exprs = new_exprs[:len(assign_exprs)] diff --git a/sumpy/cse.py b/sumpy/cse.py index 9bea6e56..466157c8 100644 --- a/sumpy/cse.py +++ b/sumpy/cse.py @@ -68,6 +68,16 @@ from sympy.core.compatibility import iterable from sympy.utilities.iterables import numbered_symbols +__doc__ = """ + +Common subexpression elimination +-------------------------------- + +.. autofunction:: cse + +""" + + # {{{ cse pre/postprocessing def preprocess_for_cse(expr, optimizations): @@ -111,7 +121,8 @@ def postprocess_for_cse(expr, optimizations): class FuncArgTracker(object): """ - A class which manages an inverse mapping from arguments to functions. + A class which manages a mapping from functions to arguments and an inverse + mapping from arguments to functions. """ def __init__(self, funcs): @@ -455,39 +466,28 @@ def tree_cse(exprs, symbols, opt_subs=None): def cse(exprs, symbols=None, optimizations=None): - """Perform common subexpression elimination on an expression. + """ + Perform common subexpression elimination on an expression. :arg exprs: A list of sympy expressions, or a single sympy expression to reduce :arg symbols: An iterator yielding unique Symbols used to label the common subexpressions which are pulled out. The ``numbered_symbols`` - generator is useful. The default is a stream of symbols of the form - "x0", "x1", etc. This must be an infinite iterator. + generator from sympy is useful. The default is a stream of symbols of the + form "x0", "x1", etc. This must be an infinite iterator. :arg optimizations: A list of (callable, callable) pairs consisting of (preprocessor, postprocessor) pairs of external optimization functions. - Returns - ======= - - replacements : list of (Symbol, expression) pairs - All of the common subexpressions that were replaced. Subexpressions - earlier in this list might show up in subexpressions later in this - list. - reduced_exprs : list of sympy expressions - The reduced expressions with all of the replacements above. + :return: This returns a pair ``(replacements, reduced_exprs)``. + * ``replacements`` is a list of (Symbol, expression) pairs consisting of + all of the common subexpressions that were replaced. Subexpressions + earlier in this list might show up in subexpressions later in this list. + * ``reduced_exprs`` is a list of sympy expressions. This contains the + reduced expressions with all of the replacements above. """ if isinstance(exprs, Basic): exprs = [exprs] - import time - - start = time.time() - - import logging - - logger = logging.getLogger(__name__) - logger.info("cse: start") - exprs = list(exprs) if optimizations is None: @@ -510,8 +510,6 @@ def cse(exprs, symbols=None, optimizations=None): # Find other optimization opportunities. opt_subs = opt_cse(reduced_exprs) - logger.info("cse: done after {}".format(time.time() - start)) - # Main CSE algorithm. replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs) @@ -522,3 +520,5 @@ def cse(exprs, symbols=None, optimizations=None): reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs] return replacements, reduced_exprs + +# vim: fdm=marker -- GitLab