From c2815afc14bda389ecd1d0774d56ba6b2a37d453 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Tue, 10 Jan 2017 00:17:20 -0600 Subject: [PATCH] [wip --- examples/m2l-timing.py | 12 +- sumpy/assignment_collection.py | 6 +- sumpy/cse.py | 319 ++++++++++++++------------------- test/test_cse.py | 105 +++-------- 4 files changed, 178 insertions(+), 264 deletions(-) diff --git a/examples/m2l-timing.py b/examples/m2l-timing.py index f2fd3a26..8596f891 100644 --- a/examples/m2l-timing.py +++ b/examples/m2l-timing.py @@ -14,15 +14,17 @@ def test_m2l_creation(ctx, mpole_expn_class, local_expn_class, knl, order): if __name__ == "__main__": import logging logging.basicConfig(level=logging.INFO) - from sumpy.kernel import LaplaceKernel + from sumpy.kernel import HelmholtzKernel, LaplaceKernel import pyopencl as cl ctx = cl._csc() - from sumpy.expansion.local import LaplaceConformingVolumeTaylorLocalExpansion as LExpn - from sumpy.expansion.multipole import LaplaceConformingVolumeTaylorMultipoleExpansion as MExpn + from sumpy.expansion.local import HelmholtzConformingVolumeTaylorLocalExpansion as LExpn + from sumpy.expansion.multipole import HelmholtzConformingVolumeTaylorMultipoleExpansion as MExpn + #from sumpy.expansion.local import H2DLocalExpansion as LExpn + #from sumpy.expansion.multipole import H2DMultipoleExpansion as MExpn results = [] - for order in range(30, 31): - results.append((order, test_m2l_creation(ctx, MExpn, LExpn, LaplaceKernel(2), order))) + for order in range(1, 2): + results.append((order, test_m2l_creation(ctx, MExpn, LExpn, HelmholtzKernel(2, 'k'), order))) print("order\ttime (s)") for order, time in results: print("{}\t{:.2f}".format(order, time)) diff --git a/sumpy/assignment_collection.py b/sumpy/assignment_collection.py index f7c96b0b..ce19f418 100644 --- a/sumpy/assignment_collection.py +++ b/sumpy/assignment_collection.py @@ -103,10 +103,12 @@ def cached_cse(exprs, symbols): s2p = SympyToPymbolicMapper() p2s = PymbolicToSympyMapper() + print(exprs) key_exprs = tuple(s2p(expr) for expr in exprs) - key = (key_exprs, frozenset(symbols.taken_symbols), - frozenset(symbols.generated_names)) + key = (key_exprs) + + print(key) try: result = cache_dict[key] diff --git a/sumpy/cse.py b/sumpy/cse.py index a7b6496e..9bea6e56 100644 --- a/sumpy/cse.py +++ b/sumpy/cse.py @@ -1,5 +1,3 @@ -""" Tools for doing common subexpression elimination. -""" from __future__ import print_function, division __copyright__ = """ @@ -64,45 +62,23 @@ DAMAGE. # }}} -from sympy.core import Basic, Mul, Add, Pow, sympify, Symbol, Tuple -#from symengine.sympy_compat import (Basic, Mul, Add, Pow, sympify, Symbol) -from sympy.core.singleton import S +from sympy.core import Basic, Mul, Add, Pow, Symbol from sympy.core.function import _coeff_isneg -from sympy.core.exprtools import factor_terms -from sympy.core.compatibility import iterable, range -from sympy.utilities.iterables import filter_symbols, \ - numbered_symbols, sift, topological_sort, ordered - -#from . import cse_opts - -# (preprocessor, postprocessor) pairs which are commonly useful. They should -# each take a sympy expression and return a possibly transformed expression. -# When used in the function ``cse()``, the target expressions will be transformed -# by each of the preprocessor functions in order. After the common -# subexpressions are eliminated, each resulting expression will have the -# postprocessor functions transform them in *reverse* order in order to undo the -# transformation if necessary. This allows the algorithm to operate on -# a representation of the expressions that allows for more optimization -# opportunities. -# ``None`` can be used to specify no transformation for either the preprocessor or -# postprocessor. +from sympy.core.compatibility import iterable +from sympy.utilities.iterables import numbered_symbols + +# {{{ cse pre/postprocessing def preprocess_for_cse(expr, optimizations): - """ Preprocess an expression to optimize for common subexpression - elimination. + """ + Preprocess an expression to optimize for common subexpression elimination. - Parameters - ---------- - expr : sympy expression - The target expression to optimize. - optimizations : list of (callable, callable) pairs - The (preprocessor, postprocessor) pairs. + :arg expr: A sympy expression, the target expression to optimize. + :arg optimizations: A list of (callable, callable) pairs, + the (preprocessor, postprocessor) pairs. - Returns - ------- - expr : sympy expression - The transformed expression. + :return: The transformed expression. """ for pre, post in optimizations: if pre is not None: @@ -111,30 +87,32 @@ def preprocess_for_cse(expr, optimizations): def postprocess_for_cse(expr, optimizations): - """ Postprocess an expression after common subexpression elimination to + """ + Postprocess an expression after common subexpression elimination to return the expression to canonical sympy form. - Parameters - ---------- - expr : sympy expression - The target expression to transform. - optimizations : list of (callable, callable) pairs, optional - The (preprocessor, postprocessor) pairs. The postprocessors will be + :arg expr: sympy expression, the target expression to transform. + :arg optimizations: A list of (callable, callable) pairs (optional), + the (preprocessor, postprocessor) pairs. The postprocessors will be applied in reversed order to undo the effects of the preprocessors correctly. - Returns - ------- - expr : sympy expression - The transformed expression. + :return: The transformed expression. """ for pre, post in reversed(optimizations): if post is not None: expr = post(expr) return expr +# }}} + + +# {{{ opt cse class FuncArgTracker(object): + """ + A class which manages an inverse mapping from arguments to functions. + """ def __init__(self, funcs): # To minimize the number of symbolic comparisons, all function arguments @@ -157,9 +135,16 @@ class FuncArgTracker(object): self.func_to_argset.append(func_argset) def get_args_in_value_order(self, argset): + """ + Return the list of arguments in sorted order according to their value + numbers. + """ return [self.value_number_to_value[argn] for argn in sorted(argset)] def get_or_add_value_number(self, value): + """ + Return the value number for the given argument. + """ nvalues = len(self.value_numbers) value_number = self.value_numbers.setdefault(value, nvalues) if value_number == nvalues: @@ -168,10 +153,17 @@ class FuncArgTracker(object): return value_number def stop_arg_tracking(self, func_i): + """ + Remove the function func_i from the argument to function mapping. + """ for arg in self.func_to_argset[func_i]: self.arg_to_funcset[arg].remove(func_i) def gen_common_arg_candidates(self, argset, min_func_i, threshold=2): + """ + Generate the list of functions which have at least `threshold` arguments in + common from `argset`. + """ from collections import defaultdict count_map = defaultdict(lambda: 0) @@ -188,6 +180,9 @@ class FuncArgTracker(object): yield item def gen_subset_candidates(self, argset, min_func_i): + """ + Generate the list of functions whose set of arguments contains `argset`. + """ iarg = iter(argset) indices = set( @@ -201,6 +196,9 @@ class FuncArgTracker(object): yield item def update_func_argset(self, func_i, new_argset): + """ + Update a function with a new set of arguments. + """ new_args = set(new_argset) old_args = self.func_to_argset[func_i] @@ -213,15 +211,33 @@ class FuncArgTracker(object): self.func_to_argset[func_i].update(new_args) -def match_common_args(Func, funcs, kind, opt_subs): - funcs = list(funcs) +def match_common_args(func_class, funcs, opt_subs): + """ + Recognize and extract common subexpressions of function arguments within a + set of function calls. For instance, for the following function calls:: + + x + z + y + sin(x + y) + + this will extract a common subexpression of `x + y`:: + + w = x + y + w + z + sin(w) + + The function we work with is assumed to be associative and commutative. + + :arg func_class: The function class (e.g. Add, Mul) + :arg funcs: A list of function calls + :arg opt_subs: A dictionary of substitutions which this function may update + """ arg_tracker = FuncArgTracker(funcs) changed = set() for i in range(len(funcs)): for j in arg_tracker.gen_common_arg_candidates( - arg_tracker.func_to_argset[i], i + 1, threshold=2): + arg_tracker.func_to_argset[i], i + 1, threshold=2): com_args = arg_tracker.func_to_argset[i].intersection( arg_tracker.func_to_argset[j]) @@ -231,11 +247,11 @@ def match_common_args(Func, funcs, kind, opt_subs): # combined in a previous iteration. continue - com_func = Func(*arg_tracker.get_args_in_value_order(com_args)) + com_func = func_class(*arg_tracker.get_args_in_value_order(com_args)) com_func_number = arg_tracker.get_or_add_value_number(com_func) - # for all sets, replace the common symbols by the function - # over them, to allow recursive matches + # For all sets, replace the common symbols by the function + # over them, to allow recursive matches. diff_i = arg_tracker.func_to_argset[i].difference(com_args) arg_tracker.update_func_argset(i, diff_i | set([com_func_number])) @@ -251,7 +267,7 @@ def match_common_args(Func, funcs, kind, opt_subs): changed.add(k) if i in changed: - opt_subs[funcs[i]] = Func( + opt_subs[funcs[i]] = func_class( *arg_tracker.get_args_in_value_order(arg_tracker.func_to_argset[i]), evaluate=False) @@ -259,115 +275,102 @@ def match_common_args(Func, funcs, kind, opt_subs): def opt_cse(exprs): - """Find optimization opportunities in Adds, Muls, Pows and negative - coefficient Muls - - Parameters - ---------- - exprs : list of sympy expressions - The expressions to optimize. + """ + Find optimization opportunities in Adds, Muls, Pows and negative coefficient + Muls - Returns - ------- - opt_subs : dictionary of expression substitutions - The expression substitutions which can be useful to optimize CSE. + :arg exprs: A list of sympy expressions: the expressions to optimize. + :return: A dictionary of expression substitutions """ opt_subs = dict() - adds = set() - muls = set() + adds = [] + muls = [] seen_subexp = set() - def _find_opts(expr): + # {{{ look for optimization opportunities, clean up minus signs + + def find_opts(expr): if not isinstance(expr, Basic): return - if expr.is_Atom: # or expr.is_Order: + if expr.is_Atom: return if iterable(expr): - list(map(_find_opts, expr)) + for item in expr: + find_opts(item) return if expr in seen_subexp: return expr - seen_subexp.add(expr) - list(map(_find_opts, expr.args)) + seen_subexp.add(expr) - """ - def _coeff_isneg(expr): - return expr.is_Number and expr < 0 - """ + for arg in expr.args: + find_opts(arg) if _coeff_isneg(expr): neg_expr = -expr if not neg_expr.is_Atom: - opt_subs[expr] = Mul(S.NegativeOne, neg_expr, evaluate=False) + opt_subs[expr] = Mul(-1, neg_expr, evaluate=False) seen_subexp.add(neg_expr) expr = neg_expr if isinstance(expr, Mul): - muls.add(expr) + muls.append(expr) elif isinstance(expr, Add): - adds.add(expr) + adds.append(expr) elif isinstance(expr, Pow): - try: - # symengine - base, exp = expr.args - except ValueError: - # sympy - base = expr.base - exp = expr.exp + base, exp = expr.args if _coeff_isneg(exp): opt_subs[expr] = Pow(Pow(base, -exp), -1, evaluate=False) + # }}} + for e in exprs: if isinstance(e, Basic): - _find_opts(e) + find_opts(e) - ## Process Adds and Muls - - match_common_args(Add, adds, "add", opt_subs) - match_common_args(Mul, muls, "mul", opt_subs) + match_common_args(Add, adds, opt_subs) + match_common_args(Mul, muls, opt_subs) return opt_subs +# }}} -def tree_cse(exprs, symbols, opt_subs=None, order='none'): - """Perform raw CSE on expression tree, taking opt_subs into account. - Parameters - ========== +# {{{ tree cse + +def tree_cse(exprs, symbols, opt_subs=None): + """ + Perform raw CSE on an expression tree, taking opt_subs into account. + + :arg exprs: A list of sympy expressions to reduce + :arg symbols: An infinite iterator yielding unique Symbols used to label + the common subexpressions which are pulled out. + :arg opt_subs: A dictionary of expression substitutions to be + substituted before any CSE action is performed. - exprs : list of sympy expressions - The expressions to reduce. - symbols : infinite iterator yielding unique Symbols - The symbols used to label the common subexpressions which are pulled - out. - opt_subs : dictionary of expression substitutions - The expressions to be substituted before any CSE action is performed. - order : string, 'none' or 'canonical' - The order by which Mul and Add arguments are processed. For large - expressions where speed is a concern, use the setting order='none'. + :return: A pair (replacements, reduced exprs) """ if opt_subs is None: opt_subs = dict() - ## Find repeated sub-expressions + # {{{ find repeated sub-expressions to_eliminate = set() seen_subexp = set() - def _find_repeated(expr): + def find_repeated(expr): if not isinstance(expr, Basic): return - if expr.is_Atom: #or expr.is_Order: + if expr.is_Atom: return if iterable(expr): @@ -385,19 +388,22 @@ def tree_cse(exprs, symbols, opt_subs=None, order='none'): args = expr.args - list(map(_find_repeated, args)) + for arg in args: + find_repeated(arg) + + # }}} for e in exprs: if isinstance(e, Basic): - _find_repeated(e) + find_repeated(e) - ## Rebuild tree + # {{{ rebuild tree replacements = [] subs = dict() - def _rebuild(expr): + def rebuild(expr): if not isinstance(expr, Basic): return expr @@ -405,7 +411,7 @@ def tree_cse(exprs, symbols, opt_subs=None, order='none'): return expr if iterable(expr): - new_args = [_rebuild(arg) for arg in expr] + new_args = [rebuild(arg) for arg in expr] return expr.func(*new_args) if expr in subs: @@ -415,25 +421,8 @@ def tree_cse(exprs, symbols, opt_subs=None, order='none'): if expr in opt_subs: expr = opt_subs[expr] - # If enabled, parse Muls and Adds arguments by order to ensure - # replacement order independent from hashes - if order != 'none': - if isinstance(expr, (Mul, MatMul)): - #c, nc = expr.args_cnc() - #if c == [1]: - # args = nc - #else: - # args = list(ordered(c)) + nc - args = expr.args - elif isinstance(expr, (Add, MatAdd)): - args = list(ordered(expr.args)) - else: - args = expr.args - else: - args = expr.args - - new_args = list(map(_rebuild, args)) - if new_args != args: + new_args = [rebuild(arg) for arg in expr.args] + if new_args != expr.args: new_expr = expr.func(*new_args) else: new_expr = expr @@ -448,50 +437,33 @@ def tree_cse(exprs, symbols, opt_subs=None, order='none'): replacements.append((sym, new_expr)) return sym - else: - return new_expr + return new_expr + + # }}} reduced_exprs = [] for e in exprs: if isinstance(e, Basic): - reduced_e = _rebuild(e) + reduced_e = rebuild(e) else: reduced_e = e reduced_exprs.append(reduced_e) return replacements, reduced_exprs +# }}} + + +def cse(exprs, symbols=None, optimizations=None): + """Perform common subexpression elimination on an expression. -def cse(exprs, symbols=None, optimizations=None, postprocess=None, - order='none'): - """ Perform common subexpression elimination on an expression. - - Parameters - ========== - - exprs : list of sympy expressions, or a single sympy expression - The expressions to reduce. - symbols : infinite iterator yielding unique Symbols - The symbols used to label the common subexpressions which are pulled - out. The ``numbered_symbols`` generator is useful. The default is a - stream of symbols of the form "x0", "x1", etc. This must be an - infinite iterator. - optimizations : list of (callable, callable) pairs - The (preprocessor, postprocessor) pairs of external optimization - functions. Optionally 'basic' can be passed for a set of predefined - basic optimizations. Such 'basic' optimizations were used by default - in old implementation, however they can be really slow on larger - expressions. Now, no pre or post optimizations are made by default. - postprocess : a function which accepts the two return values of cse and - returns the desired form of output from cse, e.g. if you want the - replacements reversed the function might be the following lambda: - lambda r, e: return reversed(r), e - order : string, 'none' or 'canonical' - The order by which Mul and Add arguments are processed. If set to - 'canonical', arguments will be canonically ordered. If set to 'none', - ordering will be faster but dependent on expressions hashes, thus - machine dependent and variable. For large expressions where speed is a - concern, use the setting order='none'. + :arg exprs: A list of sympy expressions, or a single sympy expression to reduce + :arg symbols: An iterator yielding unique Symbols used to label the + common subexpressions which are pulled out. The ``numbered_symbols`` + generator is useful. The default is a stream of symbols of the form + "x0", "x1", etc. This must be an infinite iterator. + :arg optimizations: A list of (callable, callable) pairs consisting of + (preprocessor, postprocessor) pairs of external optimization functions. Returns ======= @@ -503,18 +475,6 @@ def cse(exprs, symbols=None, optimizations=None, postprocess=None, reduced_exprs : list of sympy expressions The reduced expressions with all of the replacements above. - Examples - ======== - - >>> from sympy import cse, SparseMatrix - >>> from sympy.abc import x, y, z, w - >>> cse(((w + x + y + z)*(w + y + z))/(w + x)**3) - ([(x0, y + z), (x1, w + x)], [(w + x0)*(x0 + x1)/x1**3]) - - Note that currently, y + z will not get substituted if -y - z is used. - - >>> cse(((w + x + y + z)*(w - y - z))/(w + x)**3) - ([(x0, w + x)], [(w - y - z)*(x0 + y + z)/x0**3]) """ if isinstance(exprs, Basic): exprs = [exprs] @@ -553,7 +513,7 @@ def cse(exprs, symbols=None, optimizations=None, postprocess=None, logger.info("cse: done after {}".format(time.time() - start)) # Main CSE algorithm. - replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs, order) + replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs) # Postprocess the expressions to return the expressions to canonical form. for i, (sym, subtree) in enumerate(replacements): @@ -561,7 +521,4 @@ def cse(exprs, symbols=None, optimizations=None, postprocess=None, replacements[i] = (sym, subtree) reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs] - if postprocess is None: - return replacements, reduced_exprs - - return postprocess(replacements, reduced_exprs) + return replacements, reduced_exprs diff --git a/test/test_cse.py b/test/test_cse.py index 185359cb..b6fefdbf 100644 --- a/test/test_cse.py +++ b/test/test_cse.py @@ -67,38 +67,24 @@ import pytest import itertools import sys -from sympy import (Add, Pow, Symbol, exp, sqrt, symbols, sympify, cse, - Matrix, S, cos, sin, Eq, Function, Tuple, CRootOf, - IndexedBase, Idx, Piecewise, O) +from sympy import (Add, Pow, Symbol, exp, sqrt, symbols, sympify, S, cos, + sin, Eq, Function, Tuple, CRootOf, IndexedBase, Idx, + Piecewise) from sympy.simplify.cse_opts import sub_pre, sub_post from sympy.functions.special.hyper import meijerg -from sympy.simplify import cse_main, cse_opts -from sympy.utilities.pytest import XFAIL -from sympy.matrices import (eye, SparseMatrix, MutableDenseMatrix, - MutableSparseMatrix, ImmutableDenseMatrix, ImmutableSparseMatrix) -from sympy.matrices.expressions import MatrixSymbol +from sympy.simplify import cse_opts from sympy.core.compatibility import range -from sumpy.cse import cse +from sumpy.cse import ( + cse, preprocess_for_cse, postprocess_for_cse) w, x, y, z = symbols('w,x,y,z') x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12 = symbols('x:13') -def test_numbered_symbols(): - ns = cse_main.numbered_symbols(prefix='y') - assert list(itertools.islice( - ns, 0, 10)) == [Symbol('y%s' % i) for i in range(0, 10)] - ns = cse_main.numbered_symbols(prefix='y') - assert list(itertools.islice( - ns, 10, 20)) == [Symbol('y%s' % i) for i in range(10, 20)] - ns = cse_main.numbered_symbols() - assert list(itertools.islice( - ns, 0, 10)) == [Symbol('x%s' % i) for i in range(0, 10)] - # Dummy "optimization" functions for testing. @@ -111,21 +97,21 @@ def opt2(expr): def test_preprocess_for_cse(): - assert cse_main.preprocess_for_cse(x, [(opt1, None)]) == x + y - assert cse_main.preprocess_for_cse(x, [(None, opt1)]) == x - assert cse_main.preprocess_for_cse(x, [(None, None)]) == x - assert cse_main.preprocess_for_cse(x, [(opt1, opt2)]) == x + y - assert cse_main.preprocess_for_cse( + assert preprocess_for_cse(x, [(opt1, None)]) == x + y + assert preprocess_for_cse(x, [(None, opt1)]) == x + assert preprocess_for_cse(x, [(None, None)]) == x + assert preprocess_for_cse(x, [(opt1, opt2)]) == x + y + assert preprocess_for_cse( x, [(opt1, None), (opt2, None)]) == (x + y)*z def test_postprocess_for_cse(): - assert cse_main.postprocess_for_cse(x, [(opt1, None)]) == x - assert cse_main.postprocess_for_cse(x, [(None, opt1)]) == x + y - assert cse_main.postprocess_for_cse(x, [(None, None)]) == x - assert cse_main.postprocess_for_cse(x, [(opt1, opt2)]) == x*z + assert postprocess_for_cse(x, [(opt1, None)]) == x + assert postprocess_for_cse(x, [(None, opt1)]) == x + y + assert postprocess_for_cse(x, [(None, None)]) == x + assert postprocess_for_cse(x, [(opt1, opt2)]) == x*z # Note the reverse order of application. - assert cse_main.postprocess_for_cse( + assert postprocess_for_cse( x, [(None, opt1), (None, opt2)]) == x*z + y @@ -206,27 +192,10 @@ def test_multiple_expressions(): ([(x0, x*y)], [x0, z + x0, 3 + x0*z]) -""" -def test_issue_4498(): - assert cse(w/(x - y) + z/(y - x), optimizations='basic') == \ - ([], [(w - z)/(x - y)]) - - -def test_issue_4020(): - assert cse(x**5 + x**4 + x**3 + x**2, optimizations='basic') \ - == ([(x0, x**2)], [x0*(x**3 + x + x0 + 1)]) - - def test_issue_4203(): assert cse(sin(x**x)/x**x) == ([(x0, x**x)], [sin(x0)/x0]) -def test_issue_6263(): - e = Eq(x*(-x + 1) + x*(x - 1), 0) - assert cse(e, optimizations='basic') == ([], [True]) -""" - - def test_dont_cse_tuples(): from sympy import Subs f = Function("f") @@ -270,22 +239,11 @@ def test_pow_invpow(): ([(x0, x**(2*y))], [x0 + 1/x0]) -""" -fixme -def test_postprocess(): - eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1)) - assert cse([eq, Eq(x, z + 1), z - 2, (z + 1)*(x + 1)], - postprocess=cse_main.cse_separate) == \ - [[(x0, y + 1), (x2, z + 1), (x, x2), (x1, x + 1)], - [x1 + exp(x1/x0) + cos(x0), z - 2, x1*x2]] -""" - - def test_issue_4499(): # previously, this gave 16 constants from sympy.abc import a, b - B = Function('B') - G = Function('G') + B = Function('B') # noqa + G = Function('G') # noqa t = Tuple(* (a, a + S(1)/2, 2*a, b, 2*a - b + 1, (sqrt(z)/2)**(-2*a + 1)*B(2*a - b, sqrt(z))*B(b - 1, sqrt(z))*G(b)*G(2*a - b + 1), @@ -294,14 +252,8 @@ def test_issue_4499(): sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1), (sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, S(1)/2, z/2, -b + 1, -2*a + b, - -2*a)) + -2*a)) # noqa c = cse(t) - ans = ( - [(x0, 2*a), (x1, -b), (x2, x1 + 1), (x3, x0 + x2), (x4, sqrt(z)), (x5, - B(x0 + x1, x4)), (x6, G(b)), (x7, G(x3)), (x8, -x0), (x9, - (x4/2)**(x8 + 1)), (x10, x6*x7*x9*B(b - 1, x4)), (x11, x6*x7*x9*B(b, - x4)), (x12, B(x3, x4))], [(a, a + S(1)/2, x0, b, x3, x10*x5, - x11*x4*x5, x10*x12*x4, x11*x12, 1, 0, S(1)/2, z/2, x2, b + x8, x8)]) assert len(c[0]) == 13 @@ -313,11 +265,11 @@ def test_issue_6169(): assert sub_post(sub_pre((-x - y)*z - x - y)) == -z*(x + y) - x - y -def test_cse_Indexed(): +def test_cse_Indexed(): # noqa len_y = 5 y = IndexedBase('y', shape=(len_y,)) x = IndexedBase('x', shape=(len_y,)) - Dy = IndexedBase('Dy', shape=(len_y-1,)) + Dy = IndexedBase('Dy', shape=(len_y-1,)) # noqa i = Idx('i', len_y-1) expr1 = (y[i+1]-y[i])/(x[i+1]-x[i]) @@ -326,10 +278,11 @@ def test_cse_Indexed(): assert len(replacements) > 0 -def test_Piecewise(): +def test_Piecewise(): # noqa f = Piecewise((-z + x*y, Eq(y, 0)), (-z - x*y, True)) ans = cse(f) - actual_ans = ([(x0, -z), (x1, x*y)], [Piecewise((x0+x1, Eq(y, 0)), (x0 - x1, True))]) + actual_ans = ([(x0, -z), (x1, x*y)], + [Piecewise((x0+x1, Eq(y, 0)), (x0 - x1, True))]) assert ans == actual_ans @@ -352,27 +305,27 @@ def test_name_conflict_cust_symbols(): def test_symbols_exhausted_error(): l = cos(x+y)+x+y+cos(w+y)+sin(w+y) sym = [x, y, z] - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError): print(cse(l, symbols=sym)) def test_issue_7840(): # daveknippers' example - C393 = sympify( \ + C393 = sympify( # noqa 'Piecewise((C391 - 1.65, C390 < 0.5), (Piecewise((C391 - 1.65, \ C391 > 2.35), (C392, True)), True))' ) - C391 = sympify( \ + C391 = sympify( # noqa 'Piecewise((2.05*C390**(-1.03), C390 < 0.5), (2.5*C390**(-0.625), True))' ) - C393 = C393.subs('C391',C391) + C393 = C393.subs('C391',C391) # noqa # simple substitution sub = {} sub['C390'] = 0.703451854 sub['C392'] = 1.01417794 ss_answer = C393.subs(sub) # cse - substitutions,new_eqn = cse(C393) + substitutions, new_eqn = cse(C393) for pair in substitutions: sub[pair[0].name] = pair[1].subs(sub) cse_answer = new_eqn[0].subs(sub) -- GitLab