From 3d2a57aadc55c57c457fdb26ed1e28410ee2f54c Mon Sep 17 00:00:00 2001
From: Isuru Fernando <isuruf@gmail.com>
Date: Tue, 9 Aug 2022 18:51:16 -0500
Subject: [PATCH] Add a fallback for FFT when VkFFT is not found or broken
 (#130)

* use loopy fft

* fix inverse

* Implement broadcasting FFT

* use enum for fft backend

* Add gh-129 link

* unit test for loopy_fft

* Unit test for loopy_fft and fix warnings

* don't use vkfft only if x86 mac

* Add missing import

* Fix platform.machine()
---
 sumpy/fmm.py       |  13 ++-
 sumpy/tools.py     | 281 +++++++++++++++++++++++++++++++++++++++------
 test/test_fmm.py   |  25 +++-
 test/test_tools.py |  26 ++++-
 4 files changed, 301 insertions(+), 44 deletions(-)

diff --git a/sumpy/fmm.py b/sumpy/fmm.py
index 348c5542..30ca3838 100644
--- a/sumpy/fmm.py
+++ b/sumpy/fmm.py
@@ -42,7 +42,7 @@ from sumpy import (
         M2LGenerateTranslationClassesDependentData,
         M2LPreprocessMultipole, M2LPostprocessLocal)
 from sumpy.tools import (to_complex_dtype, AggregateProfilingEvent,
-        run_opencl_fft, get_opencl_fft_app)
+        run_opencl_fft, get_opencl_fft_app, get_native_event)
 
 from typing import TypeVar, List, Union
 
@@ -180,9 +180,9 @@ class SumpyTreeIndependentDataForWrangler(TreeIndependentDataForWrangler):
                           strength_usage=self.strength_usage)
 
     @memoize_method
-    def opencl_fft_app(self, shape, dtype):
+    def opencl_fft_app(self, shape, dtype, inverse):
         with cl.CommandQueue(self.cl_context) as queue:
-            return get_opencl_fft_app(queue, shape, dtype)
+            return get_opencl_fft_app(queue, shape, dtype, inverse)
 
 # }}}
 
@@ -554,7 +554,8 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
     # }}}
 
     def run_opencl_fft(self, queue, input_vec, inverse, wait_for):
