From 6a898631c6bab671cea00b0d550c2702f8c04a09 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Sat, 20 Aug 2022 20:01:30 +0300
Subject: [PATCH] add array context skeleton

---
 requirements.txt             |   1 +
 setup.py                     |   1 +
 sumpy/array_context.py       |  77 +++++++++++++
 sumpy/expansion/multipole.py |   2 +-
 sumpy/qbx.py                 |   2 +-
 test/test_codegen.py         |  15 ++-
 test/test_cse.py             | 213 +++++++++++++++++++++++++----------
 7 files changed, 245 insertions(+), 66 deletions(-)
 create mode 100644 sumpy/array_context.py

diff --git a/requirements.txt b/requirements.txt
index a26c33e3..ff9ca871 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,4 +8,5 @@ git+https://github.com/inducer/islpy.git#egg=islpy
 git+https://github.com/inducer/pyopencl.git#egg=pyopencl
 git+https://github.com/inducer/boxtree.git#egg=boxtree
 git+https://github.com/inducer/loopy.git#egg=loopy
+git+https://github.com/inducer/arraycontext.git#egg=arraycontext
 git+https://github.com/inducer/pyfmmlib.git#egg=pyfmmlib
diff --git a/setup.py b/setup.py
index f9f0be9b..19ca094f 100644
--- a/setup.py
+++ b/setup.py
@@ -101,6 +101,7 @@ setup(
         "pytools>=2021.1.1",
         "loopy>=2021.1",
         "boxtree>=2018.1",
+        "arraycontext",
         "pytest>=2.3",
         "pyrsistent>=0.16.0",
         "dataclasses>=0.7;python_version<='3.6'",
diff --git a/sumpy/array_context.py b/sumpy/array_context.py
new file mode 100644
index 00000000..760f7d5d
--- /dev/null
+++ b/sumpy/array_context.py
@@ -0,0 +1,77 @@
+__copyright__ = "Copyright (C) 2022 Alexandru Fikl"
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+from boxtree.array_context import PyOpenCLArrayContext as PyOpenCLArrayContextBase
+from arraycontext.pytest import (
+        _PytestPyOpenCLArrayContextFactoryWithClass,
+        register_pytest_array_context_factory)
+
+__doc__ = """
+.. autoclass:: PyOpenCLArrayContext
+"""
+
+
+# {{{ PyOpenCLArrayContext
+
+class PyOpenCLArrayContext(PyOpenCLArrayContextBase):
+    def transform_loopy_program(self, t_unit):
+        default_ep = t_unit.default_entrypoint
+        options = default_ep.options
+
+        if not (options.return_dict and options.no_numpy):
+            raise ValueError("Loopy kernel passed to call_loopy must "
+                    "have return_dict and no_numpy options set. "
+                    "Did you use arraycontext.make_loopy_program "
+                    "to create this kernel?")
+
+        return super().transform_loopy_program(t_unit)
+
+# }}}
+
+
+# {{{ pytest
+
+def _acf():
+    import pyopencl as cl
+    ctx = cl.create_some_context()
+    queue = cl.CommandQueue(ctx)
+
+    return PyOpenCLArrayContext(queue, force_device_scalars=True)
+
+
+class PytestPyOpenCLArrayContextFactory(
+        _PytestPyOpenCLArrayContextFactoryWithClass):
+    actx_class = PyOpenCLArrayContext
+
+    def __call__(self):
+        # NOTE: prevent any cache explosions during testing!
+        from sympy.core.cache import clear_cache
+        clear_cache()
+
+        return super().__call__()
+
+
+register_pytest_array_context_factory(
+    "sumpy.pyopencl",
+    PytestPyOpenCLArrayContextFactory)
+
+# }}}
diff --git a/sumpy/expansion/multipole.py b/sumpy/expansion/multipole.py
index fd5a3b1d..1924e6c7 100644
--- a/sumpy/expansion/multipole.py
+++ b/sumpy/expansion/multipole.py
@@ -203,7 +203,7 @@ class VolumeTaylorMultipoleExpansionBase(MultipoleExpansionBase):
         #    └───⬏ ↑
         #    └─────┘
         #
-        # For the second hyperplane, data is propogated rightwards first
+        # For the second hyperplane, data is propagated rightwards first
         # and then upwards second which is opposite to that of the first
         # hyperplane.
         #
diff --git a/sumpy/qbx.py b/sumpy/qbx.py
index b479e1db..de259888 100644
--- a/sumpy/qbx.py
+++ b/sumpy/qbx.py
@@ -104,7 +104,7 @@ class LayerPotentialBase(KernelComputation, KernelCacheWrapper):
             # In LineTaylorLocalExpansion.evaluate, we can't run
             # postprocess_at_target because the coefficients are assigned
             # symbols and postprocess with a derivative will make them zero.
-            # Instead run postprocess here before the coeffients are assigned.
+            # Instead run postprocess here before the coefficients are assigned.
             coefficients = [tgt_knl.postprocess_at_target(coeff, bvec) for
                     coeff in coefficients]
 
diff --git a/test/test_codegen.py b/test/test_codegen.py
index 815b92dd..9c092e58 100644
--- a/test/test_codegen.py
+++ b/test/test_codegen.py
@@ -20,13 +20,15 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-
+import pytest
 import sys
 
 import logging
 logger = logging.getLogger(__name__)
 
 
+# {{{ test_symbolic_assignment_name_uniqueness
+
 def test_symbolic_assignment_name_uniqueness():
     # https://gitlab.tiker.net/inducer/sumpy/issues/13
     from sumpy.assignment_collection import SymbolicAssignmentCollection
@@ -43,6 +45,10 @@ def test_symbolic_assignment_name_uniqueness():
 
     assert len(sac.assignments) == 3
 
+# }}}
+
+
+# {{{ test_line_taylor_coeff_growth
 
 def test_line_taylor_coeff_growth():
     # Regression test for LineTaylorLocalExpansion.
@@ -70,15 +76,16 @@ def test_line_taylor_coeff_growth():
     max_order = 2
     assert np.polyfit(np.log(indices), np.log(counts), deg=1)[0] < max_order
 
+# }}}
+
 
 # You can test individual routines by typing
-# $ python test_fmm.py "test_sumpy_fmm(cl.create_some_context)"
+# $ python test_codegen.py 'test_line_taylor_coeff_growth()'
 
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
     else:
-        from pytest import main
-        main([__file__])
+        pytest.main([__file__])
 
 # vim: fdm=marker
diff --git a/test/test_cse.py b/test/test_cse.py
index ba9986be..68352c90 100644
--- a/test/test_cse.py
+++ b/test/test_cse.py
@@ -67,26 +67,22 @@ DAMAGE.
 import pytest
 import sys
 
-from sumpy.symbolic import (
-    Add, Pow, exp, sqrt, symbols, sympify, cos, sin, Function, USE_SYMENGINE)
+import sumpy.symbolic as sym
+from sumpy.cse import cse, preprocess_for_cse, postprocess_for_cse
 
-if not USE_SYMENGINE:
+if not sym.USE_SYMENGINE:
     from sympy.simplify.cse_opts import sub_pre, sub_post
     from sympy.functions.special.hyper import meijerg
     from sympy.simplify import cse_opts
 
-from sumpy.cse import (
-    cse, preprocess_for_cse, postprocess_for_cse)
+import logging
+logger = logging.getLogger(__name__)
 
-
-w, x, y, z = symbols("w,x,y,z")
-x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12 = symbols("x:13")
+w, x, y, z = sym.symbols("w,x,y,z")
+x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12 = sym.symbols("x:13")
 
 sympyonly = (
-    pytest.mark.skipif(USE_SYMENGINE, reason="uses a sympy-only feature"))
-
-
-# Dummy "optimization" functions for testing.
+    pytest.mark.skipif(sym.USE_SYMENGINE, reason="uses a sympy-only feature"))
 
 
 def opt1(expr):
