From fc2658086e6315e37106d2af776e75fe407cf2a5 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Sun, 8 Jan 2017 21:30:33 -0600 Subject: [PATCH 01/10] [WIP] --- examples/m2l-timing.py | 28 ++ sumpy/assignment_collection.py | 7 +- sumpy/cse.py | 660 +++++++++++++++++++++++++++++++++ sumpy/e2e.py | 17 + test/test_cse.py | 403 ++++++++++++++++++++ 5 files changed, 1112 insertions(+), 3 deletions(-) create mode 100644 examples/m2l-timing.py create mode 100644 sumpy/cse.py create mode 100644 test/test_cse.py diff --git a/examples/m2l-timing.py b/examples/m2l-timing.py new file mode 100644 index 00000000..67aaa655 --- /dev/null +++ b/examples/m2l-timing.py @@ -0,0 +1,28 @@ + +def test_m2l_creation(ctx, mpole_expn_class, local_expn_class, knl, order): + from sympy.core.cache import clear_cache + clear_cache() + m_expn = mpole_expn_class(knl, order=order) + l_expn = local_expn_class(knl, order=order) + from sumpy.e2e import E2EFromCSR + m2l = E2EFromCSR(ctx, m_expn, l_expn) + import time + start = time.time() + m2l.run_translation_and_cse() + return time.time() - start + +if __name__ == "__main__": + import logging + logging.basicConfig(level=logging.INFO) + from sumpy.kernel import LaplaceKernel + + import pyopencl as cl + ctx = cl._csc() + from sumpy.expansion.local import LaplaceConformingVolumeTaylorLocalExpansion as LExpn + from sumpy.expansion.multipole import LaplaceConformingVolumeTaylorMultipoleExpansion as MExpn + results = [] + for order in range(20, 22): + results.append((order, test_m2l_creation(ctx, MExpn, LExpn, LaplaceKernel(2), order))) + print("order\ttime (s)") + for order, time in results: + print("{}\t{:.2f}".format(order, time)) diff --git a/sumpy/assignment_collection.py b/sumpy/assignment_collection.py index f2bc607c..4094aa94 100644 --- a/sumpy/assignment_collection.py +++ b/sumpy/assignment_collection.py @@ -109,9 +109,10 @@ def cached_cse(exprs, symbols): frozenset(symbols.generated_names)) try: - result = cache_dict[key] + result = cache_dict[7] except KeyError: - result = sp.cse(exprs, symbols) + from sumpy.cse import cse + result = cse(exprs, symbols) cache_dict[key] = _map_cse_result(s2p, result) return result else: @@ -216,7 +217,7 @@ class SymbolicAssignmentCollection(object): # - cached_cse: Uses on-disk cache to speed up CSE. # - checked_cse: if you mistrust the result of the cse. # Uses maxima to verify. - # - sp.cse: The underlying sympy thing. + # - sumpy.cse: The underlying thing. #from sumpy.symbolic import checked_cse new_assignments, new_exprs = cached_cse(assign_exprs + extra_exprs, diff --git a/sumpy/cse.py b/sumpy/cse.py new file mode 100644 index 00000000..9094a3fc --- /dev/null +++ b/sumpy/cse.py @@ -0,0 +1,660 @@ +""" Tools for doing common subexpression elimination. +""" +from __future__ import print_function, division + +__copyright__ = """ +Copyright (C) 2017 Matt Wala +Copyright (C) 2006-2016 SymPy Development Team +""" + +# {{{ license and original license + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +=========================================================================== + +Based on sympy/simplify/cse_main.py from SymPy 1.0, original license as follows: + +Copyright (c) 2006-2016 SymPy Development Team + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + a. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + b. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + c. Neither the name of SymPy nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +DAMAGE. +""" + +# }}} + +from sympy.core import Basic, Mul, Add, Pow, sympify, Symbol, Tuple +#from symengine.sympy_compat import (Basic, Mul, Add, Pow, sympify, Symbol) +from sympy.core.singleton import S +from sympy.core.function import _coeff_isneg +from sympy.core.exprtools import factor_terms +from sympy.core.compatibility import iterable, range +from sympy.utilities.iterables import filter_symbols, \ + numbered_symbols, sift, topological_sort, ordered + +#from . import cse_opts + +# (preprocessor, postprocessor) pairs which are commonly useful. They should +# each take a sympy expression and return a possibly transformed expression. +# When used in the function ``cse()``, the target expressions will be transformed +# by each of the preprocessor functions in order. After the common +# subexpressions are eliminated, each resulting expression will have the +# postprocessor functions transform them in *reverse* order in order to undo the +# transformation if necessary. This allows the algorithm to operate on +# a representation of the expressions that allows for more optimization +# opportunities. +# ``None`` can be used to specify no transformation for either the preprocessor or +# postprocessor. + + +#basic_optimizations = [(cse_opts.sub_pre, cse_opts.sub_post), +# (factor_terms, None)] + +# sometimes we want the output in a different format; non-trivial +# transformations can be put here for users +# =============================================================== + +def reps_toposort(r): + """Sort replacements `r` so (k1, v1) appears before (k2, v2) + if k2 is in v1's free symbols. This orders items in the + way that cse returns its results (hence, in order to use the + replacements in a substitution option it would make sense + to reverse the order). + + Examples + ======== + + >>> from sympy.simplify.cse_main import reps_toposort + >>> from sympy.abc import x, y + >>> from sympy import Eq + >>> for l, r in reps_toposort([(x, y + 1), (y, 2)]): + ... print(Eq(l, r)) + ... + Eq(y, 2) + Eq(x, y + 1) + + """ + r = sympify(r) + E = [] + for c1, (k1, v1) in enumerate(r): + for c2, (k2, v2) in enumerate(r): + if k1 in v2.free_symbols: + E.append((c1, c2)) + return [r[i] for i in topological_sort((range(len(r)), E))] + + +def cse_separate(r, e): + """Move expressions that are in the form (symbol, expr) out of the + expressions and sort them into the replacements using the reps_toposort. + + Examples + ======== + + >>> from sympy.simplify.cse_main import cse_separate + >>> from sympy.abc import x, y, z + >>> from sympy import cos, exp, cse, Eq, symbols + >>> x0, x1 = symbols('x:2') + >>> eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1)) + >>> cse([eq, Eq(x, z + 1), z - 2], postprocess=cse_separate) in [ + ... [[(x0, y + 1), (x, z + 1), (x1, x + 1)], + ... [x1 + exp(x1/x0) + cos(x0), z - 2]], + ... [[(x1, y + 1), (x, z + 1), (x0, x + 1)], + ... [x0 + exp(x0/x1) + cos(x1), z - 2]]] + ... + True + """ + d = sift(e, lambda w: w.is_Equality and w.lhs.is_Symbol) + r = r + [w.args for w in d[True]] + e = d[False] + return [reps_toposort(r), e] + +# ====end of cse postprocess idioms=========================== + + +def preprocess_for_cse(expr, optimizations): + """ Preprocess an expression to optimize for common subexpression + elimination. + + Parameters + ---------- + expr : sympy expression + The target expression to optimize. + optimizations : list of (callable, callable) pairs + The (preprocessor, postprocessor) pairs. + + Returns + ------- + expr : sympy expression + The transformed expression. + """ + for pre, post in optimizations: + if pre is not None: + expr = pre(expr) + return expr + + +def postprocess_for_cse(expr, optimizations): + """ Postprocess an expression after common subexpression elimination to + return the expression to canonical sympy form. + + Parameters + ---------- + expr : sympy expression + The target expression to transform. + optimizations : list of (callable, callable) pairs, optional + The (preprocessor, postprocessor) pairs. The postprocessors will be + applied in reversed order to undo the effects of the preprocessors + correctly. + + Returns + ------- + expr : sympy expression + The transformed expression. + """ + for pre, post in reversed(optimizations): + if post is not None: + expr = post(expr) + return expr + + +class FuncArgTracker(object): + + def __init__(self, funcs): + import time + s = time.time() + from collections import defaultdict + from sortedcontainers import SortedSet + self.arg_to_func_idx_map = defaultdict(lambda: SortedSet()) + func_args = [] + func_count = 0 + func_arg_count = 0 + for func_i, func in enumerate(funcs): + func_count += 1 + func_argset = set(func.args) + func_args.append(func_argset) + for func_arg in func_argset: + func_arg_count += 1 + self.arg_to_func_idx_map[func_arg].add(func_i) + #print("func count, func arg count", func_count, func_arg_count) + self.func_args = func_args + self.changed = set() + + def gen_common_arg_candidates(self, argset, min_idx=0, threshold=1, kind="k"): + last_seen_idx = min_idx - 1 + max_idx = len(self.func_args) + + while True: + from collections import defaultdict + count_map = defaultdict(lambda: 0) + + # TODO: Dynamically update count_map, this is stupid. + next_idx = max_idx + for arg in argset: + m = self.arg_to_func_idx_map[arg] + for other_idx in m.islice(m.bisect_right(last_seen_idx)): + count_map[other_idx] += 1 + + for item in count_map.keys(): + if count_map[item] >= threshold: + next_idx = min(item, next_idx) + + # Delete old items in count_map. + + if next_idx == max_idx: + return + + last_seen_idx = next_idx + yield next_idx + + def gen_subset_candidates(self, argset, min_idx): + indices = min( + (self.arg_to_func_idx_map[arg] for arg in argset), + key=lambda s: len(s) - s.bisect_right(min_idx - 1)) + + indices = indices[indices.bisect_right(min_idx - 1):] + for arg in argset: + indices &= self.arg_to_func_idx_map[arg] + + for i in indices: + yield i + + def update_func_args(self, func_idx, new_args): + new_args = set(new_args) + old_args = self.func_args[func_idx] + for deleted_arg in old_args - new_args: + self.arg_to_func_idx_map[deleted_arg].remove(func_idx) + for added_arg in new_args - old_args: + self.arg_to_func_idx_map[added_arg].add(func_idx) + self.func_args[func_idx].clear() + self.func_args[func_idx].update(new_args) + self.changed.add(func_idx) + + +def match_common_args(Func, funcs, kind, opt_subs, order="canonical"): + #if order != 'none': + # funcs = list(ordered(funcs)) + #else: + + + #func_args = [set(e.args) for e in funcs] + import time + #if order == "canonical": + # funcs = list(ordered(funcs)) + #else: + funcs = sorted(funcs, key=lambda x: len(x.args)) + arg_tracker = FuncArgTracker(funcs) + + for i in range(len(funcs)): + for j in arg_tracker.gen_common_arg_candidates(arg_tracker.func_args[i], i + 1, threshold=2, kind=kind): + com_args = arg_tracker.func_args[i].intersection(arg_tracker.func_args[j]) + assert len(com_args) > 1 + com_func = Func(*com_args) + + # for all sets, replace the common symbols by the function + # over them, to allow recursive matches + + diff_i = arg_tracker.func_args[i].difference(com_args) + arg_tracker.update_func_args(i, diff_i | set([com_func])) + #if diff_i: + # opt_subs[funcs[i]] = Func(Func(*diff_i), com_func, + #evaluate=False) + + diff_j = arg_tracker.func_args[j].difference(com_args) + arg_tracker.update_func_args(j, diff_j | set([com_func])) + #opt_subs[funcs[j]] = Func(Func(*diff_j), com_func, + # evaluate=False) + + for k in arg_tracker.gen_subset_candidates(com_args, j + 1): + diff_k = arg_tracker.func_args[k].difference(com_args) + arg_tracker.update_func_args(k, diff_k | set([com_func])) + # opt_subs[funcs[k]] = Func(Func(*diff_k), com_func, evaluate=False) + if i in arg_tracker.changed: + opt_subs[funcs[i]] = Func(*arg_tracker.func_args[i], evaluate=False) + #print("made subst", opt_subs[funcs[i]]) + + +def opt_cse(exprs, order='canonical'): + """Find optimization opportunities in Adds, Muls, Pows and negative + coefficient Muls + + Parameters + ---------- + exprs : list of sympy expressions + The expressions to optimize. + order : string, 'none' or 'canonical' + The order by which Mul and Add arguments are processed. For large + expressions where speed is a concern, use the setting order='none'. + + Returns + ------- + opt_subs : dictionary of expression substitutions + The expression substitutions which can be useful to optimize CSE. + + Examples + ======== + + >>> from sympy.simplify.cse_main import opt_cse + >>> from sympy.abc import x + >>> opt_subs = opt_cse([x**-2]) + >>> print(opt_subs) + {x**(-2): 1/(x**2)} + """ + from sympy.matrices.expressions import MatAdd, MatMul, MatPow + opt_subs = dict() + + adds = set() + muls = set() + + seen_subexp = set() + + def _find_opts(expr): + + if not isinstance(expr, Basic): + return + + if expr.is_Atom: # or expr.is_Order: + return + + if iterable(expr): + list(map(_find_opts, expr)) + return + + if expr in seen_subexp: + return expr + seen_subexp.add(expr) + + list(map(_find_opts, expr.args)) + + """ + def _coeff_isneg(expr): + return expr.is_Number and expr < 0 + """ + + if _coeff_isneg(expr): + neg_expr = -expr + if not neg_expr.is_Atom: + opt_subs[expr] = Mul(S.NegativeOne, neg_expr, evaluate=False) + seen_subexp.add(neg_expr) + expr = neg_expr + + if isinstance(expr, (Mul, MatMul)): + muls.add(expr) + + elif isinstance(expr, (Add, MatAdd)): + adds.add(expr) + + elif isinstance(expr, (Pow, MatPow)): + try: + # symengine + base, exp = expr.args + except ValueError: + base = expr.base + exp = expr.exp + if _coeff_isneg(exp): + opt_subs[expr] = Pow(Pow(base, -exp), -1, evaluate=False) + + for e in exprs: + if isinstance(e, Basic): + _find_opts(e) + + ## Process Adds and commutative Muls + + import time + + start = time.time() + + import logging + + logger = logging.getLogger(__name__) + logger.info("cse: start") + + niter = 0 + # split muls into commutative + comutative_muls = set() + for m in muls: + if len(m.args) > 1: + comutative_muls.add(m) + """ + c, nc = m.args_cnc(cset=True) + if c: + c_mul = m.func(*c) + if nc: + if c_mul == 1: + new_obj = m.func(*nc) + else: + new_obj = m.func(c_mul, m.func(*nc), evaluate=False) + opt_subs[m] = new_obj + if len(c) > 1: + comutative_muls.add(c_mul) + """ + + match_common_args(Add, adds, "add", opt_subs) + match_common_args(Mul, comutative_muls, "mul", opt_subs) + logger.info("cse: done {}".format(time.time() - start)) + return opt_subs + + +def tree_cse(exprs, symbols, opt_subs=None, order='none'): + """Perform raw CSE on expression tree, taking opt_subs into account. + + Parameters + ========== + + exprs : list of sympy expressions + The expressions to reduce. + symbols : infinite iterator yielding unique Symbols + The symbols used to label the common subexpressions which are pulled + out. + opt_subs : dictionary of expression substitutions + The expressions to be substituted before any CSE action is performed. + order : string, 'none' or 'canonical' + The order by which Mul and Add arguments are processed. For large + expressions where speed is a concern, use the setting order='none'. + """ + from sympy.matrices.expressions import MatrixExpr, MatrixSymbol, MatMul, MatAdd + + if opt_subs is None: + opt_subs = dict() + + ## Find repeated sub-expressions + + to_eliminate = set() + + seen_subexp = set() + + def _find_repeated(expr): + if not isinstance(expr, Basic): + return + + if expr.is_Atom: #or expr.is_Order: + return + + if iterable(expr): + args = expr + + else: + if expr in seen_subexp: + to_eliminate.add(expr) + return + + seen_subexp.add(expr) + + if expr in opt_subs: + expr = opt_subs[expr] + + args = expr.args + + list(map(_find_repeated, args)) + + for e in exprs: + if isinstance(e, Basic): + _find_repeated(e) + + ## Rebuild tree + + replacements = [] + + subs = dict() + + def _rebuild(expr): + if not isinstance(expr, Basic): + return expr + + if not expr.args: + return expr + + if iterable(expr): + new_args = [_rebuild(arg) for arg in expr] + return expr.func(*new_args) + + if expr in subs: + return subs[expr] + + orig_expr = expr + if expr in opt_subs: + expr = opt_subs[expr] + + # If enabled, parse Muls and Adds arguments by order to ensure + # replacement order independent from hashes + if order != 'none': + if isinstance(expr, (Mul, MatMul)): + #c, nc = expr.args_cnc() + #if c == [1]: + # args = nc + #else: + # args = list(ordered(c)) + nc + args = expr.args + elif isinstance(expr, (Add, MatAdd)): + args = list(ordered(expr.args)) + else: + args = expr.args + else: + args = expr.args + + new_args = list(map(_rebuild, args)) + if new_args != args: + new_expr = expr.func(*new_args) + else: + new_expr = expr + + if orig_expr in to_eliminate: + try: + sym = next(symbols) + except StopIteration: + raise ValueError("Symbols iterator ran out of symbols.") + + if isinstance(orig_expr, MatrixExpr): + sym = MatrixSymbol(sym.name, orig_expr.rows, + orig_expr.cols) + + subs[orig_expr] = sym + replacements.append((sym, new_expr)) + return sym + + else: + return new_expr + + reduced_exprs = [] + for e in exprs: + if isinstance(e, Basic): + reduced_e = _rebuild(e) + else: + reduced_e = e + reduced_exprs.append(reduced_e) + + return replacements, reduced_exprs + + +def cse(exprs, symbols=None, optimizations=None, postprocess=None, + order='none'): + """ Perform common subexpression elimination on an expression. + + Parameters + ========== + + exprs : list of sympy expressions, or a single sympy expression + The expressions to reduce. + symbols : infinite iterator yielding unique Symbols + The symbols used to label the common subexpressions which are pulled + out. The ``numbered_symbols`` generator is useful. The default is a + stream of symbols of the form "x0", "x1", etc. This must be an + infinite iterator. + optimizations : list of (callable, callable) pairs + The (preprocessor, postprocessor) pairs of external optimization + functions. Optionally 'basic' can be passed for a set of predefined + basic optimizations. Such 'basic' optimizations were used by default + in old implementation, however they can be really slow on larger + expressions. Now, no pre or post optimizations are made by default. + postprocess : a function which accepts the two return values of cse and + returns the desired form of output from cse, e.g. if you want the + replacements reversed the function might be the following lambda: + lambda r, e: return reversed(r), e + order : string, 'none' or 'canonical' + The order by which Mul and Add arguments are processed. If set to + 'canonical', arguments will be canonically ordered. If set to 'none', + ordering will be faster but dependent on expressions hashes, thus + machine dependent and variable. For large expressions where speed is a + concern, use the setting order='none'. + + Returns + ======= + + replacements : list of (Symbol, expression) pairs + All of the common subexpressions that were replaced. Subexpressions + earlier in this list might show up in subexpressions later in this + list. + reduced_exprs : list of sympy expressions + The reduced expressions with all of the replacements above. + + Examples + ======== + + >>> from sympy import cse, SparseMatrix + >>> from sympy.abc import x, y, z, w + >>> cse(((w + x + y + z)*(w + y + z))/(w + x)**3) + ([(x0, y + z), (x1, w + x)], [(w + x0)*(x0 + x1)/x1**3]) + + Note that currently, y + z will not get substituted if -y - z is used. + + >>> cse(((w + x + y + z)*(w - y - z))/(w + x)**3) + ([(x0, w + x)], [(w - y - z)*(x0 + y + z)/x0**3]) + """ + if isinstance(exprs, Basic): + exprs = [exprs] + + exprs = list(exprs) + + if optimizations is None: + optimizations = [] + + # Preprocess the expressions to give us better optimization opportunities. + reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs] + + if symbols is None: + symbols = numbered_symbols() + else: + # In case we get passed an iterable with an __iter__ method instead of + # an actual iterator. + symbols = iter(symbols) + + # Remove symbols from the generator that conflict with names in the expressions. + excluded_symbols = set().union(*[expr.atoms(Symbol) for expr in reduced_exprs]) + symbols = (symbol for symbol in symbols if symbol not in excluded_symbols) + + # Find other optimization opportunities. + opt_subs = opt_cse(reduced_exprs, order) + + # Main CSE algorithm. + replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs, order) + + # Postprocess the expressions to return the expressions to canonical form. + for i, (sym, subtree) in enumerate(replacements): + subtree = postprocess_for_cse(subtree, optimizations) + replacements[i] = (sym, subtree) + reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs] + + if postprocess is None: + return replacements, reduced_exprs + + return postprocess(replacements, reduced_exprs) diff --git a/sumpy/e2e.py b/sumpy/e2e.py index f5b34940..7c28c968 100644 --- a/sumpy/e2e.py +++ b/sumpy/e2e.py @@ -76,6 +76,23 @@ class E2EBase(KernelCacheWrapper): self.dim = src_expansion.dim + def run_translation_and_cse(self): + from sumpy.symbolic import make_sympy_vector + dvec = make_sympy_vector("d", self.dim) + + src_coeff_exprs = [sp.Symbol("src_coeff%d" % i) + for i in range(len(self.src_expansion))] + + from sumpy.assignment_collection import SymbolicAssignmentCollection + sac = SymbolicAssignmentCollection() + tgt_coeff_names = [ + sac.assign_unique("coeff%d" % i, coeff_i) + for i, coeff_i in enumerate( + self.tgt_expansion.translate_from( + self.src_expansion, src_coeff_exprs, dvec))] + + sac.run_global_cse() + def get_translation_loopy_insns(self): from sumpy.symbolic import make_sympy_vector dvec = make_sympy_vector("d", self.dim) diff --git a/test/test_cse.py b/test/test_cse.py new file mode 100644 index 00000000..185359cb --- /dev/null +++ b/test/test_cse.py @@ -0,0 +1,403 @@ +from __future__ import print_function, division + +__copyright__ = """ +Copyright (C) 2017 Matt Wala +Copyright (C) 2006-2016 SymPy Development Team +""" + +# {{{ license and original license + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +=========================================================================== + +Based on sympy/simplify/tests/test_cse.py from SymPy 1.0, original license as +follows: + +Copyright (c) 2006-2016 SymPy Development Team + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + a. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + b. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + c. Neither the name of SymPy nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +DAMAGE. +""" + +# }}} + +import pytest +import itertools +import sys + +from sympy import (Add, Pow, Symbol, exp, sqrt, symbols, sympify, cse, + Matrix, S, cos, sin, Eq, Function, Tuple, CRootOf, + IndexedBase, Idx, Piecewise, O) +from sympy.simplify.cse_opts import sub_pre, sub_post +from sympy.functions.special.hyper import meijerg +from sympy.simplify import cse_main, cse_opts +from sympy.utilities.pytest import XFAIL +from sympy.matrices import (eye, SparseMatrix, MutableDenseMatrix, + MutableSparseMatrix, ImmutableDenseMatrix, ImmutableSparseMatrix) +from sympy.matrices.expressions import MatrixSymbol + +from sympy.core.compatibility import range + + +from sumpy.cse import cse + + +w, x, y, z = symbols('w,x,y,z') +x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12 = symbols('x:13') + + +def test_numbered_symbols(): + ns = cse_main.numbered_symbols(prefix='y') + assert list(itertools.islice( + ns, 0, 10)) == [Symbol('y%s' % i) for i in range(0, 10)] + ns = cse_main.numbered_symbols(prefix='y') + assert list(itertools.islice( + ns, 10, 20)) == [Symbol('y%s' % i) for i in range(10, 20)] + ns = cse_main.numbered_symbols() + assert list(itertools.islice( + ns, 0, 10)) == [Symbol('x%s' % i) for i in range(0, 10)] + +# Dummy "optimization" functions for testing. + + +def opt1(expr): + return expr + y + + +def opt2(expr): + return expr*z + + +def test_preprocess_for_cse(): + assert cse_main.preprocess_for_cse(x, [(opt1, None)]) == x + y + assert cse_main.preprocess_for_cse(x, [(None, opt1)]) == x + assert cse_main.preprocess_for_cse(x, [(None, None)]) == x + assert cse_main.preprocess_for_cse(x, [(opt1, opt2)]) == x + y + assert cse_main.preprocess_for_cse( + x, [(opt1, None), (opt2, None)]) == (x + y)*z + + +def test_postprocess_for_cse(): + assert cse_main.postprocess_for_cse(x, [(opt1, None)]) == x + assert cse_main.postprocess_for_cse(x, [(None, opt1)]) == x + y + assert cse_main.postprocess_for_cse(x, [(None, None)]) == x + assert cse_main.postprocess_for_cse(x, [(opt1, opt2)]) == x*z + # Note the reverse order of application. + assert cse_main.postprocess_for_cse( + x, [(None, opt1), (None, opt2)]) == x*z + y + + +def test_cse_single(): + # Simple substitution. + e = Add(Pow(x + y, 2), sqrt(x + y)) + substs, reduced = cse([e]) + assert substs == [(x0, x + y)] + assert reduced == [sqrt(x0) + x0**2] + + +def test_cse_not_possible(): + # No substitution possible. + e = Add(x, y) + substs, reduced = cse([e]) + assert substs == [] + assert reduced == [x + y] + # issue 6329 + eq = (meijerg((1, 2), (y, 4), (5,), [], x) + + meijerg((1, 3), (y, 4), (5,), [], x)) + assert cse(eq) == ([], [eq]) + + +def test_nested_substitution(): + # Substitution within a substitution. + e = Add(Pow(w*x + y, 2), sqrt(w*x + y)) + substs, reduced = cse([e]) + assert substs == [(x0, w*x + y)] + assert reduced == [sqrt(x0) + x0**2] + + +def test_subtraction_opt(): + # Make sure subtraction is optimized. + e = (x - y)*(z - y) + exp((x - y)*(z - y)) + substs, reduced = cse( + [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) + assert substs == [(x0, (x - y)*(y - z))] + assert reduced == [-x0 + exp(-x0)] + e = -(x - y)*(z - y) + exp(-(x - y)*(z - y)) + substs, reduced = cse( + [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) + assert substs == [(x0, (x - y)*(y - z))] + assert reduced == [x0 + exp(x0)] + # issue 4077 + n = -1 + 1/x + e = n/x/(-n)**2 - 1/n/x + assert cse(e, optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) == \ + ([], [0]) + + +def test_multiple_expressions(): + e1 = (x + y)*z + e2 = (x + y)*w + substs, reduced = cse([e1, e2]) + assert substs == [(x0, x + y)] + assert reduced == [x0*z, x0*w] + l = [w*x*y + z, w*y] + substs, reduced = cse(l) + rsubsts, _ = cse(reversed(l)) + assert substs == rsubsts + assert reduced == [z + x*x0, x0] + l = [w*x*y, w*x*y + z, w*y] + substs, reduced = cse(l) + rsubsts, _ = cse(reversed(l)) + assert substs == rsubsts + assert reduced == [x1, x1 + z, x0] + l = [(x - z)*(y - z), x - z, y - z] + substs, reduced = cse(l) + rsubsts, _ = cse(reversed(l)) + assert substs == [(x0, -z), (x1, x + x0), (x2, x0 + y)] + assert rsubsts == [(x0, -z), (x1, x0 + y), (x2, x + x0)] + assert reduced == [x1*x2, x1, x2] + l = [w*y + w + x + y + z, w*x*y] + assert cse(l) == ([(x0, w*y)], [w + x + x0 + y + z, x*x0]) + assert cse([x + y, x + y + z]) == ([(x0, x + y)], [x0, z + x0]) + assert cse([x + y, x + z]) == ([], [x + y, x + z]) + assert cse([x*y, z + x*y, x*y*z + 3]) == \ + ([(x0, x*y)], [x0, z + x0, 3 + x0*z]) + + +""" +def test_issue_4498(): + assert cse(w/(x - y) + z/(y - x), optimizations='basic') == \ + ([], [(w - z)/(x - y)]) + + +def test_issue_4020(): + assert cse(x**5 + x**4 + x**3 + x**2, optimizations='basic') \ + == ([(x0, x**2)], [x0*(x**3 + x + x0 + 1)]) + + +def test_issue_4203(): + assert cse(sin(x**x)/x**x) == ([(x0, x**x)], [sin(x0)/x0]) + + +def test_issue_6263(): + e = Eq(x*(-x + 1) + x*(x - 1), 0) + assert cse(e, optimizations='basic') == ([], [True]) +""" + + +def test_dont_cse_tuples(): + from sympy import Subs + f = Function("f") + g = Function("g") + + name_val, (expr,) = cse( + Subs(f(x, y), (x, y), (0, 1)) + + Subs(g(x, y), (x, y), (0, 1))) + + assert name_val == [] + assert expr == (Subs(f(x, y), (x, y), (0, 1)) + + Subs(g(x, y), (x, y), (0, 1))) + + name_val, (expr,) = cse( + Subs(f(x, y), (x, y), (0, x + y)) + + Subs(g(x, y), (x, y), (0, x + y))) + + assert name_val == [(x0, x + y)] + assert expr == Subs(f(x, y), (x, y), (0, x0)) + \ + Subs(g(x, y), (x, y), (0, x0)) + + +def test_pow_invpow(): + assert cse(1/x**2 + x**2) == \ + ([(x0, x**2)], [x0 + 1/x0]) + assert cse(x**2 + (1 + 1/x**2)/x**2) == \ + ([(x0, x**2), (x1, 1/x0)], [x0 + x1*(x1 + 1)]) + assert cse(1/x**2 + (1 + 1/x**2)*x**2) == \ + ([(x0, x**2), (x1, 1/x0)], [x0*(x1 + 1) + x1]) + assert cse(cos(1/x**2) + sin(1/x**2)) == \ + ([(x0, x**(-2))], [sin(x0) + cos(x0)]) + assert cse(cos(x**2) + sin(x**2)) == \ + ([(x0, x**2)], [sin(x0) + cos(x0)]) + assert cse(y/(2 + x**2) + z/x**2/y) == \ + ([(x0, x**2)], [y/(x0 + 2) + z/(x0*y)]) + assert cse(exp(x**2) + x**2*cos(1/x**2)) == \ + ([(x0, x**2)], [x0*cos(1/x0) + exp(x0)]) + assert cse((1 + 1/x**2)/x**2) == \ + ([(x0, x**(-2))], [x0*(x0 + 1)]) + assert cse(x**(2*y) + x**(-2*y)) == \ + ([(x0, x**(2*y))], [x0 + 1/x0]) + + +""" +fixme +def test_postprocess(): + eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1)) + assert cse([eq, Eq(x, z + 1), z - 2, (z + 1)*(x + 1)], + postprocess=cse_main.cse_separate) == \ + [[(x0, y + 1), (x2, z + 1), (x, x2), (x1, x + 1)], + [x1 + exp(x1/x0) + cos(x0), z - 2, x1*x2]] +""" + + +def test_issue_4499(): + # previously, this gave 16 constants + from sympy.abc import a, b + B = Function('B') + G = Function('G') + t = Tuple(* + (a, a + S(1)/2, 2*a, b, 2*a - b + 1, (sqrt(z)/2)**(-2*a + 1)*B(2*a - + b, sqrt(z))*B(b - 1, sqrt(z))*G(b)*G(2*a - b + 1), + sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b, + sqrt(z))*G(b)*G(2*a - b + 1), sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b - 1, + sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1), + (sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b + 1, + sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, S(1)/2, z/2, -b + 1, -2*a + b, + -2*a)) + c = cse(t) + ans = ( + [(x0, 2*a), (x1, -b), (x2, x1 + 1), (x3, x0 + x2), (x4, sqrt(z)), (x5, + B(x0 + x1, x4)), (x6, G(b)), (x7, G(x3)), (x8, -x0), (x9, + (x4/2)**(x8 + 1)), (x10, x6*x7*x9*B(b - 1, x4)), (x11, x6*x7*x9*B(b, + x4)), (x12, B(x3, x4))], [(a, a + S(1)/2, x0, b, x3, x10*x5, + x11*x4*x5, x10*x12*x4, x11*x12, 1, 0, S(1)/2, z/2, x2, b + x8, x8)]) + assert len(c[0]) == 13 + + +def test_issue_6169(): + r = CRootOf(x**6 - 4*x**5 - 2, 1) + assert cse(r) == ([], [r]) + # and a check that the right thing is done with the new + # mechanism + assert sub_post(sub_pre((-x - y)*z - x - y)) == -z*(x + y) - x - y + + +def test_cse_Indexed(): + len_y = 5 + y = IndexedBase('y', shape=(len_y,)) + x = IndexedBase('x', shape=(len_y,)) + Dy = IndexedBase('Dy', shape=(len_y-1,)) + i = Idx('i', len_y-1) + + expr1 = (y[i+1]-y[i])/(x[i+1]-x[i]) + expr2 = 1/(x[i+1]-x[i]) + replacements, reduced_exprs = cse([expr1, expr2]) + assert len(replacements) > 0 + + +def test_Piecewise(): + f = Piecewise((-z + x*y, Eq(y, 0)), (-z - x*y, True)) + ans = cse(f) + actual_ans = ([(x0, -z), (x1, x*y)], [Piecewise((x0+x1, Eq(y, 0)), (x0 - x1, True))]) + assert ans == actual_ans + + +def test_name_conflict(): + z1 = x0 + y + z2 = x2 + x3 + l = [cos(z1) + z1, cos(z2) + z2, x0 + x2] + substs, reduced = cse(l) + assert [e.subs(reversed(substs)) for e in reduced] == l + + +def test_name_conflict_cust_symbols(): + z1 = x0 + y + z2 = x2 + x3 + l = [cos(z1) + z1, cos(z2) + z2, x0 + x2] + substs, reduced = cse(l, symbols("x:10")) + assert [e.subs(reversed(substs)) for e in reduced] == l + + +def test_symbols_exhausted_error(): + l = cos(x+y)+x+y+cos(w+y)+sin(w+y) + sym = [x, y, z] + with pytest.raises(ValueError) as excinfo: + print(cse(l, symbols=sym)) + + +def test_issue_7840(): + # daveknippers' example + C393 = sympify( \ + 'Piecewise((C391 - 1.65, C390 < 0.5), (Piecewise((C391 - 1.65, \ + C391 > 2.35), (C392, True)), True))' + ) + C391 = sympify( \ + 'Piecewise((2.05*C390**(-1.03), C390 < 0.5), (2.5*C390**(-0.625), True))' + ) + C393 = C393.subs('C391',C391) + # simple substitution + sub = {} + sub['C390'] = 0.703451854 + sub['C392'] = 1.01417794 + ss_answer = C393.subs(sub) + # cse + substitutions,new_eqn = cse(C393) + for pair in substitutions: + sub[pair[0].name] = pair[1].subs(sub) + cse_answer = new_eqn[0].subs(sub) + # both methods should be the same + assert ss_answer == cse_answer + + # GitRay's example + expr = sympify( + "Piecewise((Symbol('ON'), Equality(Symbol('mode'), Symbol('ON'))), \ + (Piecewise((Piecewise((Symbol('OFF'), StrictLessThan(Symbol('x'), \ + Symbol('threshold'))), (Symbol('ON'), S.true)), Equality(Symbol('mode'), \ + Symbol('AUTO'))), (Symbol('OFF'), S.true)), S.true))" + ) + substitutions, new_eqn = cse(expr) + # this Piecewise should be exactly the same + assert new_eqn[0] == expr + # there should not be any replacements + assert len(substitutions) < 1 + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from py.test.cmdline import main + main([__file__]) + +# vim: fdm=marker -- GitLab From c2815afc14bda389ecd1d0774d56ba6b2a37d453 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Tue, 10 Jan 2017 00:17:20 -0600 Subject: [PATCH 02/10] [wip --- examples/m2l-timing.py | 12 +- sumpy/assignment_collection.py | 6 +- sumpy/cse.py | 319 ++++++++++++++------------------- test/test_cse.py | 105 +++-------- 4 files changed, 178 insertions(+), 264 deletions(-) diff --git a/examples/m2l-timing.py b/examples/m2l-timing.py index f2fd3a26..8596f891 100644 --- a/examples/m2l-timing.py +++ b/examples/m2l-timing.py @@ -14,15 +14,17 @@ def test_m2l_creation(ctx, mpole_expn_class, local_expn_class, knl, order): if __name__ == "__main__": import logging logging.basicConfig(level=logging.INFO) - from sumpy.kernel import LaplaceKernel + from sumpy.kernel import HelmholtzKernel, LaplaceKernel import pyopencl as cl ctx = cl._csc() - from sumpy.expansion.local import LaplaceConformingVolumeTaylorLocalExpansion as LExpn - from sumpy.expansion.multipole import LaplaceConformingVolumeTaylorMultipoleExpansion as MExpn + from sumpy.expansion.local import HelmholtzConformingVolumeTaylorLocalExpansion as LExpn + from sumpy.expansion.multipole import HelmholtzConformingVolumeTaylorMultipoleExpansion as MExpn + #from sumpy.expansion.local import H2DLocalExpansion as LExpn + #from sumpy.expansion.multipole import H2DMultipoleExpansion as MExpn results = [] - for order in range(30, 31): - results.append((order, test_m2l_creation(ctx, MExpn, LExpn, LaplaceKernel(2), order))) + for order in range(1, 2): + results.append((order, test_m2l_creation(ctx, MExpn, LExpn, HelmholtzKernel(2, 'k'), order))) print("order\ttime (s)") for order, time in results: print("{}\t{:.2f}".format(order, time)) diff --git a/sumpy/assignment_collection.py b/sumpy/assignment_collection.py index f7c96b0b..ce19f418 100644 --- a/sumpy/assignment_collection.py +++ b/sumpy/assignment_collection.py @@ -103,10 +103,12 @@ def cached_cse(exprs, symbols): s2p = SympyToPymbolicMapper() p2s = PymbolicToSympyMapper() + print(exprs) key_exprs = tuple(s2p(expr) for expr in exprs) - key = (key_exprs, frozenset(symbols.taken_symbols), - frozenset(symbols.generated_names)) + key = (key_exprs) + + print(key) try: result = cache_dict[key] diff --git a/sumpy/cse.py b/sumpy/cse.py index a7b6496e..9bea6e56 100644 --- a/sumpy/cse.py +++ b/sumpy/cse.py @@ -1,5 +1,3 @@ -""" Tools for doing common subexpression elimination. -""" from __future__ import print_function, division __copyright__ = """ @@ -64,45 +62,23 @@ DAMAGE. # }}} -from sympy.core import Basic, Mul, Add, Pow, sympify, Symbol, Tuple -#from symengine.sympy_compat import (Basic, Mul, Add, Pow, sympify, Symbol) -from sympy.core.singleton import S +from sympy.core import Basic, Mul, Add, Pow, Symbol from sympy.core.function import _coeff_isneg -from sympy.core.exprtools import factor_terms -from sympy.core.compatibility import iterable, range -from sympy.utilities.iterables import filter_symbols, \ - numbered_symbols, sift, topological_sort, ordered - -#from . import cse_opts - -# (preprocessor, postprocessor) pairs which are commonly useful. They should -# each take a sympy expression and return a possibly transformed expression. -# When used in the function ``cse()``, the target expressions will be transformed -# by each of the preprocessor functions in order. After the common -# subexpressions are eliminated, each resulting expression will have the -# postprocessor functions transform them in *reverse* order in order to undo the -# transformation if necessary. This allows the algorithm to operate on -# a representation of the expressions that allows for more optimization -# opportunities. -# ``None`` can be used to specify no transformation for either the preprocessor or -# postprocessor. +from sympy.core.compatibility import iterable +from sympy.utilities.iterables import numbered_symbols + +# {{{ cse pre/postprocessing def preprocess_for_cse(expr, optimizations): - """ Preprocess an expression to optimize for common subexpression - elimination. + """ + Preprocess an expression to optimize for common subexpression elimination. - Parameters - ---------- - expr : sympy expression - The target expression to optimize. - optimizations : list of (callable, callable) pairs - The (preprocessor, postprocessor) pairs. + :arg expr: A sympy expression, the target expression to optimize. + :arg optimizations: A list of (callable, callable) pairs, + the (preprocessor, postprocessor) pairs. - Returns - ------- - expr : sympy expression - The transformed expression. + :return: The transformed expression. """ for pre, post in optimizations: if pre is not None: @@ -111,30 +87,32 @@ def preprocess_for_cse(expr, optimizations): def postprocess_for_cse(expr, optimizations): - """ Postprocess an expression after common subexpression elimination to + """ + Postprocess an expression after common subexpression elimination to return the expression to canonical sympy form. - Parameters - ---------- - expr : sympy expression - The target expression to transform. - optimizations : list of (callable, callable) pairs, optional - The (preprocessor, postprocessor) pairs. The postprocessors will be + :arg expr: sympy expression, the target expression to transform. + :arg optimizations: A list of (callable, callable) pairs (optional), + the (preprocessor, postprocessor) pairs. The postprocessors will be applied in reversed order to undo the effects of the preprocessors correctly. - Returns - ------- - expr : sympy expression - The transformed expression. + :return: The transformed expression. """ for pre, post in reversed(optimizations): if post is not None: expr = post(expr) return expr +# }}} + + +# {{{ opt cse class FuncArgTracker(object): + """ + A class which manages an inverse mapping from arguments to functions. + """ def __init__(self, funcs): # To minimize the number of symbolic comparisons, all function arguments @@ -157,9 +135,16 @@ class FuncArgTracker(object): self.func_to_argset.append(func_argset) def get_args_in_value_order(self, argset): + """ + Return the list of arguments in sorted order according to their value + numbers. + """ return [self.value_number_to_value[argn] for argn in sorted(argset)] def get_or_add_value_number(self, value): + """ + Return the value number for the given argument. + """ nvalues = len(self.value_numbers) value_number = self.value_numbers.setdefault(value, nvalues) if value_number == nvalues: @@ -168,10 +153,17 @@ class FuncArgTracker(object): return value_number def stop_arg_tracking(self, func_i): + """ + Remove the function func_i from the argument to function mapping. + """ for arg in self.func_to_argset[func_i]: self.arg_to_funcset[arg].remove(func_i) def gen_common_arg_candidates(self, argset, min_func_i, threshold=2): + """ + Generate the list of functions which have at least `threshold` arguments in + common from `argset`. + """ from collections import defaultdict count_map = defaultdict(lambda: 0) @@ -188,6 +180,9 @@ class FuncArgTracker(object): yield item def gen_subset_candidates(self, argset, min_func_i): + """ + Generate the list of functions whose set of arguments contains `argset`. + """ iarg = iter(argset) indices = set( @@ -201,6 +196,9 @@ class FuncArgTracker(object): yield item def update_func_argset(self, func_i, new_argset): + """ + Update a function with a new set of arguments. + """ new_args = set(new_argset) old_args = self.func_to_argset[func_i] @@ -213,15 +211,33 @@ class FuncArgTracker(object): self.func_to_argset[func_i].update(new_args) -def match_common_args(Func, funcs, kind, opt_subs): - funcs = list(funcs) +def match_common_args(func_class, funcs, opt_subs): + """ + Recognize and extract common subexpressions of function arguments within a + set of function calls. For instance, for the following function calls:: + + x + z + y + sin(x + y) + + this will extract a common subexpression of `x + y`:: + + w = x + y + w + z + sin(w) + + The function we work with is assumed to be associative and commutative. + + :arg func_class: The function class (e.g. Add, Mul) + :arg funcs: A list of function calls + :arg opt_subs: A dictionary of substitutions which this function may update + """ arg_tracker = FuncArgTracker(funcs) changed = set() for i in range(len(funcs)): for j in arg_tracker.gen_common_arg_candidates( - arg_tracker.func_to_argset[i], i + 1, threshold=2): + arg_tracker.func_to_argset[i], i + 1, threshold=2): com_args = arg_tracker.func_to_argset[i].intersection( arg_tracker.func_to_argset[j]) @@ -231,11 +247,11 @@ def match_common_args(Func, funcs, kind, opt_subs): # combined in a previous iteration. continue - com_func = Func(*arg_tracker.get_args_in_value_order(com_args)) + com_func = func_class(*arg_tracker.get_args_in_value_order(com_args)) com_func_number = arg_tracker.get_or_add_value_number(com_func) - # for all sets, replace the common symbols by the function - # over them, to allow recursive matches + # For all sets, replace the common symbols by the function + # over them, to allow recursive matches. diff_i = arg_tracker.func_to_argset[i].difference(com_args) arg_tracker.update_func_argset(i, diff_i | set([com_func_number])) @@ -251,7 +267,7 @@ def match_common_args(Func, funcs, kind, opt_subs): changed.add(k) if i in changed: - opt_subs[funcs[i]] = Func( + opt_subs[funcs[i]] = func_class( *arg_tracker.get_args_in_value_order(arg_tracker.func_to_argset[i]), evaluate=False) @@ -259,115 +275,102 @@ def match_common_args(Func, funcs, kind, opt_subs): def opt_cse(exprs): - """Find optimization opportunities in Adds, Muls, Pows and negative - coefficient Muls - - Parameters - ---------- - exprs : list of sympy expressions - The expressions to optimize. + """ + Find optimization opportunities in Adds, Muls, Pows and negative coefficient + Muls - Returns - ------- - opt_subs : dictionary of expression substitutions - The expression substitutions which can be useful to optimize CSE. + :arg exprs: A list of sympy expressions: the expressions to optimize. + :return: A dictionary of expression substitutions """ opt_subs = dict() - adds = set() - muls = set() + adds = [] + muls = [] seen_subexp = set() - def _find_opts(expr): + # {{{ look for optimization opportunities, clean up minus signs + + def find_opts(expr): if not isinstance(expr, Basic): return - if expr.is_Atom: # or expr.is_Order: + if expr.is_Atom: return if iterable(expr): - list(map(_find_opts, expr)) + for item in expr: + find_opts(item) return if expr in seen_subexp: return expr - seen_subexp.add(expr) - list(map(_find_opts, expr.args)) + seen_subexp.add(expr) - """ - def _coeff_isneg(expr): - return expr.is_Number and expr < 0 - """ + for arg in expr.args: + find_opts(arg) if _coeff_isneg(expr): neg_expr = -expr if not neg_expr.is_Atom: - opt_subs[expr] = Mul(S.NegativeOne, neg_expr, evaluate=False) + opt_subs[expr] = Mul(-1, neg_expr, evaluate=False) seen_subexp.add(neg_expr) expr = neg_expr if isinstance(expr, Mul): - muls.add(expr) + muls.append(expr) elif isinstance(expr, Add): - adds.add(expr) + adds.append(expr) elif isinstance(expr, Pow): - try: - # symengine - base, exp = expr.args - except ValueError: - # sympy - base = expr.base - exp = expr.exp + base, exp = expr.args if _coeff_isneg(exp): opt_subs[expr] = Pow(Pow(base, -exp), -1, evaluate=False) + # }}} + for e in exprs: if isinstance(e, Basic): - _find_opts(e) + find_opts(e) - ## Process Adds and Muls - - match_common_args(Add, adds, "add", opt_subs) - match_common_args(Mul, muls, "mul", opt_subs) + match_common_args(Add, adds, opt_subs) + match_common_args(Mul, muls, opt_subs) return opt_subs +# }}} -def tree_cse(exprs, symbols, opt_subs=None, order='none'): - """Perform raw CSE on expression tree, taking opt_subs into account. - Parameters - ========== +# {{{ tree cse + +def tree_cse(exprs, symbols, opt_subs=None): + """ + Perform raw CSE on an expression tree, taking opt_subs into account. + + :arg exprs: A list of sympy expressions to reduce + :arg symbols: An infinite iterator yielding unique Symbols used to label + the common subexpressions which are pulled out. + :arg opt_subs: A dictionary of expression substitutions to be + substituted before any CSE action is performed. - exprs : list of sympy expressions - The expressions to reduce. - symbols : infinite iterator yielding unique Symbols - The symbols used to label the common subexpressions which are pulled - out. - opt_subs : dictionary of expression substitutions - The expressions to be substituted before any CSE action is performed. - order : string, 'none' or 'canonical' - The order by which Mul and Add arguments are processed. For large - expressions where speed is a concern, use the setting order='none'. + :return: A pair (replacements, reduced exprs) """ if opt_subs is None: opt_subs = dict() - ## Find repeated sub-expressions + # {{{ find repeated sub-expressions to_eliminate = set() seen_subexp = set() - def _find_repeated(expr): + def find_repeated(expr): if not isinstance(expr, Basic): return - if expr.is_Atom: #or expr.is_Order: + if expr.is_Atom: return if iterable(expr): @@ -385,19 +388,22 @@ def tree_cse(exprs, symbols, opt_subs=None, order='none'): args = expr.args - list(map(_find_repeated, args)) + for arg in args: + find_repeated(arg) + + # }}} for e in exprs: if isinstance(e, Basic): - _find_repeated(e) + find_repeated(e) - ## Rebuild tree + # {{{ rebuild tree replacements = [] subs = dict() - def _rebuild(expr): + def rebuild(expr): if not isinstance(expr, Basic): return expr @@ -405,7 +411,7 @@ def tree_cse(exprs, symbols, opt_subs=None, order='none'): return expr if iterable(expr): - new_args = [_rebuild(arg) for arg in expr] + new_args = [rebuild(arg) for arg in expr] return expr.func(*new_args) if expr in subs: @@ -415,25 +421,8 @@ def tree_cse(exprs, symbols, opt_subs=None, order='none'): if expr in opt_subs: expr = opt_subs[expr] - # If enabled, parse Muls and Adds arguments by order to ensure - # replacement order independent from hashes - if order != 'none': - if isinstance(expr, (Mul, MatMul)): - #c, nc = expr.args_cnc() - #if c == [1]: - # args = nc - #else: - # args = list(ordered(c)) + nc - args = expr.args - elif isinstance(expr, (Add, MatAdd)): - args = list(ordered(expr.args)) - else: - args = expr.args - else: - args = expr.args - - new_args = list(map(_rebuild, args)) - if new_args != args: + new_args = [rebuild(arg) for arg in expr.args] + if new_args != expr.args: new_expr = expr.func(*new_args) else: new_expr = expr @@ -448,50 +437,33 @@ def tree_cse(exprs, symbols, opt_subs=None, order='none'): replacements.append((sym, new_expr)) return sym - else: - return new_expr + return new_expr + + # }}} reduced_exprs = [] for e in exprs: if isinstance(e, Basic): - reduced_e = _rebuild(e) + reduced_e = rebuild(e) else: reduced_e = e reduced_exprs.append(reduced_e) return replacements, reduced_exprs +# }}} + + +def cse(exprs, symbols=None, optimizations=None): + """Perform common subexpression elimination on an expression. -def cse(exprs, symbols=None, optimizations=None, postprocess=None, - order='none'): - """ Perform common subexpression elimination on an expression. - - Parameters - ========== - - exprs : list of sympy expressions, or a single sympy expression - The expressions to reduce. - symbols : infinite iterator yielding unique Symbols - The symbols used to label the common subexpressions which are pulled - out. The ``numbered_symbols`` generator is useful. The default is a - stream of symbols of the form "x0", "x1", etc. This must be an - infinite iterator. - optimizations : list of (callable, callable) pairs - The (preprocessor, postprocessor) pairs of external optimization - functions. Optionally 'basic' can be passed for a set of predefined - basic optimizations. Such 'basic' optimizations were used by default - in old implementation, however they can be really slow on larger - expressions. Now, no pre or post optimizations are made by default. - postprocess : a function which accepts the two return values of cse and - returns the desired form of output from cse, e.g. if you want the - replacements reversed the function might be the following lambda: - lambda r, e: return reversed(r), e - order : string, 'none' or 'canonical' - The order by which Mul and Add arguments are processed. If set to - 'canonical', arguments will be canonically ordered. If set to 'none', - ordering will be faster but dependent on expressions hashes, thus - machine dependent and variable. For large expressions where speed is a - concern, use the setting order='none'. + :arg exprs: A list of sympy expressions, or a single sympy expression to reduce + :arg symbols: An iterator yielding unique Symbols used to label the + common subexpressions which are pulled out. The ``numbered_symbols`` + generator is useful. The default is a stream of symbols of the form + "x0", "x1", etc. This must be an infinite iterator. + :arg optimizations: A list of (callable, callable) pairs consisting of + (preprocessor, postprocessor) pairs of external optimization functions. Returns ======= @@ -503,18 +475,6 @@ def cse(exprs, symbols=None, optimizations=None, postprocess=None, reduced_exprs : list of sympy expressions The reduced expressions with all of the replacements above. - Examples - ======== - - >>> from sympy import cse, SparseMatrix - >>> from sympy.abc import x, y, z, w - >>> cse(((w + x + y + z)*(w + y + z))/(w + x)**3) - ([(x0, y + z), (x1, w + x)], [(w + x0)*(x0 + x1)/x1**3]) - - Note that currently, y + z will not get substituted if -y - z is used. - - >>> cse(((w + x + y + z)*(w - y - z))/(w + x)**3) - ([(x0, w + x)], [(w - y - z)*(x0 + y + z)/x0**3]) """ if isinstance(exprs, Basic): exprs = [exprs] @@ -553,7 +513,7 @@ def cse(exprs, symbols=None, optimizations=None, postprocess=None, logger.info("cse: done after {}".format(time.time() - start)) # Main CSE algorithm. - replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs, order) + replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs) # Postprocess the expressions to return the expressions to canonical form. for i, (sym, subtree) in enumerate(replacements): @@ -561,7 +521,4 @@ def cse(exprs, symbols=None, optimizations=None, postprocess=None, replacements[i] = (sym, subtree) reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs] - if postprocess is None: - return replacements, reduced_exprs - - return postprocess(replacements, reduced_exprs) + return replacements, reduced_exprs diff --git a/test/test_cse.py b/test/test_cse.py index 185359cb..b6fefdbf 100644 --- a/test/test_cse.py +++ b/test/test_cse.py @@ -67,38 +67,24 @@ import pytest import itertools import sys -from sympy import (Add, Pow, Symbol, exp, sqrt, symbols, sympify, cse, - Matrix, S, cos, sin, Eq, Function, Tuple, CRootOf, - IndexedBase, Idx, Piecewise, O) +from sympy import (Add, Pow, Symbol, exp, sqrt, symbols, sympify, S, cos, + sin, Eq, Function, Tuple, CRootOf, IndexedBase, Idx, + Piecewise) from sympy.simplify.cse_opts import sub_pre, sub_post from sympy.functions.special.hyper import meijerg -from sympy.simplify import cse_main, cse_opts -from sympy.utilities.pytest import XFAIL -from sympy.matrices import (eye, SparseMatrix, MutableDenseMatrix, - MutableSparseMatrix, ImmutableDenseMatrix, ImmutableSparseMatrix) -from sympy.matrices.expressions import MatrixSymbol +from sympy.simplify import cse_opts from sympy.core.compatibility import range -from sumpy.cse import cse +from sumpy.cse import ( + cse, preprocess_for_cse, postprocess_for_cse) w, x, y, z = symbols('w,x,y,z') x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12 = symbols('x:13') -def test_numbered_symbols(): - ns = cse_main.numbered_symbols(prefix='y') - assert list(itertools.islice( - ns, 0, 10)) == [Symbol('y%s' % i) for i in range(0, 10)] - ns = cse_main.numbered_symbols(prefix='y') - assert list(itertools.islice( - ns, 10, 20)) == [Symbol('y%s' % i) for i in range(10, 20)] - ns = cse_main.numbered_symbols() - assert list(itertools.islice( - ns, 0, 10)) == [Symbol('x%s' % i) for i in range(0, 10)] - # Dummy "optimization" functions for testing. @@ -111,21 +97,21 @@ def opt2(expr): def test_preprocess_for_cse(): - assert cse_main.preprocess_for_cse(x, [(opt1, None)]) == x + y - assert cse_main.preprocess_for_cse(x, [(None, opt1)]) == x - assert cse_main.preprocess_for_cse(x, [(None, None)]) == x - assert cse_main.preprocess_for_cse(x, [(opt1, opt2)]) == x + y - assert cse_main.preprocess_for_cse( + assert preprocess_for_cse(x, [(opt1, None)]) == x + y + assert preprocess_for_cse(x, [(None, opt1)]) == x + assert preprocess_for_cse(x, [(None, None)]) == x + assert preprocess_for_cse(x, [(opt1, opt2)]) == x + y + assert preprocess_for_cse( x, [(opt1, None), (opt2, None)]) == (x + y)*z def test_postprocess_for_cse(): - assert cse_main.postprocess_for_cse(x, [(opt1, None)]) == x - assert cse_main.postprocess_for_cse(x, [(None, opt1)]) == x + y - assert cse_main.postprocess_for_cse(x, [(None, None)]) == x - assert cse_main.postprocess_for_cse(x, [(opt1, opt2)]) == x*z + assert postprocess_for_cse(x, [(opt1, None)]) == x + assert postprocess_for_cse(x, [(None, opt1)]) == x + y + assert postprocess_for_cse(x, [(None, None)]) == x + assert postprocess_for_cse(x, [(opt1, opt2)]) == x*z # Note the reverse order of application. - assert cse_main.postprocess_for_cse( + assert postprocess_for_cse( x, [(None, opt1), (None, opt2)]) == x*z + y @@ -206,27 +192,10 @@ def test_multiple_expressions(): ([(x0, x*y)], [x0, z + x0, 3 + x0*z]) -""" -def test_issue_4498(): - assert cse(w/(x - y) + z/(y - x), optimizations='basic') == \ - ([], [(w - z)/(x - y)]) - - -def test_issue_4020(): - assert cse(x**5 + x**4 + x**3 + x**2, optimizations='basic') \ - == ([(x0, x**2)], [x0*(x**3 + x + x0 + 1)]) - - def test_issue_4203(): assert cse(sin(x**x)/x**x) == ([(x0, x**x)], [sin(x0)/x0]) -def test_issue_6263(): - e = Eq(x*(-x + 1) + x*(x - 1), 0) - assert cse(e, optimizations='basic') == ([], [True]) -""" - - def test_dont_cse_tuples(): from sympy import Subs f = Function("f") @@ -270,22 +239,11 @@ def test_pow_invpow(): ([(x0, x**(2*y))], [x0 + 1/x0]) -""" -fixme -def test_postprocess(): - eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1)) - assert cse([eq, Eq(x, z + 1), z - 2, (z + 1)*(x + 1)], - postprocess=cse_main.cse_separate) == \ - [[(x0, y + 1), (x2, z + 1), (x, x2), (x1, x + 1)], - [x1 + exp(x1/x0) + cos(x0), z - 2, x1*x2]] -""" - - def test_issue_4499(): # previously, this gave 16 constants from sympy.abc import a, b - B = Function('B') - G = Function('G') + B = Function('B') # noqa + G = Function('G') # noqa t = Tuple(* (a, a + S(1)/2, 2*a, b, 2*a - b + 1, (sqrt(z)/2)**(-2*a + 1)*B(2*a - b, sqrt(z))*B(b - 1, sqrt(z))*G(b)*G(2*a - b + 1), @@ -294,14 +252,8 @@ def test_issue_4499(): sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1), (sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, S(1)/2, z/2, -b + 1, -2*a + b, - -2*a)) + -2*a)) # noqa c = cse(t) - ans = ( - [(x0, 2*a), (x1, -b), (x2, x1 + 1), (x3, x0 + x2), (x4, sqrt(z)), (x5, - B(x0 + x1, x4)), (x6, G(b)), (x7, G(x3)), (x8, -x0), (x9, - (x4/2)**(x8 + 1)), (x10, x6*x7*x9*B(b - 1, x4)), (x11, x6*x7*x9*B(b, - x4)), (x12, B(x3, x4))], [(a, a + S(1)/2, x0, b, x3, x10*x5, - x11*x4*x5, x10*x12*x4, x11*x12, 1, 0, S(1)/2, z/2, x2, b + x8, x8)]) assert len(c[0]) == 13 @@ -313,11 +265,11 @@ def test_issue_6169(): assert sub_post(sub_pre((-x - y)*z - x - y)) == -z*(x + y) - x - y -def test_cse_Indexed(): +def test_cse_Indexed(): # noqa len_y = 5 y = IndexedBase('y', shape=(len_y,)) x = IndexedBase('x', shape=(len_y,)) - Dy = IndexedBase('Dy', shape=(len_y-1,)) + Dy = IndexedBase('Dy', shape=(len_y-1,)) # noqa i = Idx('i', len_y-1) expr1 = (y[i+1]-y[i])/(x[i+1]-x[i]) @@ -326,10 +278,11 @@ def test_cse_Indexed(): assert len(replacements) > 0 -def test_Piecewise(): +def test_Piecewise(): # noqa f = Piecewise((-z + x*y, Eq(y, 0)), (-z - x*y, True)) ans = cse(f) - actual_ans = ([(x0, -z), (x1, x*y)], [Piecewise((x0+x1, Eq(y, 0)), (x0 - x1, True))]) + actual_ans = ([(x0, -z), (x1, x*y)], + [Piecewise((x0+x1, Eq(y, 0)), (x0 - x1, True))]) assert ans == actual_ans @@ -352,27 +305,27 @@ def test_name_conflict_cust_symbols(): def test_symbols_exhausted_error(): l = cos(x+y)+x+y+cos(w+y)+sin(w+y) sym = [x, y, z] - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError): print(cse(l, symbols=sym)) def test_issue_7840(): # daveknippers' example - C393 = sympify( \ + C393 = sympify( # noqa 'Piecewise((C391 - 1.65, C390 < 0.5), (Piecewise((C391 - 1.65, \ C391 > 2.35), (C392, True)), True))' ) - C391 = sympify( \ + C391 = sympify( # noqa 'Piecewise((2.05*C390**(-1.03), C390 < 0.5), (2.5*C390**(-0.625), True))' ) - C393 = C393.subs('C391',C391) + C393 = C393.subs('C391',C391) # noqa # simple substitution sub = {} sub['C390'] = 0.703451854 sub['C392'] = 1.01417794 ss_answer = C393.subs(sub) # cse - substitutions,new_eqn = cse(C393) + substitutions, new_eqn = cse(C393) for pair in substitutions: sub[pair[0].name] = pair[1].subs(sub) cse_answer = new_eqn[0].subs(sub) -- GitLab From 204cf3fca368e02fa348ed9574f3342a72226c3b Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Tue, 10 Jan 2017 08:08:28 -0600 Subject: [PATCH 03/10] Remove cached_cse. This was not working well anyway, since the caching used the pytools disk dict which assumes that hashes are persistent across runs. --- sumpy/assignment_collection.py | 52 ---------------------------------- 1 file changed, 52 deletions(-) diff --git a/sumpy/assignment_collection.py b/sumpy/assignment_collection.py index ce19f418..af568ddb 100644 --- a/sumpy/assignment_collection.py +++ b/sumpy/assignment_collection.py @@ -71,58 +71,6 @@ class _SymbolGenerator(object): __next__ = next -# {{{ CSE caching - -def _map_cse_result(mapper, cse_result): - replacements, reduced_exprs = cse_result - - new_replacements = [ - (sym, mapper(repl)) - for sym, repl in replacements] - new_reduced_exprs = [ - mapper(expr) - for expr in reduced_exprs] - - return new_replacements, new_reduced_exprs - - -def cached_cse(exprs, symbols): - assert isinstance(symbols, _SymbolGenerator) - - from pytools.diskdict import get_disk_dict - cache_dict = get_disk_dict("sumpy-cse-cache", version=1) - - # sympy expressions don't pickle properly :( - # (as of Jun 7, 2013) - # https://code.google.com/p/sympy/issues/detail?id=1198 - - from pymbolic.interop.sympy import ( - SympyToPymbolicMapper, - PymbolicToSympyMapper) - - s2p = SympyToPymbolicMapper() - p2s = PymbolicToSympyMapper() - - print(exprs) - key_exprs = tuple(s2p(expr) for expr in exprs) - - key = (key_exprs) - - print(key) - - try: - result = cache_dict[key] - except KeyError: - from sumpy.cse import cse - result = cse(exprs, symbols) - cache_dict[key] = _map_cse_result(s2p, result) - return result - else: - return _map_cse_result(p2s, result) - -# }}} - - # {{{ collection of assignments class SymbolicAssignmentCollection(object): -- GitLab From e94735550a20d3f119cce4c224dea4677a1b98b3 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Tue, 10 Jan 2017 08:16:42 -0600 Subject: [PATCH 04/10] Documentation, various updates. --- doc/codegen.rst | 1 + sumpy/assignment_collection.py | 9 ++++--- sumpy/cse.py | 48 +++++++++++++++++----------------- 3 files changed, 30 insertions(+), 28 deletions(-) diff --git a/doc/codegen.rst b/doc/codegen.rst index 66fc5b2d..2d5d6f4d 100644 --- a/doc/codegen.rst +++ b/doc/codegen.rst @@ -3,3 +3,4 @@ Code Generation .. automodule:: sumpy.codegen .. automodule:: sumpy.assignment_collection +.. automodule:: sumpy.cse diff --git a/sumpy/assignment_collection.py b/sumpy/assignment_collection.py index af568ddb..00c0266b 100644 --- a/sumpy/assignment_collection.py +++ b/sumpy/assignment_collection.py @@ -160,17 +160,18 @@ class SymbolicAssignmentCollection(object): def run_global_cse(self, extra_exprs=[]): logger.info("common subexpression elimination: start") - assign_names = list(self.assignments) + assign_names = sorted(self.assign_names) assign_exprs = [self.assignments[name] for name in assign_names] # Options here: - # - cached_cse: Uses on-disk cache to speed up CSE. # - checked_cse: if you mistrust the result of the cse. # Uses maxima to verify. - # - sp.cse: The underlying sympy thing. + # - sp.cse: The sympy thing. + # - sumpy.cse.cse: Based on sympy, designed to go faster. #from sumpy.symbolic import checked_cse - new_assignments, new_exprs = cached_cse(assign_exprs + extra_exprs, + from sumpy.cse import cse + new_assignments, new_exprs = cse(assign_exprs + extra_exprs, symbols=self.symbol_generator) new_assign_exprs = new_exprs[:len(assign_exprs)] diff --git a/sumpy/cse.py b/sumpy/cse.py index 9bea6e56..466157c8 100644 --- a/sumpy/cse.py +++ b/sumpy/cse.py @@ -68,6 +68,16 @@ from sympy.core.compatibility import iterable from sympy.utilities.iterables import numbered_symbols +__doc__ = """ + +Common subexpression elimination +-------------------------------- + +.. autofunction:: cse + +""" + + # {{{ cse pre/postprocessing def preprocess_for_cse(expr, optimizations): @@ -111,7 +121,8 @@ def postprocess_for_cse(expr, optimizations): class FuncArgTracker(object): """ - A class which manages an inverse mapping from arguments to functions. + A class which manages a mapping from functions to arguments and an inverse + mapping from arguments to functions. """ def __init__(self, funcs): @@ -455,39 +466,28 @@ def tree_cse(exprs, symbols, opt_subs=None): def cse(exprs, symbols=None, optimizations=None): - """Perform common subexpression elimination on an expression. + """ + Perform common subexpression elimination on an expression. :arg exprs: A list of sympy expressions, or a single sympy expression to reduce :arg symbols: An iterator yielding unique Symbols used to label the common subexpressions which are pulled out. The ``numbered_symbols`` - generator is useful. The default is a stream of symbols of the form - "x0", "x1", etc. This must be an infinite iterator. + generator from sympy is useful. The default is a stream of symbols of the + form "x0", "x1", etc. This must be an infinite iterator. :arg optimizations: A list of (callable, callable) pairs consisting of (preprocessor, postprocessor) pairs of external optimization functions. - Returns - ======= - - replacements : list of (Symbol, expression) pairs - All of the common subexpressions that were replaced. Subexpressions - earlier in this list might show up in subexpressions later in this - list. - reduced_exprs : list of sympy expressions - The reduced expressions with all of the replacements above. + :return: This returns a pair ``(replacements, reduced_exprs)``. + * ``replacements`` is a list of (Symbol, expression) pairs consisting of + all of the common subexpressions that were replaced. Subexpressions + earlier in this list might show up in subexpressions later in this list. + * ``reduced_exprs`` is a list of sympy expressions. This contains the + reduced expressions with all of the replacements above. """ if isinstance(exprs, Basic): exprs = [exprs] - import time - - start = time.time() - - import logging - - logger = logging.getLogger(__name__) - logger.info("cse: start") - exprs = list(exprs) if optimizations is None: @@ -510,8 +510,6 @@ def cse(exprs, symbols=None, optimizations=None): # Find other optimization opportunities. opt_subs = opt_cse(reduced_exprs) - logger.info("cse: done after {}".format(time.time() - start)) - # Main CSE algorithm. replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs) @@ -522,3 +520,5 @@ def cse(exprs, symbols=None, optimizations=None): reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs] return replacements, reduced_exprs + +# vim: fdm=marker -- GitLab From 8f184c6cb6c176546c1fc90a58c85be664efdc42 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Tue, 10 Jan 2017 09:06:44 -0600 Subject: [PATCH 05/10] SymbolGenerator: Remove O(n^2) behavior. --- examples/m2l-timing.py | 30 -------------------------- sumpy/assignment_collection.py | 39 +++++++++++++++++----------------- 2 files changed, 20 insertions(+), 49 deletions(-) delete mode 100644 examples/m2l-timing.py diff --git a/examples/m2l-timing.py b/examples/m2l-timing.py deleted file mode 100644 index 8596f891..00000000 --- a/examples/m2l-timing.py +++ /dev/null @@ -1,30 +0,0 @@ - -def test_m2l_creation(ctx, mpole_expn_class, local_expn_class, knl, order): - from sympy.core.cache import clear_cache - clear_cache() - m_expn = mpole_expn_class(knl, order=order) - l_expn = local_expn_class(knl, order=order) - from sumpy.e2e import E2EFromCSR - m2l = E2EFromCSR(ctx, m_expn, l_expn) - import time - start = time.time() - m2l.run_translation_and_cse() - return time.time() - start - -if __name__ == "__main__": - import logging - logging.basicConfig(level=logging.INFO) - from sumpy.kernel import HelmholtzKernel, LaplaceKernel - - import pyopencl as cl - ctx = cl._csc() - from sumpy.expansion.local import HelmholtzConformingVolumeTaylorLocalExpansion as LExpn - from sumpy.expansion.multipole import HelmholtzConformingVolumeTaylorMultipoleExpansion as MExpn - #from sumpy.expansion.local import H2DLocalExpansion as LExpn - #from sumpy.expansion.multipole import H2DMultipoleExpansion as MExpn - results = [] - for order in range(1, 2): - results.append((order, test_m2l_creation(ctx, MExpn, LExpn, HelmholtzKernel(2, 'k'), order))) - print("order\ttime (s)") - for order, time in results: - print("{}\t{:.2f}".format(order, time)) diff --git a/sumpy/assignment_collection.py b/sumpy/assignment_collection.py index 00c0266b..2bfdbc2a 100644 --- a/sumpy/assignment_collection.py +++ b/sumpy/assignment_collection.py @@ -41,26 +41,29 @@ Manipulating batches of assignments """ -def _generate_unique_possibilities(prefix): - yield prefix - - try_num = 0 - while True: - yield "%s_%d" % (prefix, try_num) - try_num += 1 - - class _SymbolGenerator(object): + def __init__(self, taken_symbols): self.taken_symbols = taken_symbols - self.generated_names = set() + from collections import defaultdict + self.base_to_count = defaultdict(lambda: 0) def __call__(self, base="expr"): - for id_str in _generate_unique_possibilities(base): - if id_str not in self.taken_symbols \ - and id_str not in self.generated_names: - self.generated_names.add(id_str) - return sp.Symbol(id_str) + count = self.base_to_count[base] + + def make_id_str(base, count): + return "{base}{suffix}".format( + base=base, + suffix="" if count == 0 else "_" + str(count - 1)) + + id_str = make_id_str(base, count) + while id_str in self.taken_symbols: + count += 1 + id_str = make_id_str(base, count) + + self.base_to_count[base] = count + 1 + + return sp.Symbol(id_str) def __iter__(self): return self @@ -149,9 +152,7 @@ class SymbolicAssignmentCollection(object): """Assign *expr* to a new variable whose name is based on *name_base*. Return the new variable name. """ - for new_name in _generate_unique_possibilities(name_base): - if new_name not in self.assignments: - break + new_name = self.symbol_generator(name_base).name self.add_assignment(new_name, expr) self.user_symbols.add(new_name) @@ -160,7 +161,7 @@ class SymbolicAssignmentCollection(object): def run_global_cse(self, extra_exprs=[]): logger.info("common subexpression elimination: start") - assign_names = sorted(self.assign_names) + assign_names = sorted(self.assignments) assign_exprs = [self.assignments[name] for name in assign_names] # Options here: -- GitLab From 4f9b02ae8a0f743a220063ea24f5e00104a7b2f8 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Tue, 10 Jan 2017 09:13:52 -0600 Subject: [PATCH 06/10] Add a timing message. --- sumpy/assignment_collection.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sumpy/assignment_collection.py b/sumpy/assignment_collection.py index 2bfdbc2a..840b04da 100644 --- a/sumpy/assignment_collection.py +++ b/sumpy/assignment_collection.py @@ -159,6 +159,9 @@ class SymbolicAssignmentCollection(object): return new_name def run_global_cse(self, extra_exprs=[]): + import time + start_time = time.time() + logger.info("common subexpression elimination: start") assign_names = sorted(self.assignments) @@ -185,7 +188,8 @@ class SymbolicAssignmentCollection(object): assert isinstance(name, sp.Symbol) self.add_assignment(name.name, value) - logger.info("common subexpression elimination: done") + logger.info("common subexpression elimination: done after {dur:.2f} s" + .format(dur=time.time() - start_time)) return new_extra_exprs def kill_trivial_assignments(self, exprs): -- GitLab From add4ee09fc625e3f9a5981e1ff7b32dfea7fd689 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Tue, 10 Jan 2017 09:20:02 -0600 Subject: [PATCH 07/10] Flake8 fixes. --- sumpy/e2e.py | 17 ----------------- test/test_cse.py | 8 ++------ 2 files changed, 2 insertions(+), 23 deletions(-) diff --git a/sumpy/e2e.py b/sumpy/e2e.py index 7c28c968..f5b34940 100644 --- a/sumpy/e2e.py +++ b/sumpy/e2e.py @@ -76,23 +76,6 @@ class E2EBase(KernelCacheWrapper): self.dim = src_expansion.dim - def run_translation_and_cse(self): - from sumpy.symbolic import make_sympy_vector - dvec = make_sympy_vector("d", self.dim) - - src_coeff_exprs = [sp.Symbol("src_coeff%d" % i) - for i in range(len(self.src_expansion))] - - from sumpy.assignment_collection import SymbolicAssignmentCollection - sac = SymbolicAssignmentCollection() - tgt_coeff_names = [ - sac.assign_unique("coeff%d" % i, coeff_i) - for i, coeff_i in enumerate( - self.tgt_expansion.translate_from( - self.src_expansion, src_coeff_exprs, dvec))] - - sac.run_global_cse() - def get_translation_loopy_insns(self): from sumpy.symbolic import make_sympy_vector dvec = make_sympy_vector("d", self.dim) diff --git a/test/test_cse.py b/test/test_cse.py index b6fefdbf..f5fbcaaf 100644 --- a/test/test_cse.py +++ b/test/test_cse.py @@ -64,18 +64,14 @@ DAMAGE. # }}} import pytest -import itertools import sys -from sympy import (Add, Pow, Symbol, exp, sqrt, symbols, sympify, S, cos, - sin, Eq, Function, Tuple, CRootOf, IndexedBase, Idx, - Piecewise) +from sympy import (Add, Pow, exp, sqrt, symbols, sympify, S, cos, sin, Eq, + Function, Tuple, CRootOf, IndexedBase, Idx, Piecewise) from sympy.simplify.cse_opts import sub_pre, sub_post from sympy.functions.special.hyper import meijerg from sympy.simplify import cse_opts -from sympy.core.compatibility import range - from sumpy.cse import ( cse, preprocess_for_cse, postprocess_for_cse) -- GitLab From 335b822924a8f23f871761088a09b78710541b2c Mon Sep 17 00:00:00 2001 From: repo sync repo start master --all Date: Wed, 11 Jan 2017 02:15:23 -0600 Subject: [PATCH 08/10] CSE: Take into account the common subexpressions to speed up looking for excluded symbols. --- sumpy/cse.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sumpy/cse.py b/sumpy/cse.py index 466157c8..f9a4daba 100644 --- a/sumpy/cse.py +++ b/sumpy/cse.py @@ -62,7 +62,7 @@ DAMAGE. # }}} -from sympy.core import Basic, Mul, Add, Pow, Symbol +from sympy.core import Basic, Mul, Add, Pow from sympy.core.function import _coeff_isneg from sympy.core.compatibility import iterable from sympy.utilities.iterables import numbered_symbols @@ -371,17 +371,20 @@ def tree_cse(exprs, symbols, opt_subs=None): if opt_subs is None: opt_subs = dict() - # {{{ find repeated sub-expressions + # {{{ find repeated sub-expressions and used symbols to_eliminate = set() seen_subexp = set() + excluded_symbols = set() def find_repeated(expr): if not isinstance(expr, Basic): return if expr.is_Atom: + if expr.is_Symbol: + excluded_symbols.add(expr) return if iterable(expr): @@ -410,6 +413,9 @@ def tree_cse(exprs, symbols, opt_subs=None): # {{{ rebuild tree + # Remove symbols from the generator that conflict with names in the expressions. + symbols = (symbol for symbol in symbols if symbol not in excluded_symbols) + replacements = [] subs = dict() @@ -503,10 +509,6 @@ def cse(exprs, symbols=None, optimizations=None): # an actual iterator. symbols = iter(symbols) - # Remove symbols from the generator that conflict with names in the expressions. - excluded_symbols = set().union(*[expr.atoms(Symbol) for expr in reduced_exprs]) - symbols = (symbol for symbol in symbols if symbol not in excluded_symbols) - # Find other optimization opportunities. opt_subs = opt_cse(reduced_exprs) -- GitLab From e8c38e74f0668b3d40e6923ebb582be5fc88c554 Mon Sep 17 00:00:00 2001 From: repo sync repo start master --all Date: Wed, 11 Jan 2017 04:46:05 -0600 Subject: [PATCH 09/10] CSE: com_args can be empty too. --- sumpy/cse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sumpy/cse.py b/sumpy/cse.py index f9a4daba..01b61e83 100644 --- a/sumpy/cse.py +++ b/sumpy/cse.py @@ -253,7 +253,7 @@ def match_common_args(func_class, funcs, opt_subs): com_args = arg_tracker.func_to_argset[i].intersection( arg_tracker.func_to_argset[j]) - if len(com_args) == 1: + if len(com_args) <= 1: # This may happen if a set of common arguments was already # combined in a previous iteration. continue -- GitLab From ef36a0f535b4fcaae80ca16e815daeb81d41af58 Mon Sep 17 00:00:00 2001 From: repo sync repo start master --all Date: Wed, 11 Jan 2017 04:54:17 -0600 Subject: [PATCH 10/10] Update license part to make modification/original license distinction clear. --- sumpy/cse.py | 6 +++++- test/test_cse.py | 7 +++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/sumpy/cse.py b/sumpy/cse.py index 01b61e83..ca82b971 100644 --- a/sumpy/cse.py +++ b/sumpy/cse.py @@ -8,6 +8,10 @@ Copyright (C) 2006-2016 SymPy Development Team # {{{ license and original license __license__ = """ +Modifications from original are under the following license: + +Copyright (C) 2017 Matt Wala + Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights @@ -28,7 +32,7 @@ THE SOFTWARE. =========================================================================== -Based on sympy/simplify/cse_main.py from SymPy 1.0, original license as follows: +Based on sympy/simplify/cse_main.py from SymPy 1.0, license as follows: Copyright (c) 2006-2016 SymPy Development Team diff --git a/test/test_cse.py b/test/test_cse.py index f5fbcaaf..b5697ac4 100644 --- a/test/test_cse.py +++ b/test/test_cse.py @@ -8,6 +8,10 @@ Copyright (C) 2006-2016 SymPy Development Team # {{{ license and original license __license__ = """ +Modifications from original are under the following license: + +Copyright (C) 2017 Matt Wala + Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights @@ -28,8 +32,7 @@ THE SOFTWARE. =========================================================================== -Based on sympy/simplify/tests/test_cse.py from SymPy 1.0, original license as -follows: +Based on sympy/simplify/tests/test_cse.py from SymPy 1.0, license as follows: Copyright (c) 2006-2016 SymPy Development Team -- GitLab