diff --git a/doc/codegen.rst b/doc/codegen.rst index 66fc5b2dccac2f1f3c16918d7325854004e54d31..2d5d6f4dcdc14f7259610585ffdd1478a807358d 100644 --- a/doc/codegen.rst +++ b/doc/codegen.rst @@ -3,3 +3,4 @@ Code Generation .. automodule:: sumpy.codegen .. automodule:: sumpy.assignment_collection +.. automodule:: sumpy.cse diff --git a/sumpy/assignment_collection.py b/sumpy/assignment_collection.py index f2bc607cb951bc02d8b1d3eb058a924230074f8c..840b04da0c65ad4a425cee47ab88964caf8a424c 100644 --- a/sumpy/assignment_collection.py +++ b/sumpy/assignment_collection.py @@ -41,26 +41,29 @@ Manipulating batches of assignments """ -def _generate_unique_possibilities(prefix): - yield prefix - - try_num = 0 - while True: - yield "%s_%d" % (prefix, try_num) - try_num += 1 - - class _SymbolGenerator(object): + def __init__(self, taken_symbols): self.taken_symbols = taken_symbols - self.generated_names = set() + from collections import defaultdict + self.base_to_count = defaultdict(lambda: 0) def __call__(self, base="expr"): - for id_str in _generate_unique_possibilities(base): - if id_str not in self.taken_symbols \ - and id_str not in self.generated_names: - self.generated_names.add(id_str) - return sp.Symbol(id_str) + count = self.base_to_count[base] + + def make_id_str(base, count): + return "{base}{suffix}".format( + base=base, + suffix="" if count == 0 else "_" + str(count - 1)) + + id_str = make_id_str(base, count) + while id_str in self.taken_symbols: + count += 1 + id_str = make_id_str(base, count) + + self.base_to_count[base] = count + 1 + + return sp.Symbol(id_str) def __iter__(self): return self @@ -71,55 +74,6 @@ class _SymbolGenerator(object): __next__ = next -# {{{ CSE caching - -def _map_cse_result(mapper, cse_result): - replacements, reduced_exprs = cse_result - - new_replacements = [ - (sym, mapper(repl)) - for sym, repl in replacements] - new_reduced_exprs = [ - mapper(expr) - for expr in reduced_exprs] - - return new_replacements, new_reduced_exprs - - -def cached_cse(exprs, symbols): - assert isinstance(symbols, _SymbolGenerator) - - from pytools.diskdict import get_disk_dict - cache_dict = get_disk_dict("sumpy-cse-cache", version=1) - - # sympy expressions don't pickle properly :( - # (as of Jun 7, 2013) - # https://code.google.com/p/sympy/issues/detail?id=1198 - - from pymbolic.interop.sympy import ( - SympyToPymbolicMapper, - PymbolicToSympyMapper) - - s2p = SympyToPymbolicMapper() - p2s = PymbolicToSympyMapper() - - key_exprs = tuple(s2p(expr) for expr in exprs) - - key = (key_exprs, frozenset(symbols.taken_symbols), - frozenset(symbols.generated_names)) - - try: - result = cache_dict[key] - except KeyError: - result = sp.cse(exprs, symbols) - cache_dict[key] = _map_cse_result(s2p, result) - return result - else: - return _map_cse_result(p2s, result) - -# }}} - - # {{{ collection of assignments class SymbolicAssignmentCollection(object): @@ -198,28 +152,30 @@ class SymbolicAssignmentCollection(object): """Assign *expr* to a new variable whose name is based on *name_base*. Return the new variable name. """ - for new_name in _generate_unique_possibilities(name_base): - if new_name not in self.assignments: - break + new_name = self.symbol_generator(name_base).name self.add_assignment(new_name, expr) self.user_symbols.add(new_name) return new_name def run_global_cse(self, extra_exprs=[]): + import time + start_time = time.time() + logger.info("common subexpression elimination: start") - assign_names = list(self.assignments) + assign_names = sorted(self.assignments) assign_exprs = [self.assignments[name] for name in assign_names] # Options here: - # - cached_cse: Uses on-disk cache to speed up CSE. # - checked_cse: if you mistrust the result of the cse. # Uses maxima to verify. - # - sp.cse: The underlying sympy thing. + # - sp.cse: The sympy thing. + # - sumpy.cse.cse: Based on sympy, designed to go faster. #from sumpy.symbolic import checked_cse - new_assignments, new_exprs = cached_cse(assign_exprs + extra_exprs, + from sumpy.cse import cse + new_assignments, new_exprs = cse(assign_exprs + extra_exprs, symbols=self.symbol_generator) new_assign_exprs = new_exprs[:len(assign_exprs)] @@ -232,7 +188,8 @@ class SymbolicAssignmentCollection(object): assert isinstance(name, sp.Symbol) self.add_assignment(name.name, value) - logger.info("common subexpression elimination: done") + logger.info("common subexpression elimination: done after {dur:.2f} s" + .format(dur=time.time() - start_time)) return new_extra_exprs def kill_trivial_assignments(self, exprs): diff --git a/sumpy/cse.py b/sumpy/cse.py new file mode 100644 index 0000000000000000000000000000000000000000..ca82b971c018326d95890bdebfd17e43443849a0 --- /dev/null +++ b/sumpy/cse.py @@ -0,0 +1,530 @@ +from __future__ import print_function, division + +__copyright__ = """ +Copyright (C) 2017 Matt Wala +Copyright (C) 2006-2016 SymPy Development Team +""" + +# {{{ license and original license + +__license__ = """ +Modifications from original are under the following license: + +Copyright (C) 2017 Matt Wala + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +=========================================================================== + +Based on sympy/simplify/cse_main.py from SymPy 1.0, license as follows: + +Copyright (c) 2006-2016 SymPy Development Team + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + a. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + b. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + c. Neither the name of SymPy nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +DAMAGE. +""" + +# }}} + +from sympy.core import Basic, Mul, Add, Pow +from sympy.core.function import _coeff_isneg +from sympy.core.compatibility import iterable +from sympy.utilities.iterables import numbered_symbols + + +__doc__ = """ + +Common subexpression elimination +-------------------------------- + +.. autofunction:: cse + +""" + + +# {{{ cse pre/postprocessing + +def preprocess_for_cse(expr, optimizations): + """ + Preprocess an expression to optimize for common subexpression elimination. + + :arg expr: A sympy expression, the target expression to optimize. + :arg optimizations: A list of (callable, callable) pairs, + the (preprocessor, postprocessor) pairs. + + :return: The transformed expression. + """ + for pre, post in optimizations: + if pre is not None: + expr = pre(expr) + return expr + + +def postprocess_for_cse(expr, optimizations): + """ + Postprocess an expression after common subexpression elimination to + return the expression to canonical sympy form. + + :arg expr: sympy expression, the target expression to transform. + :arg optimizations: A list of (callable, callable) pairs (optional), + the (preprocessor, postprocessor) pairs. The postprocessors will be + applied in reversed order to undo the effects of the preprocessors + correctly. + + :return: The transformed expression. + """ + for pre, post in reversed(optimizations): + if post is not None: + expr = post(expr) + return expr + +# }}} + + +# {{{ opt cse + +class FuncArgTracker(object): + """ + A class which manages a mapping from functions to arguments and an inverse + mapping from arguments to functions. + """ + + def __init__(self, funcs): + # To minimize the number of symbolic comparisons, all function arguments + # get assigned a value number. + self.value_numbers = {} + self.value_number_to_value = [] + + # Both of these maps use integer indices for arguments / functions. + self.arg_to_funcset = [] + self.func_to_argset = [] + + for func_i, func in enumerate(funcs): + func_argset = set() + + for func_arg in func.args: + arg_number = self.get_or_add_value_number(func_arg) + func_argset.add(arg_number) + self.arg_to_funcset[arg_number].add(func_i) + + self.func_to_argset.append(func_argset) + + def get_args_in_value_order(self, argset): + """ + Return the list of arguments in sorted order according to their value + numbers. + """ + return [self.value_number_to_value[argn] for argn in sorted(argset)] + + def get_or_add_value_number(self, value): + """ + Return the value number for the given argument. + """ + nvalues = len(self.value_numbers) + value_number = self.value_numbers.setdefault(value, nvalues) + if value_number == nvalues: + self.value_number_to_value.append(value) + self.arg_to_funcset.append(set()) + return value_number + + def stop_arg_tracking(self, func_i): + """ + Remove the function func_i from the argument to function mapping. + """ + for arg in self.func_to_argset[func_i]: + self.arg_to_funcset[arg].remove(func_i) + + def gen_common_arg_candidates(self, argset, min_func_i, threshold=2): + """ + Generate the list of functions which have at least `threshold` arguments in + common from `argset`. + """ + from collections import defaultdict + count_map = defaultdict(lambda: 0) + + for arg in argset: + for func_i in self.arg_to_funcset[arg]: + if func_i >= min_func_i: + count_map[func_i] += 1 + + count_map_keys_in_order = sorted( + key for key, val in count_map.items() + if val >= threshold) + + for item in count_map_keys_in_order: + yield item + + def gen_subset_candidates(self, argset, min_func_i): + """ + Generate the list of functions whose set of arguments contains `argset`. + """ + iarg = iter(argset) + + indices = set( + fi for fi in self.arg_to_funcset[next(iarg)] + if fi >= min_func_i) + + for arg in iarg: + indices &= self.arg_to_funcset[arg] + + for item in indices: + yield item + + def update_func_argset(self, func_i, new_argset): + """ + Update a function with a new set of arguments. + """ + new_args = set(new_argset) + old_args = self.func_to_argset[func_i] + + for deleted_arg in old_args - new_args: + self.arg_to_funcset[deleted_arg].remove(func_i) + for added_arg in new_args - old_args: + self.arg_to_funcset[added_arg].add(func_i) + + self.func_to_argset[func_i].clear() + self.func_to_argset[func_i].update(new_args) + + +def match_common_args(func_class, funcs, opt_subs): + """ + Recognize and extract common subexpressions of function arguments within a + set of function calls. For instance, for the following function calls:: + + x + z + y + sin(x + y) + + this will extract a common subexpression of `x + y`:: + + w = x + y + w + z + sin(w) + + The function we work with is assumed to be associative and commutative. + + :arg func_class: The function class (e.g. Add, Mul) + :arg funcs: A list of function calls + :arg opt_subs: A dictionary of substitutions which this function may update + """ + arg_tracker = FuncArgTracker(funcs) + + changed = set() + + for i in range(len(funcs)): + for j in arg_tracker.gen_common_arg_candidates( + arg_tracker.func_to_argset[i], i + 1, threshold=2): + + com_args = arg_tracker.func_to_argset[i].intersection( + arg_tracker.func_to_argset[j]) + + if len(com_args) <= 1: + # This may happen if a set of common arguments was already + # combined in a previous iteration. + continue + + com_func = func_class(*arg_tracker.get_args_in_value_order(com_args)) + com_func_number = arg_tracker.get_or_add_value_number(com_func) + + # For all sets, replace the common symbols by the function + # over them, to allow recursive matches. + + diff_i = arg_tracker.func_to_argset[i].difference(com_args) + arg_tracker.update_func_argset(i, diff_i | set([com_func_number])) + changed.add(i) + + diff_j = arg_tracker.func_to_argset[j].difference(com_args) + arg_tracker.update_func_argset(j, diff_j | set([com_func_number])) + changed.add(j) + + for k in arg_tracker.gen_subset_candidates(com_args, j + 1): + diff_k = arg_tracker.func_to_argset[k].difference(com_args) + arg_tracker.update_func_argset(k, diff_k | set([com_func_number])) + changed.add(k) + + if i in changed: + opt_subs[funcs[i]] = func_class( + *arg_tracker.get_args_in_value_order(arg_tracker.func_to_argset[i]), + evaluate=False) + + arg_tracker.stop_arg_tracking(i) + + +def opt_cse(exprs): + """ + Find optimization opportunities in Adds, Muls, Pows and negative coefficient + Muls + + :arg exprs: A list of sympy expressions: the expressions to optimize. + :return: A dictionary of expression substitutions + """ + opt_subs = dict() + + adds = [] + muls = [] + + seen_subexp = set() + + # {{{ look for optimization opportunities, clean up minus signs + + def find_opts(expr): + + if not isinstance(expr, Basic): + return + + if expr.is_Atom: + return + + if iterable(expr): + for item in expr: + find_opts(item) + return + + if expr in seen_subexp: + return expr + + seen_subexp.add(expr) + + for arg in expr.args: + find_opts(arg) + + if _coeff_isneg(expr): + neg_expr = -expr + if not neg_expr.is_Atom: + opt_subs[expr] = Mul(-1, neg_expr, evaluate=False) + seen_subexp.add(neg_expr) + expr = neg_expr + + if isinstance(expr, Mul): + muls.append(expr) + + elif isinstance(expr, Add): + adds.append(expr) + + elif isinstance(expr, Pow): + base, exp = expr.args + if _coeff_isneg(exp): + opt_subs[expr] = Pow(Pow(base, -exp), -1, evaluate=False) + + # }}} + + for e in exprs: + if isinstance(e, Basic): + find_opts(e) + + match_common_args(Add, adds, opt_subs) + match_common_args(Mul, muls, opt_subs) + return opt_subs + +# }}} + + +# {{{ tree cse + +def tree_cse(exprs, symbols, opt_subs=None): + """ + Perform raw CSE on an expression tree, taking opt_subs into account. + + :arg exprs: A list of sympy expressions to reduce + :arg symbols: An infinite iterator yielding unique Symbols used to label + the common subexpressions which are pulled out. + :arg opt_subs: A dictionary of expression substitutions to be + substituted before any CSE action is performed. + + :return: A pair (replacements, reduced exprs) + """ + if opt_subs is None: + opt_subs = dict() + + # {{{ find repeated sub-expressions and used symbols + + to_eliminate = set() + + seen_subexp = set() + excluded_symbols = set() + + def find_repeated(expr): + if not isinstance(expr, Basic): + return + + if expr.is_Atom: + if expr.is_Symbol: + excluded_symbols.add(expr) + return + + if iterable(expr): + args = expr + + else: + if expr in seen_subexp: + to_eliminate.add(expr) + return + + seen_subexp.add(expr) + + if expr in opt_subs: + expr = opt_subs[expr] + + args = expr.args + + for arg in args: + find_repeated(arg) + + # }}} + + for e in exprs: + if isinstance(e, Basic): + find_repeated(e) + + # {{{ rebuild tree + + # Remove symbols from the generator that conflict with names in the expressions. + symbols = (symbol for symbol in symbols if symbol not in excluded_symbols) + + replacements = [] + + subs = dict() + + def rebuild(expr): + if not isinstance(expr, Basic): + return expr + + if not expr.args: + return expr + + if iterable(expr): + new_args = [rebuild(arg) for arg in expr] + return expr.func(*new_args) + + if expr in subs: + return subs[expr] + + orig_expr = expr + if expr in opt_subs: + expr = opt_subs[expr] + + new_args = [rebuild(arg) for arg in expr.args] + if new_args != expr.args: + new_expr = expr.func(*new_args) + else: + new_expr = expr + + if orig_expr in to_eliminate: + try: + sym = next(symbols) + except StopIteration: + raise ValueError("Symbols iterator ran out of symbols.") + + subs[orig_expr] = sym + replacements.append((sym, new_expr)) + return sym + + return new_expr + + # }}} + + reduced_exprs = [] + for e in exprs: + if isinstance(e, Basic): + reduced_e = rebuild(e) + else: + reduced_e = e + reduced_exprs.append(reduced_e) + + return replacements, reduced_exprs + +# }}} + + +def cse(exprs, symbols=None, optimizations=None): + """ + Perform common subexpression elimination on an expression. + + :arg exprs: A list of sympy expressions, or a single sympy expression to reduce + :arg symbols: An iterator yielding unique Symbols used to label the + common subexpressions which are pulled out. The ``numbered_symbols`` + generator from sympy is useful. The default is a stream of symbols of the + form "x0", "x1", etc. This must be an infinite iterator. + :arg optimizations: A list of (callable, callable) pairs consisting of + (preprocessor, postprocessor) pairs of external optimization functions. + + :return: This returns a pair ``(replacements, reduced_exprs)``. + + * ``replacements`` is a list of (Symbol, expression) pairs consisting of + all of the common subexpressions that were replaced. Subexpressions + earlier in this list might show up in subexpressions later in this list. + * ``reduced_exprs`` is a list of sympy expressions. This contains the + reduced expressions with all of the replacements above. + """ + if isinstance(exprs, Basic): + exprs = [exprs] + + exprs = list(exprs) + + if optimizations is None: + optimizations = [] + + # Preprocess the expressions to give us better optimization opportunities. + reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs] + + if symbols is None: + symbols = numbered_symbols() + else: + # In case we get passed an iterable with an __iter__ method instead of + # an actual iterator. + symbols = iter(symbols) + + # Find other optimization opportunities. + opt_subs = opt_cse(reduced_exprs) + + # Main CSE algorithm. + replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs) + + # Postprocess the expressions to return the expressions to canonical form. + for i, (sym, subtree) in enumerate(replacements): + subtree = postprocess_for_cse(subtree, optimizations) + replacements[i] = (sym, subtree) + reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs] + + return replacements, reduced_exprs + +# vim: fdm=marker diff --git a/test/test_cse.py b/test/test_cse.py new file mode 100644 index 0000000000000000000000000000000000000000..b5697ac4fa2bf67a0f5ed6f8eafa64f9b379059a --- /dev/null +++ b/test/test_cse.py @@ -0,0 +1,355 @@ +from __future__ import print_function, division + +__copyright__ = """ +Copyright (C) 2017 Matt Wala +Copyright (C) 2006-2016 SymPy Development Team +""" + +# {{{ license and original license + +__license__ = """ +Modifications from original are under the following license: + +Copyright (C) 2017 Matt Wala + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +=========================================================================== + +Based on sympy/simplify/tests/test_cse.py from SymPy 1.0, license as follows: + +Copyright (c) 2006-2016 SymPy Development Team + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + a. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + b. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + c. Neither the name of SymPy nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +DAMAGE. +""" + +# }}} + +import pytest +import sys + +from sympy import (Add, Pow, exp, sqrt, symbols, sympify, S, cos, sin, Eq, + Function, Tuple, CRootOf, IndexedBase, Idx, Piecewise) +from sympy.simplify.cse_opts import sub_pre, sub_post +from sympy.functions.special.hyper import meijerg +from sympy.simplify import cse_opts + + +from sumpy.cse import ( + cse, preprocess_for_cse, postprocess_for_cse) + + +w, x, y, z = symbols('w,x,y,z') +x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12 = symbols('x:13') + + +# Dummy "optimization" functions for testing. + + +def opt1(expr): + return expr + y + + +def opt2(expr): + return expr*z + + +def test_preprocess_for_cse(): + assert preprocess_for_cse(x, [(opt1, None)]) == x + y + assert preprocess_for_cse(x, [(None, opt1)]) == x + assert preprocess_for_cse(x, [(None, None)]) == x + assert preprocess_for_cse(x, [(opt1, opt2)]) == x + y + assert preprocess_for_cse( + x, [(opt1, None), (opt2, None)]) == (x + y)*z + + +def test_postprocess_for_cse(): + assert postprocess_for_cse(x, [(opt1, None)]) == x + assert postprocess_for_cse(x, [(None, opt1)]) == x + y + assert postprocess_for_cse(x, [(None, None)]) == x + assert postprocess_for_cse(x, [(opt1, opt2)]) == x*z + # Note the reverse order of application. + assert postprocess_for_cse( + x, [(None, opt1), (None, opt2)]) == x*z + y + + +def test_cse_single(): + # Simple substitution. + e = Add(Pow(x + y, 2), sqrt(x + y)) + substs, reduced = cse([e]) + assert substs == [(x0, x + y)] + assert reduced == [sqrt(x0) + x0**2] + + +def test_cse_not_possible(): + # No substitution possible. + e = Add(x, y) + substs, reduced = cse([e]) + assert substs == [] + assert reduced == [x + y] + # issue 6329 + eq = (meijerg((1, 2), (y, 4), (5,), [], x) + + meijerg((1, 3), (y, 4), (5,), [], x)) + assert cse(eq) == ([], [eq]) + + +def test_nested_substitution(): + # Substitution within a substitution. + e = Add(Pow(w*x + y, 2), sqrt(w*x + y)) + substs, reduced = cse([e]) + assert substs == [(x0, w*x + y)] + assert reduced == [sqrt(x0) + x0**2] + + +def test_subtraction_opt(): + # Make sure subtraction is optimized. + e = (x - y)*(z - y) + exp((x - y)*(z - y)) + substs, reduced = cse( + [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) + assert substs == [(x0, (x - y)*(y - z))] + assert reduced == [-x0 + exp(-x0)] + e = -(x - y)*(z - y) + exp(-(x - y)*(z - y)) + substs, reduced = cse( + [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) + assert substs == [(x0, (x - y)*(y - z))] + assert reduced == [x0 + exp(x0)] + # issue 4077 + n = -1 + 1/x + e = n/x/(-n)**2 - 1/n/x + assert cse(e, optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) == \ + ([], [0]) + + +def test_multiple_expressions(): + e1 = (x + y)*z + e2 = (x + y)*w + substs, reduced = cse([e1, e2]) + assert substs == [(x0, x + y)] + assert reduced == [x0*z, x0*w] + l = [w*x*y + z, w*y] + substs, reduced = cse(l) + rsubsts, _ = cse(reversed(l)) + assert substs == rsubsts + assert reduced == [z + x*x0, x0] + l = [w*x*y, w*x*y + z, w*y] + substs, reduced = cse(l) + rsubsts, _ = cse(reversed(l)) + assert substs == rsubsts + assert reduced == [x1, x1 + z, x0] + l = [(x - z)*(y - z), x - z, y - z] + substs, reduced = cse(l) + rsubsts, _ = cse(reversed(l)) + assert substs == [(x0, -z), (x1, x + x0), (x2, x0 + y)] + assert rsubsts == [(x0, -z), (x1, x0 + y), (x2, x + x0)] + assert reduced == [x1*x2, x1, x2] + l = [w*y + w + x + y + z, w*x*y] + assert cse(l) == ([(x0, w*y)], [w + x + x0 + y + z, x*x0]) + assert cse([x + y, x + y + z]) == ([(x0, x + y)], [x0, z + x0]) + assert cse([x + y, x + z]) == ([], [x + y, x + z]) + assert cse([x*y, z + x*y, x*y*z + 3]) == \ + ([(x0, x*y)], [x0, z + x0, 3 + x0*z]) + + +def test_issue_4203(): + assert cse(sin(x**x)/x**x) == ([(x0, x**x)], [sin(x0)/x0]) + + +def test_dont_cse_tuples(): + from sympy import Subs + f = Function("f") + g = Function("g") + + name_val, (expr,) = cse( + Subs(f(x, y), (x, y), (0, 1)) + + Subs(g(x, y), (x, y), (0, 1))) + + assert name_val == [] + assert expr == (Subs(f(x, y), (x, y), (0, 1)) + + Subs(g(x, y), (x, y), (0, 1))) + + name_val, (expr,) = cse( + Subs(f(x, y), (x, y), (0, x + y)) + + Subs(g(x, y), (x, y), (0, x + y))) + + assert name_val == [(x0, x + y)] + assert expr == Subs(f(x, y), (x, y), (0, x0)) + \ + Subs(g(x, y), (x, y), (0, x0)) + + +def test_pow_invpow(): + assert cse(1/x**2 + x**2) == \ + ([(x0, x**2)], [x0 + 1/x0]) + assert cse(x**2 + (1 + 1/x**2)/x**2) == \ + ([(x0, x**2), (x1, 1/x0)], [x0 + x1*(x1 + 1)]) + assert cse(1/x**2 + (1 + 1/x**2)*x**2) == \ + ([(x0, x**2), (x1, 1/x0)], [x0*(x1 + 1) + x1]) + assert cse(cos(1/x**2) + sin(1/x**2)) == \ + ([(x0, x**(-2))], [sin(x0) + cos(x0)]) + assert cse(cos(x**2) + sin(x**2)) == \ + ([(x0, x**2)], [sin(x0) + cos(x0)]) + assert cse(y/(2 + x**2) + z/x**2/y) == \ + ([(x0, x**2)], [y/(x0 + 2) + z/(x0*y)]) + assert cse(exp(x**2) + x**2*cos(1/x**2)) == \ + ([(x0, x**2)], [x0*cos(1/x0) + exp(x0)]) + assert cse((1 + 1/x**2)/x**2) == \ + ([(x0, x**(-2))], [x0*(x0 + 1)]) + assert cse(x**(2*y) + x**(-2*y)) == \ + ([(x0, x**(2*y))], [x0 + 1/x0]) + + +def test_issue_4499(): + # previously, this gave 16 constants + from sympy.abc import a, b + B = Function('B') # noqa + G = Function('G') # noqa + t = Tuple(* + (a, a + S(1)/2, 2*a, b, 2*a - b + 1, (sqrt(z)/2)**(-2*a + 1)*B(2*a - + b, sqrt(z))*B(b - 1, sqrt(z))*G(b)*G(2*a - b + 1), + sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b, + sqrt(z))*G(b)*G(2*a - b + 1), sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b - 1, + sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1), + (sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b + 1, + sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, S(1)/2, z/2, -b + 1, -2*a + b, + -2*a)) # noqa + c = cse(t) + assert len(c[0]) == 13 + + +def test_issue_6169(): + r = CRootOf(x**6 - 4*x**5 - 2, 1) + assert cse(r) == ([], [r]) + # and a check that the right thing is done with the new + # mechanism + assert sub_post(sub_pre((-x - y)*z - x - y)) == -z*(x + y) - x - y + + +def test_cse_Indexed(): # noqa + len_y = 5 + y = IndexedBase('y', shape=(len_y,)) + x = IndexedBase('x', shape=(len_y,)) + Dy = IndexedBase('Dy', shape=(len_y-1,)) # noqa + i = Idx('i', len_y-1) + + expr1 = (y[i+1]-y[i])/(x[i+1]-x[i]) + expr2 = 1/(x[i+1]-x[i]) + replacements, reduced_exprs = cse([expr1, expr2]) + assert len(replacements) > 0 + + +def test_Piecewise(): # noqa + f = Piecewise((-z + x*y, Eq(y, 0)), (-z - x*y, True)) + ans = cse(f) + actual_ans = ([(x0, -z), (x1, x*y)], + [Piecewise((x0+x1, Eq(y, 0)), (x0 - x1, True))]) + assert ans == actual_ans + + +def test_name_conflict(): + z1 = x0 + y + z2 = x2 + x3 + l = [cos(z1) + z1, cos(z2) + z2, x0 + x2] + substs, reduced = cse(l) + assert [e.subs(reversed(substs)) for e in reduced] == l + + +def test_name_conflict_cust_symbols(): + z1 = x0 + y + z2 = x2 + x3 + l = [cos(z1) + z1, cos(z2) + z2, x0 + x2] + substs, reduced = cse(l, symbols("x:10")) + assert [e.subs(reversed(substs)) for e in reduced] == l + + +def test_symbols_exhausted_error(): + l = cos(x+y)+x+y+cos(w+y)+sin(w+y) + sym = [x, y, z] + with pytest.raises(ValueError): + print(cse(l, symbols=sym)) + + +def test_issue_7840(): + # daveknippers' example + C393 = sympify( # noqa + 'Piecewise((C391 - 1.65, C390 < 0.5), (Piecewise((C391 - 1.65, \ + C391 > 2.35), (C392, True)), True))' + ) + C391 = sympify( # noqa + 'Piecewise((2.05*C390**(-1.03), C390 < 0.5), (2.5*C390**(-0.625), True))' + ) + C393 = C393.subs('C391',C391) # noqa + # simple substitution + sub = {} + sub['C390'] = 0.703451854 + sub['C392'] = 1.01417794 + ss_answer = C393.subs(sub) + # cse + substitutions, new_eqn = cse(C393) + for pair in substitutions: + sub[pair[0].name] = pair[1].subs(sub) + cse_answer = new_eqn[0].subs(sub) + # both methods should be the same + assert ss_answer == cse_answer + + # GitRay's example + expr = sympify( + "Piecewise((Symbol('ON'), Equality(Symbol('mode'), Symbol('ON'))), \ + (Piecewise((Piecewise((Symbol('OFF'), StrictLessThan(Symbol('x'), \ + Symbol('threshold'))), (Symbol('ON'), S.true)), Equality(Symbol('mode'), \ + Symbol('AUTO'))), (Symbol('OFF'), S.true)), S.true))" + ) + substitutions, new_eqn = cse(expr) + # this Piecewise should be exactly the same + assert new_eqn[0] == expr + # there should not be any replacements + assert len(substitutions) < 1 + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from py.test.cmdline import main + main([__file__]) + +# vim: fdm=marker