@@ -97,6 +93,8 @@ def opt2(expr):
     return expr*z
 
 
+# {{{ test_preprocess_for_cse
+
 def test_preprocess_for_cse():
     assert preprocess_for_cse(x, [(opt1, None)]) == x + y
     assert preprocess_for_cse(x, [(None, opt1)]) == x
@@ -105,6 +103,10 @@ def test_preprocess_for_cse():
     assert preprocess_for_cse(
         x, [(opt1, None), (opt2, None)]) == (x + y)*z
 
+# }}}
+
+
+# {{{ test_postprocess_for_cse
 
 def test_postprocess_for_cse():
     assert postprocess_for_cse(x, [(opt1, None)]) == x
@@ -115,19 +117,27 @@ def test_postprocess_for_cse():
     assert postprocess_for_cse(
         x, [(None, opt1), (None, opt2)]) == x*z + y
 
+# }}}
+
+
+# {{{ test_cse_single
 
 def test_cse_single():
     # Simple substitution.
-    e = Add(Pow(x + y, 2), sqrt(x + y))
+    e = sym.Add(sym.Pow(x + y, 2), sym.sqrt(x + y))
     substs, reduced = cse([e])
     assert substs == [(x0, x + y)]
-    assert reduced == [sqrt(x0) + x0**2]
+    assert reduced == [sym.sqrt(x0) + x0**2]
+
+# }}}
+
 
