Skip to content
Snippets Groups Projects
Commit dcb396c1 authored by Matt Wala's avatar Matt Wala
Browse files

Placate flake8.

parent fb32f55b
No related branches found
No related tags found
1 merge request!66Add support for unoptimized kernels
Pipeline #
...@@ -173,25 +173,25 @@ def test_multiple_expressions(): ...@@ -173,25 +173,25 @@ def test_multiple_expressions():
substs, reduced = cse([e1, e2]) substs, reduced = cse([e1, e2])
assert substs == [(x0, x + y)] assert substs == [(x0, x + y)]
assert reduced == [x0*z, x0*w] assert reduced == [x0*z, x0*w]
l = [w*x*y + z, w*y] l_ = [w*x*y + z, w*y]
substs, reduced = cse(l) substs, reduced = cse(l_)
rsubsts, _ = cse(reversed(l)) rsubsts, _ = cse(reversed(l_))
assert substs == rsubsts assert substs == rsubsts
assert reduced == [z + x*x0, x0] assert reduced == [z + x*x0, x0]
l = [w*x*y, w*x*y + z, w*y] l_ = [w*x*y, w*x*y + z, w*y]
substs, reduced = cse(l) substs, reduced = cse(l_)
rsubsts, _ = cse(reversed(l)) rsubsts, _ = cse(reversed(l_))
assert substs == rsubsts assert substs == rsubsts
assert reduced == [x1, x1 + z, x0] assert reduced == [x1, x1 + z, x0]
f = Function("f") f = Function("f")
l = [f(x - z, y - z), x - z, y - z] l_ = [f(x - z, y - z), x - z, y - z]
substs, reduced = cse(l) substs, reduced = cse(l_)
rsubsts, _ = cse(reversed(l)) rsubsts, _ = cse(reversed(l_))
assert substs == [(x0, -z), (x1, x + x0), (x2, x0 + y)] assert substs == [(x0, -z), (x1, x + x0), (x2, x0 + y)]
assert rsubsts == [(x0, -z), (x1, x0 + y), (x2, x + x0)] assert rsubsts == [(x0, -z), (x1, x0 + y), (x2, x + x0)]
assert reduced == [f(x1, x2), x1, x2] assert reduced == [f(x1, x2), x1, x2]
l = [w*y + w + x + y + z, w*x*y] l_ = [w*y + w + x + y + z, w*x*y]
assert cse(l) == ([(x0, w*y)], [w + x + x0 + y + z, x*x0]) assert cse(l_) == ([(x0, w*y)], [w + x + x0 + y + z, x*x0])
assert cse([x + y, x + y + z]) == ([(x0, x + y)], [x0, z + x0]) assert cse([x + y, x + y + z]) == ([(x0, x + y)], [x0, z + x0])
assert cse([x + y, x + z]) == ([], [x + y, x + z]) assert cse([x + y, x + z]) == ([], [x + y, x + z])
assert cse([x*y, z + x*y, x*y*z + 3]) == \ assert cse([x*y, z + x*y, x*y*z + 3]) == \
...@@ -302,24 +302,24 @@ def test_Piecewise(): # noqa ...@@ -302,24 +302,24 @@ def test_Piecewise(): # noqa
def test_name_conflict(): def test_name_conflict():
z1 = x0 + y z1 = x0 + y
z2 = x2 + x3 z2 = x2 + x3
l = [cos(z1) + z1, cos(z2) + z2, x0 + x2] l_ = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
substs, reduced = cse(l) substs, reduced = cse(l_)
assert [e.subs(dict(substs)) for e in reduced] == l assert [e.subs(dict(substs)) for e in reduced] == l_
def test_name_conflict_cust_symbols(): def test_name_conflict_cust_symbols():
z1 = x0 + y z1 = x0 + y
z2 = x2 + x3 z2 = x2 + x3
l = [cos(z1) + z1, cos(z2) + z2, x0 + x2] l_ = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
substs, reduced = cse(l, symbols("x:10")) substs, reduced = cse(l_, symbols("x:10"))
assert [e.subs(dict(substs)) for e in reduced] == l assert [e.subs(dict(substs)) for e in reduced] == l_
def test_symbols_exhausted_error(): def test_symbols_exhausted_error():
l = cos(x+y)+x+y+cos(w+y)+sin(w+y) l_ = cos(x+y)+x+y+cos(w+y)+sin(w+y)
sym = [x, y, z] sym = [x, y, z]
with pytest.raises(ValueError): with pytest.raises(ValueError):
print(cse(l, symbols=sym)) print(cse(l_, symbols=sym))
@sympyonly @sympyonly
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment