From 5cd4939dd228f99f38500b3351dcdad33fe903c5 Mon Sep 17 00:00:00 2001
From: Isuru Fernando <isuruf@gmail.com>
Date: Fri, 6 Jan 2023 04:43:58 +0530
Subject: [PATCH] Optimize M2L for GPU (#138)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* Optimize M2L for GPU

* Move icoeff_tgt to top level iname

* Fix substitution

* use loopy branch

* remove unused imports

* go back to loopy main

* Reduce diff

* move all optimizations to m2l_translation

* Remove extraneous FIXME

Co-authored-by: Andreas Klöckner <inform@tiker.net>
---
 sumpy/e2e.py           |  9 +++++----
 sumpy/expansion/m2l.py | 25 +++++++++++++++++++++++--
 test/test_fmm.py       | 11 +++++++----
 3 files changed, 35 insertions(+), 10 deletions(-)

diff --git a/sumpy/e2e.py b/sumpy/e2e.py
index 9c96ac56..89df2789 100644
--- a/sumpy/e2e.py
+++ b/sumpy/e2e.py
@@ -458,7 +458,7 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR):
                     end
                     tgt_expansions[tgt_ibox - tgt_base_ibox, icoeff_tgt] = \
                             tgt_expansion[icoeff_tgt] \
-                            {dep=update_coeffs, dup=icoeff_tgt}
+                            {dep=update_coeffs, dup=icoeff_tgt,id=write_e2e}
                 end
                 """],
                 [
@@ -497,7 +497,8 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR):
                             m2l_translation_classes_dependent_ndata),
                         ncoeff_tgt=ncoeff_tgt,
                         ncoeff_src=ncoeff_src),
-                lang_version=MOST_RECENT_LANGUAGE_VERSION
+                lang_version=MOST_RECENT_LANGUAGE_VERSION,
+                silenced_warnings="write_race(write_e2e*)",
                 )
 
         loopy_knl = lp.merge([translation_knl, loopy_knl])
@@ -514,8 +515,8 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR):
 
     def get_optimized_kernel(self, result_dtype):
         knl = self.get_kernel(result_dtype)
-        # FIXME
-        knl = lp.split_iname(knl, "itgt_box", 16, outer_tag="g.0")
+        knl = self.tgt_expansion.m2l_translation.optimize_loopy_kernel(
+                knl, self.tgt_expansion, self.src_expansion)
 
         return knl
 
diff --git a/sumpy/expansion/m2l.py b/sumpy/expansion/m2l.py
index 3d4b89e7..1845f1d3 100644
--- a/sumpy/expansion/m2l.py
+++ b/sumpy/expansion/m2l.py
@@ -246,8 +246,11 @@ class M2LTranslationBase(ABC):
     def update_persistent_hash(self, key_hash, key_builder):
         key_hash.update(type(self).__name__.encode("utf8"))
 
-# }}} M2LTranslationBase
+    def optimize_loopy_kernel(self, knl, tgt_expansion, src_expansion):
+        return lp.tag_inames(knl, dict(itgt_box="g.0"))
+
 
+# }}} M2LTranslationBase
 
 # {{{ VolumeTaylorM2LTranslation
 
@@ -740,8 +743,26 @@ class VolumeTaylorM2LWithFFT(VolumeTaylorM2LWithPreprocessedMultipoles):
         return super().postprocess_local_exprs(tgt_expansion,
             src_expansion, m2l_result, src_rscale, tgt_rscale, sac)
 
-# }}} VolumeTaylorM2LWithFFT
+    def optimize_loopy_kernel(self, knl, tgt_expansion, src_expansion):
+        # Transform the kernel so that icoeff_tgt and its duplicates
+        # become the outermost iname
+        inames = knl.default_entrypoint.all_inames()
+        knl = lp.rename_inames(knl,
+            [iname for iname in inames if "icoeff_tgt" in iname],
+            "icoeff_tgt", existing_ok=True)
+        knl = lp.add_inames_to_insn(knl, "icoeff_tgt", None)
 
+        # unprivatize icoeff_tgt because it is the outermost iname
+        knl = lp.unprivatize_temporaries_with_inames(knl,
+                {"icoeff_tgt"}, {"tgt_expansion"})
+
+        knl = lp.split_iname(knl, "icoeff_tgt", 32, inner_iname="inner",
+                inner_tag="l.0")
+        knl = lp.tag_inames(knl, dict(itgt_box="g.0"))
+        return knl
+
+
+# }}} VolumeTaylorM2LWithFFT
 
 # {{{ FourierBesselM2LTranslation
 
diff --git a/test/test_fmm.py b/test/test_fmm.py
index f447999c..d00b5b45 100644
--- a/test/test_fmm.py
+++ b/test/test_fmm.py
@@ -704,10 +704,13 @@ def test_sumpy_target_point_multiplier(actx_factory, deriv_axes, visualize=False
 # }}}
 
 
-# You can test individual routines by typing
-# $ python test_fmm.py 'test_sumpy_fmm(_acf, LaplaceKernel(2),
-#       VolumeTaylorLocalExpansion, VolumeTaylorMultipoleExpansion, False, False,
-#       visualize=True)'
+"""
+You can test individual routines by typing
+$ python test/test_fmm.py 'test_sumpy_fmm(_acf, LaplaceKernel(2),
+      VolumeTaylorLocalExpansion, VolumeTaylorMultipoleExpansion,
+      order_varies_with_level=False, use_translation_classes=True, use_fft=True,
+      fft_backend="pyvkfft", visualize=True)'
+"""
 
 if __name__ == "__main__":
     if len(sys.argv) > 1:
-- 
GitLab