+# {{{
 
 @sympyonly
 def test_cse_not_possible():
     # No substitution possible.
-    e = Add(x, y)
+    e = sym.Add(x, y)
     substs, reduced = cse([e])
     assert substs == []
     assert reduced == [x + y]
@@ -136,34 +146,46 @@ def test_cse_not_possible():
           + meijerg((1, 3), (y, 4), (5,), [], x))
     assert cse(eq) == ([], [eq])
 
+# }}}
+
+
+# {{{ test_nested_substitution
 
 def test_nested_substitution():
     # Substitution within a substitution.
-    e = Add(Pow(w*x + y, 2), sqrt(w*x + y))
+    e = sym.Add(sym.Pow(w*x + y, 2), sym.sqrt(w*x + y))
     substs, reduced = cse([e])
     assert substs == [(x0, w*x + y)]
-    assert reduced == [sqrt(x0) + x0**2]
+    assert reduced == [sym.sqrt(x0) + x0**2]
 
+# }}}
+
+
+# {{{ test_subtraction_opt
 
 @sympyonly
 def test_subtraction_opt():
     # Make sure subtraction is optimized.
-    e = (x - y)*(z - y) + exp((x - y)*(z - y))
+    e = (x - y)*(z - y) + sym.exp((x - y)*(z - y))
     substs, reduced = cse(
         [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)])
     assert substs == [(x0, (x - y)*(y - z))]
-    assert reduced == [-x0 + exp(-x0)]
-    e = -(x - y)*(z - y) + exp(-(x - y)*(z - y))
+    assert reduced == [-x0 + sym.exp(-x0)]
+    e = -(x - y)*(z - y) + sym.exp(-(x - y)*(z - y))
     substs, reduced = cse(
         [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)])
     assert substs == [(x0, (x - y)*(y - z))]
-    assert reduced == [x0 + exp(x0)]
+    assert reduced == [x0 + sym.exp(x0)]
     # issue 4077
     n = -1 + 1/x
     e = n/x/(-n)**2 - 1/n/x
     assert cse(e, optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) == \
         ([], [0])
 
+# }}}
+
+
+# {{{ test_multiple_expressions
 
 def test_multiple_expressions():
     e1 = (x + y)*z
@@ -181,7 +203,7 @@ def test_multiple_expressions():
     rsubsts, _ = cse(reversed(l_))
     assert substs == rsubsts
     assert reduced == [x1, x1 + z, x0]
-    f = Function("f")
+    f = sym.Function("f")
     l_ = [f(x - z, y - z), x - z, y - z]
     substs, reduced = cse(l_)
     rsubsts, _ = cse(reversed(l_))
@@ -195,31 +217,46 @@ def test_multiple_expressions():
     assert cse([x*y, z + x*y, x*y*z + 3]) == \
         ([(x0, x*y)], [x0, z + x0, 3 + x0*z])
 
+# }}}
+
+
+# {{{ test_issue_4203
 
 def test_issue_4203():
-    assert cse(sin(x**x)/x**x) == ([(x0, x**x)], [sin(x0)/x0])
+    assert cse(sym.sin(x**x)/x**x) == ([(x0, x**x)], [sym.sin(x0)/x0])
+
+# }}}
 
 
+# {{{ test_dont_cse_subs
+
 def test_dont_cse_subs():
-    f = Function("f")
-    g = Function("g")
+    f = sym.Function("f")
+    g = sym.Function("g")
 
     name_val, (expr,) = cse(f(x+y).diff(x) + g(x+y).diff(x))
 
     assert name_val == []
 
+# }}}
+
+
+# {{{ test_dont_cse_derivative
 
 def test_dont_cse_derivative():
-    from sumpy.symbolic import Derivative
-    f = Function("f")
+    f = sym.Function("f")
 
-    deriv = Derivative(f(x+y), x)
+    deriv = sym.Derivative(f(x+y), x)
 
     name_val, (expr,) = cse(x + y + deriv)
 
     assert name_val == []
     assert expr == x + y + deriv
 
+# }}}
+
+
+# {{{ test_pow_invpow
 
 def test_pow_invpow():
     assert cse(1/x**2 + x**2) == \