-        app = self.tree_indep.opencl_fft_app(input_vec.shape, input_vec.dtype)
+        app = self.tree_indep.opencl_fft_app(input_vec.shape, input_vec.dtype,
+            inverse)
         return run_opencl_fft(app, queue, input_vec, inverse, wait_for)
 
     def form_multipoles(self,
@@ -815,7 +816,7 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
                         self.run_opencl_fft(queue,
                             preprocessed_mpole_exps[lev],
                             inverse=False, wait_for=wait_for)
-                    wait_for.append(evt_fft.native_event)
+                    wait_for.append(get_native_event(evt_fft))
                     evt = AggregateProfilingEvent([evt, evt_fft])
 
                 preprocess_evts.append(evt)
@@ -876,7 +877,7 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
                         self.run_opencl_fft(queue,
                             target_locals_before_postprocessing_view,
                             inverse=True, wait_for=wait_for)
-                    wait_for.append(evt_fft.native_event)
+                    wait_for.append(get_native_event(evt_fft))
 
                 evt, _ = postprocess_local_kernel(
                     queue,
diff --git a/sumpy/tools.py b/sumpy/tools.py
index 4c6c5891..d3b61b8d 100644
--- a/sumpy/tools.py
+++ b/sumpy/tools.py
@@ -39,8 +39,14 @@ __doc__ = """
 from pytools import memoize_method
 from pytools.tag import Tag, tag_dataclass
 import numbers
+import warnings
+import os
+import sys
+import enum
+import platform
 from collections import defaultdict, namedtuple
 from pymbolic.mapper import WalkMapper
+import pymbolic
 
 import numpy as np
 import sumpy.symbolic as sym
@@ -937,6 +943,10 @@ def to_complex_dtype(dtype):
 ProfileGetter = namedtuple("ProfileGetter", "start, end")
 
 
+def get_native_event(evt):
+    return evt if isinstance(evt, cl.Event) else evt.native_event
+
+
 class AggregateProfilingEvent:
     """An object to hold a list of events and provides compatibility
     with some of the functionality of :class:`pyopencl.Event`.
@@ -944,10 +954,7 @@ class AggregateProfilingEvent:
     """
     def __init__(self, events):
         self.events = events[:]
-        if isinstance(events[-1], cl.Event):
-            self.native_event = events[-1]
-        else:
-            self.native_event = events[-1].native_event
+        self.native_event = get_native_event(events[-1])
 
     @property
     def profile(self):
@@ -976,52 +983,262 @@ class MarkerBasedProfilingEvent:
         return self.native_event.wait()
 
 
-def get_opencl_fft_app(queue, shape, dtype):
-    """Setup an object for out-of-place FFT on with given shape and dtype
-    on given queue. Only supports in-order queues.
-    """
+def loopy_fft(shape, inverse, complex_dtype, index_dtype=None,
+        name=None):
+    from pymbolic.algorithm import find_factors
+    from math import pi
+
+    sign = 1 if not inverse else -1
+    n = shape[-1]
+
+    m = n
+    factors = []
+    while m != 1:
+        N1, m = find_factors(m)  # noqa: N806
+        factors.append(N1)
+
+    nfft = n
+
+    broadcast_dims = tuple(pymbolic.var(f"j{d}") for d in range(len(shape) - 1))
+
+    domains = [
+        "{[i]: 0<=i<n}",
+        "{[i2]: 0<=i2<n}",
+    ]
+    domains += [f"{{[j{d}]: 0<=j{d}<{shape[d]} }}" for d in range(len(shape) - 1)]
+
+    x = pymbolic.var("x")
+    y = pymbolic.var("y")
+    i = pymbolic.var("i")
+    i2 = pymbolic.var("i2")
+    i3 = pymbolic.var("i3")
+
+    fixed_parameters = {"const": complex_dtype(sign*(-2j)*pi/n), "n": n}
+
+    insns = [
+        "exp_table[i] = exp(const * i) {id=exp_table}",
+        lp.Assignment(
+            assignee=x[(*broadcast_dims, i2)],
+            expression=y[(*broadcast_dims, i2)],
+            id="copy",
+            depends_on=frozenset(["exp_table"]),
+        ),
+    ]
+
+    for ilev, N1 in enumerate(list(reversed(factors))):  # noqa: N806
+        nfft //= N1
+        N2 = n // (nfft * N1)  # noqa: N806
+        if ilev == 0:
+            init_depends_on = "copy"
+        else:
+            init_depends_on = f"update_{ilev-1}"
+
+        temp = pymbolic.var("temp")
+        exp_table = pymbolic.var("exp_table")
+        i = pymbolic.var(f"i_{ilev}")
+        i2 = pymbolic.var(f"i2_{ilev}")
+        ifft = pymbolic.var(f"ifft_{ilev}")
+        iN1 = pymbolic.var(f"iN1_{ilev}")           # noqa: N806
+        iN1_sum = pymbolic.var(f"iN1_sum_{ilev}")   # noqa: N806
+        iN2 = pymbolic.var(f"iN2_{ilev}")           # noqa: N806
+        table_idx = pymbolic.var(f"table_idx_{ilev}")
+        exp = pymbolic.var(f"exp_{ilev}")
+
+        insns += [
+            lp.Assignment(
+                assignee=temp[i],
+                expression=x[(*broadcast_dims, i)],
+                id=f"copy_{ilev}",
+                depends_on=frozenset([init_depends_on]),
+            ),
+            lp.Assignment(
+                assignee=x[(*broadcast_dims, i2)],
+                expression=0,
+                id=f"reset_{ilev}",
+                depends_on=frozenset([f"copy_{ilev}"])),
+            lp.Assignment(
+                assignee=table_idx,
+                expression=nfft*iN1_sum*(iN2 + N2*iN1),
+                id=f"idx_{ilev}",
+                depends_on=frozenset([f"reset_{ilev}"]),
+                temp_var_type=lp.Optional(np.uint32)),
+            lp.Assignment(
+                assignee=exp,
+                expression=exp_table[table_idx % n],
+                id=f"exp_{ilev}",
+                depends_on=frozenset([f"idx_{ilev}"]),
+                within_inames=frozenset(map(lambda x: x.name,
+                    [*broadcast_dims, iN1_sum, iN1, iN2])),
+                temp_var_type=lp.Optional(complex_dtype)),
+            lp.Assignment(
+                assignee=x[(*broadcast_dims, ifft + nfft * (iN1*N2 + iN2))],
+                expression=(x[(*broadcast_dims, ifft + nfft*(iN1*N2 + iN2))]
+                    + exp * temp[ifft + nfft * (iN2*N1 + iN1_sum)]),
+                id=f"update_{ilev}",
+                depends_on=frozenset([f"exp_{ilev}"])),
+        ]
+
+        domains += [
+            f"[ifft_{ilev}]: 0<=ifft_{ilev}<{nfft}",
+            f"[iN1_{ilev}]: 0<=iN1_{ilev}<{N1}",
+            f"[iN1_sum_{ilev}]: 0<=iN1_sum_{ilev}<{N1}",
+            f"[iN2_{ilev}]: 0<=iN2_{ilev}<{N2}",
+            f"[i_{ilev}]: 0<=i_{ilev}<{n}",
+            f"[i2_{ilev}]: 0<=i2_{ilev}<{n}",
+        ]
+
+    for idom, dom in enumerate(domains):
+        if not dom.startswith("{"):
+            domains[idom] = "{" + dom + "}"
+
+    kernel_data = [
+        lp.GlobalArg("x", shape=shape, is_input=False, is_output=True,
+            dtype=complex_dtype),
+        lp.GlobalArg("y", shape=shape, is_input=True, is_output=False,
+            dtype=complex_dtype),
+        lp.TemporaryVariable("exp_table", shape=(n,),
+            dtype=complex_dtype),
+        lp.TemporaryVariable("temp", shape=(n,),
+            dtype=complex_dtype),
+        ...
+    ]
+
+    if n == 1:
+        domains = domains[2:]
+        insns = [
+            lp.Assignment(
+                assignee=x[(*broadcast_dims, 0)],
+                expression=y[(*broadcast_dims, 0)],
+            ),
+        ]
+        kernel_data = kernel_data[:2]
+    elif inverse:
+        domains += ["{[i3]: 0<=i3<n}"]
+        insns += [
+            lp.Assignment(
+                assignee=x[(*broadcast_dims, i3)],
+                expression=x[(*broadcast_dims, i3)] / n,
+                depends_on=frozenset([f"update_{len(factors) - 1}"]),
+            ),
+        ]
+
+    if name is None:
+        if inverse:
+            name = f"ifft_{n}"
+        else:
+            name = f"fft_{n}"
+
+    knl = lp.make_kernel(
+        domains, insns,
+        kernel_data=kernel_data,
+        name=name,
+        fixed_parameters=fixed_parameters,
+        lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
+        index_dtype=index_dtype,
+    )
+
+    if broadcast_dims:
+        knl = lp.split_iname(knl, "j0", 32, inner_tag="l.0", outer_tag="g.0")
+        knl = lp.add_inames_for_unused_hw_axes(knl)
+
+    return knl
+
+
+class FFTBackend(enum.Enum):
+    pyvkfft = 1
+    loopy = 2
+
+
+def _get_fft_backend(queue) -> FFTBackend:
+    env_val = os.environ.get("SUMPY_FFT_BACKEND", None)
+    if env_val:
+        if env_val not in ["loopy", "pyvkfft"]:
+            raise ValueError("Expected 'loopy' or 'pyvkfft' for SUMPY_FFT_BACKEND. "
+                   f"Found {env_val}.")
+        return FFTBackend[env_val]
+
+    try:
+        import pyvkfft.opencl  # noqa: F401
+    except ImportError:
+        warnings.warn("VkFFT not found. FFT runs will be slower.")
+        return FFTBackend.loopy
+
     if queue.properties & cl.command_queue_properties.OUT_OF_ORDER_EXEC_MODE_ENABLE:
-        raise RuntimeError("VkFFT does not support out of order queues yet.")
+        warnings.warn("VkFFT does not support out of order queues yet. "
+            "Falling back to slower implementation.")
+        return FFTBackend.loopy
+
+    if (sys.platform == "darwin"
+            and platform.machine() == "x86_64"
+            and queue.context.devices[0].platform.name
+            == "Portable Computing Language"):
+        warnings.warn("Pocl miscompiles some VkFFT kernels. "
+            "See https://github.com/inducer/sumpy/issues/129. "
+            "Falling back to slower implementation.")
+        return FFTBackend.loopy
+
+    return FFTBackend.pyvkfft
+
 
+def get_opencl_fft_app(queue, shape, dtype, inverse):
+    """Setup an object for out-of-place FFT on with given shape and dtype
+    on given queue.
+    """
     assert dtype.type in (np.float32, np.float64, np.complex64,
                            np.complex128)
 
-    from pyvkfft.opencl import VkFFTApp
-    app = VkFFTApp(shape=shape, dtype=dtype, queue=queue, ndim=1, inplace=False)
-    return app
+    backend = _get_fft_backend(queue)
+
+    if backend == FFTBackend.loopy:
+        return loopy_fft(shape, inverse=inverse, complex_dtype=dtype.type), backend
+    elif backend == FFTBackend.pyvkfft:
+        from pyvkfft.opencl import VkFFTApp
+        app = VkFFTApp(shape=shape, dtype=dtype, queue=queue, ndim=1, inplace=False)
+        return app, backend
+    else:
+        raise RuntimeError(f"Unsupported FFT backend {backend}")
 
 
-def run_opencl_fft(vkfft_app, queue, input_vec, inverse=False, wait_for=None):
+def run_opencl_fft(fft_app, queue, input_vec, inverse=False, wait_for=None):
     """Runs an FFT on input_vec and returns a :class:`MarkerBasedProfilingEvent`
     that indicate the end and start of the operations carried out and the output
     vector.
     Only supports in-order queues.
     """
-    if wait_for is None:
-        wait_for = []
+    app, backend = fft_app
 
-    start_evt = cl.enqueue_marker(queue, wait_for=wait_for[:])
+    if backend == FFTBackend.loopy:
+        evt, (output_vec,) = app(queue, y=input_vec, wait_for=wait_for)
+        return (evt, output_vec)
+    elif backend == FFTBackend.pyvkfft:
+        if wait_for is None:
+            wait_for = []
 
-    if vkfft_app.inplace:
-        raise RuntimeError("inplace fft is not supported")
-    else:
-        output_vec = cla.empty_like(input_vec, queue)
+        start_evt = cl.enqueue_marker(queue, wait_for=wait_for[:])
 
-    # FIXME: use the public API once https://github.com/vincefn/pyvkfft/pull/17 is in
-    from pyvkfft.opencl import _vkfft_opencl
-    if inverse:
-        meth = _vkfft_opencl.ifft
-    else:
-        meth = _vkfft_opencl.fft
+        if app.inplace:
+            raise RuntimeError("inplace fft is not supported")
+        else:
+            output_vec = cla.empty_like(input_vec, queue)
 
-    meth(vkfft_app.app, int(input_vec.data.int_ptr), int(output_vec.data.int_ptr),
-        int(queue.int_ptr))
+        # FIXME: use the public API once
+        # https://github.com/vincefn/pyvkfft/pull/17 is in
+        from pyvkfft.opencl import _vkfft_opencl
+        if inverse:
+            meth = _vkfft_opencl.ifft
+        else:
+            meth = _vkfft_opencl.fft
+
+        meth(app.app, int(input_vec.data.int_ptr),
+            int(output_vec.data.int_ptr), int(queue.int_ptr))
 
-    end_evt = cl.enqueue_marker(queue, wait_for=[start_evt])
-    output_vec.add_event(end_evt)
+        end_evt = cl.enqueue_marker(queue, wait_for=[start_evt])
+        output_vec.add_event(end_evt)
 
-    return (MarkerBasedProfilingEvent(end_event=end_evt, start_event=start_evt),
-        output_vec)
+        return (MarkerBasedProfilingEvent(end_event=end_evt, start_event=start_evt),
+            output_vec)
+    else:
+        raise RuntimeError(f"Unsupported FFT backend {backend}")
 
 # }}}
 
diff --git a/test/test_fmm.py b/test/test_fmm.py
index a63b64fd..58ec8b3a 100644
--- a/test/test_fmm.py
+++ b/test/test_fmm.py
@@ -22,6 +22,8 @@ THE SOFTWARE.
 
 
 import sys
+import os
+from unittest.mock import patch
 import numpy as np
 import numpy.linalg as la
 import pyopencl as cl
@@ -55,8 +57,9 @@ else:
     faulthandler.enable()
 
 
-@pytest.mark.parametrize("use_translation_classes, use_fft",
-    [(False, False), (True, False), (True, True)])
+@pytest.mark.parametrize("use_translation_classes, use_fft, fft_backend",
+    [(False, False, None), (True, False, None), (True, True, "loopy"),
+     (True, True, "pyvkfft")])
 @pytest.mark.parametrize(
         ("knl", "local_expn_class", "mpole_expn_class",
         "order_varies_with_level"), [
@@ -82,7 +85,8 @@ else:
                 False),
             ])
 def test_sumpy_fmm(ctx_factory, knl, local_expn_class, mpole_expn_class,
-        order_varies_with_level, use_translation_classes, use_fft):
+        order_varies_with_level, use_translation_classes, use_fft,
+        fft_backend):
     logging.basicConfig(level=logging.INFO)
 
     if local_expn_class == VolumeTaylorLocalExpansion and use_fft:
@@ -91,6 +95,21 @@ def test_sumpy_fmm(ctx_factory, knl, local_expn_class, mpole_expn_class,
     if local_expn_class in [H2DLocalExpansion, Y2DLocalExpansion] and use_fft:
         pytest.skip("Fourier/Bessel based expansions with FFT is not supported yet.")
 
+    if use_fft:
+        with patch.dict(os.environ, {"SUMPY_FFT_BACKEND": fft_backend}):
+            _test_sumpy_fmm(ctx_factory, knl, local_expn_class, mpole_expn_class,
+                order_varies_with_level, use_translation_classes, use_fft,
+                fft_backend)
+    else:
+        _test_sumpy_fmm(ctx_factory, knl, local_expn_class, mpole_expn_class,
+            order_varies_with_level, use_translation_classes, use_fft,
+            fft_backend)
+
+
+def _test_sumpy_fmm(ctx_factory, knl, local_expn_class, mpole_expn_class,
+        order_varies_with_level, use_translation_classes, use_fft,
+        fft_backend):
+
     ctx = ctx_factory()
     queue = cl.CommandQueue(ctx)
 
diff --git a/test/test_tools.py b/test/test_tools.py
index a3d2a9f1..08a90689 100644
--- a/test/test_tools.py
+++ b/test/test_tools.py
@@ -25,11 +25,18 @@ logger = logging.getLogger(__name__)
 
 import sumpy.symbolic as sym
 from sumpy.tools import (fft_toeplitz_upper_triangular,
-    matvec_toeplitz_upper_triangular)
+    matvec_toeplitz_upper_triangular, loopy_fft, fft)
 import numpy as np
 
+import pyopencl as cl
+import pyopencl.array as cla
+from pyopencl.tools import (  # noqa
+        pytest_generate_tests_for_pyopencl as pytest_generate_tests)
 
-def test_fft():
+import pytest
+
+
+def test_matvec_fft():
     k = 5
     v = np.random.rand(k)
     x = np.random.rand(k)
@@ -41,7 +48,7 @@ def test_fft():
         assert abs(fft[i] - matvec[i]) < 1e-14
 
 
-def test_fft_small_floats():
+def test_matvec_fft_small_floats():
     k = 5
     v = sym.make_sym_vector("v", k)
     x = sym.make_sym_vector("x", k)
@@ -52,3 +59,16 @@ def test_fft_small_floats():
             if f == 0:
                 continue
             assert abs(f) > 1e-10
+
+
+@pytest.mark.parametrize("size", [1, 2, 7, 10, 30, 210])
+def test_fft(ctx_factory, size):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+    inp = np.arange(size, dtype=np.complex64)
+    inp_dev = cla.to_device(queue, inp)
+    out = fft(inp)
+
+    fft_func = loopy_fft(inp.shape, inverse=False, complex_dtype=inp.dtype.type)
+    evt, (out_dev,) = fft_func(queue, y=inp_dev)
+    assert np.allclose(out_dev.get(), out)
-- 
GitLab