@@ -228,39 +265,64 @@ def test_pow_invpow():
         ([(x0, x**2), (x1, 1/x0)], [x0 + x1*(x1 + 1)])
     assert cse(1/x**2 + (1 + 1/x**2)*x**2) == \
         ([(x0, x**2), (x1, 1/x0)], [x0*(x1 + 1) + x1])
-    assert cse(cos(1/x**2) + sin(1/x**2)) == \
-        ([(x0, x**(-2))], [sin(x0) + cos(x0)])
-    assert cse(cos(x**2) + sin(x**2)) == \
-        ([(x0, x**2)], [sin(x0) + cos(x0)])
+    assert cse(sym.cos(1/x**2) + sym.sin(1/x**2)) == \
+        ([(x0, x**(-2))], [sym.sin(x0) + sym.cos(x0)])
+    assert cse(sym.cos(x**2) + sym.sin(x**2)) == \
+        ([(x0, x**2)], [sym.sin(x0) + sym.cos(x0)])
     assert cse(y/(2 + x**2) + z/x**2/y) == \
         ([(x0, x**2)], [y/(x0 + 2) + z/(x0*y)])
-    assert cse(exp(x**2) + x**2*cos(1/x**2)) == \
-        ([(x0, x**2)], [x0*cos(1/x0) + exp(x0)])
+    assert cse(sym.exp(x**2) + x**2*sym.cos(1/x**2)) == \
+        ([(x0, x**2)], [x0*sym.cos(1/x0) + sym.exp(x0)])
     assert cse((1 + 1/x**2)/x**2) == \
         ([(x0, x**(-2))], [x0*(x0 + 1)])
     assert cse(x**(2*y) + x**(-2*y)) == \
         ([(x0, x**(2*y))], [x0 + 1/x0])
 
+# }}}
+
+
+# {{{ test_issue_4499
 
 @sympyonly
 def test_issue_4499():
     # previously, this gave 16 constants
     from sympy.abc import a, b
     from sympy import Tuple, S
-    B = Function("B")  # noqa
-    G = Function("G")  # noqa
-    t = Tuple(
-        *(a, a + S(1)/2, 2*a, b, 2*a - b + 1, (sqrt(z)/2)**(-2*a + 1)*B(2*a
-        - b, sqrt(z))*B(b - 1, sqrt(z))*G(b)*G(2*a - b + 1),  # noqa
-        sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b,  # noqa
-        sqrt(z))*G(b)*G(2*a - b + 1), sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b - 1,
-        sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1),
-        (sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b + 1,  # noqa
-        sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, S(1)/2, z/2, -b + 1, -2*a + b, # noqa
-        -2*a))  # noqa
+
+    B = sym.Function("B")   # noqa: N806
+    G = sym.Function("G")   # noqa: N806
+    t = Tuple(*(
+        a,
+        a + S(1)/2,
+        2*a,
+        b,
+        2*a - b + 1,
+        (sym.sqrt(z)/2)**(-2*a + 1)
+        * B(2*a-b, sym.sqrt(z))
+        * B(b - 1, sym.sqrt(z))*G(b)*G(2*a - b + 1),
+        sym.sqrt(z)*(sym.sqrt(z)/2)**(-2*a + 1)
+        * B(b, sym.sqrt(z))
+        * B(2*a - b, sym.sqrt(z))*G(b)*G(2*a - b + 1),
+        sym.sqrt(z)*(sym.sqrt(z)/2)**(-2*a + 1)
+        * B(b - 1, sym.sqrt(z))
+        * B(2*a - b + 1, sym.sqrt(z))*G(b)*G(2*a - b + 1),
+        (sym.sqrt(z)/2)**(-2*a + 1)
+        * B(b, sym.sqrt(z))
+        * B(2*a - b + 1, sym.sqrt(z))*G(b)*G(2*a - b + 1),
+        1,
+        0,
+        S(1)/2,
+        z/2,
+        -b + 1,
+        -2*a + b,
+        -2*a))
     c = cse(t)
     assert len(c[0]) == 11
 
+# }}}
+
+
+# {{{ test_issue_6169
 
 @sympyonly
 def test_issue_6169():
@@ -271,14 +333,17 @@ def test_issue_6169():
     # mechanism
     assert sub_post(sub_pre((-x - y)*z - x - y)) == -z*(x + y) - x - y
 
+# }}}
+
+
+# {{{ test_cse_Indexed
 
 @sympyonly
-def test_cse_Indexed():  # noqa
+def test_cse_indexed():
     from sympy import IndexedBase, Idx
     len_y = 5
     y = IndexedBase("y", shape=(len_y,))
     x = IndexedBase("x", shape=(len_y,))
-    Dy = IndexedBase("Dy", shape=(len_y-1,))  # noqa
     i = Idx("i", len_y-1)
 
     expr1 = (y[i+1]-y[i])/(x[i+1]-x[i])
@@ -286,9 +351,13 @@ def test_cse_Indexed():  # noqa
     replacements, reduced_exprs = cse([expr1, expr2])
     assert len(replacements) > 0
 
+# }}}
+
+
+# {{{ test_Piecewise
 
 @sympyonly
-def test_Piecewise():  # noqa
+def test_piecewise():
     from sympy import Piecewise, Eq
     f = Piecewise((-z + x*y, Eq(y, 0)), (-z - x*y, True))
     ans = cse(f)
@@ -296,41 +365,57 @@ def test_Piecewise():  # noqa
                   [Piecewise((x0+x1, Eq(y, 0)), (x0 - x1, True))])
     assert ans == actual_ans
 
+# }}}
+
+
+# {{{ test_name_conflict
 
 def test_name_conflict():
     z1 = x0 + y
     z2 = x2 + x3
-    l_ = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
+    l_ = [sym.cos(z1) + z1, sym.cos(z2) + z2, x0 + x2]
     substs, reduced = cse(l_)
     assert [e.subs(dict(substs)) for e in reduced] == l_
 
+# }}}
+
+
+# {{{ test_name_conflict_cust_symbols
 
 def test_name_conflict_cust_symbols():
     z1 = x0 + y
     z2 = x2 + x3
-    l_ = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
-    substs, reduced = cse(l_, symbols("x:10"))
+    l_ = [sym.cos(z1) + z1, sym.cos(z2) + z2, x0 + x2]
+    substs, reduced = cse(l_, sym.symbols("x:10"))
     assert [e.subs(dict(substs)) for e in reduced] == l_
 
+# }}}
+
+
+# {{{ test_symbols_exhausted_error
 
 def test_symbols_exhausted_error():
-    l_ = cos(x+y)+x+y+cos(w+y)+sin(w+y)
-    sym = [x, y, z]
+    l_ = sym.cos(x+y)+x+y+sym.cos(w+y)+sym.sin(w+y)
+    s = [x, y, z]
     with pytest.raises(ValueError):
-        print(cse(l_, symbols=sym))
+        logger.info("cse:\n%s", cse(l_, symbols=s))
+
+# }}}
+
 
+# {{{ test_issue_7840
 
 @sympyonly
 def test_issue_7840():
     # daveknippers' example
-    C393 = sympify(  # noqa
+    C393 = sym.sympify(     # noqa: N806
         "Piecewise((C391 - 1.65, C390 < 0.5), (Piecewise((C391 - 1.65, \
         C391 > 2.35), (C392, True)), True))"
     )
-    C391 = sympify(  # noqa
+    C391 = sym.sympify(     # noqa: N806
         "Piecewise((2.05*C390**(-1.03), C390 < 0.5), (2.5*C390**(-0.625), True))"
     )
-    C393 = C393.subs("C391",C391)  # noqa
+    C393 = C393.subs("C391", C391)   # noqa: N806
     # simple substitution
     sub = {}
     sub["C390"] = 0.703451854
@@ -345,7 +430,7 @@ def test_issue_7840():
     assert ss_answer == cse_answer
 
     # GitRay's example
-    expr = sympify(
+    expr = sym.sympify(
         "Piecewise((Symbol('ON'), Equality(Symbol('mode'), Symbol('ON'))), \
         (Piecewise((Piecewise((Symbol('OFF'), StrictLessThan(Symbol('x'), \
         Symbol('threshold'))), (Symbol('ON'), S.true)), Equality(Symbol('mode'), \
@@ -357,6 +442,10 @@ def test_issue_7840():
     # there should not be any replacements
     assert len(substitutions) < 1
 
+# }}}
+
+
+# {{{ test_recursive_matching
 
 def test_recursive_matching():
     assert cse([x+y, 2+x+y, x+y+z, 3+x+y+z]) == \
@@ -372,12 +461,16 @@ def test_recursive_matching():
     assert cse([2*x*x, x*x*y, x*x*y*w, x*x*y*w*x0, x*x*y*w*x2]) == \
         ([(x1, x**2), (x3, x1*y), (x4, w*x3)], [2*x1, x3, x4, x0*x4, x2*x4])
 
+# }}}
+
+
+# You can test individual routines by typing
+# $ python test_cse.py 'test_recursive_matching()'
 
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
     else:
-        from pytest import main
-        main([__file__])
+        pytest.main([__file__])
 
 # vim: fdm=marker
-- 
GitLab