From 5cd4939dd228f99f38500b3351dcdad33fe903c5 Mon Sep 17 00:00:00 2001 From: Isuru Fernando <isuruf@gmail.com> Date: Fri, 6 Jan 2023 04:43:58 +0530 Subject: [PATCH] Optimize M2L for GPU (#138) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Optimize M2L for GPU * Move icoeff_tgt to top level iname * Fix substitution * use loopy branch * remove unused imports * go back to loopy main * Reduce diff * move all optimizations to m2l_translation * Remove extraneous FIXME Co-authored-by: Andreas Klöckner <inform@tiker.net> --- sumpy/e2e.py | 9 +++++---- sumpy/expansion/m2l.py | 25 +++++++++++++++++++++++-- test/test_fmm.py | 11 +++++++---- 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/sumpy/e2e.py b/sumpy/e2e.py index 9c96ac56..89df2789 100644 --- a/sumpy/e2e.py +++ b/sumpy/e2e.py @@ -458,7 +458,7 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR): end tgt_expansions[tgt_ibox - tgt_base_ibox, icoeff_tgt] = \ tgt_expansion[icoeff_tgt] \ - {dep=update_coeffs, dup=icoeff_tgt} + {dep=update_coeffs, dup=icoeff_tgt,id=write_e2e} end """], [ @@ -497,7 +497,8 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR): m2l_translation_classes_dependent_ndata), ncoeff_tgt=ncoeff_tgt, ncoeff_src=ncoeff_src), - lang_version=MOST_RECENT_LANGUAGE_VERSION + lang_version=MOST_RECENT_LANGUAGE_VERSION, + silenced_warnings="write_race(write_e2e*)", ) loopy_knl = lp.merge([translation_knl, loopy_knl]) @@ -514,8 +515,8 @@ class M2LUsingTranslationClassesDependentData(E2EFromCSR): def get_optimized_kernel(self, result_dtype): knl = self.get_kernel(result_dtype) - # FIXME - knl = lp.split_iname(knl, "itgt_box", 16, outer_tag="g.0") + knl = self.tgt_expansion.m2l_translation.optimize_loopy_kernel( + knl, self.tgt_expansion, self.src_expansion) return knl diff --git a/sumpy/expansion/m2l.py b/sumpy/expansion/m2l.py index 3d4b89e7..1845f1d3 100644 --- a/sumpy/expansion/m2l.py +++ b/sumpy/expansion/m2l.py @@ -246,8 +246,11 @@ class M2LTranslationBase(ABC): def update_persistent_hash(self, key_hash, key_builder): key_hash.update(type(self).__name__.encode("utf8")) -# }}} M2LTranslationBase + def optimize_loopy_kernel(self, knl, tgt_expansion, src_expansion): + return lp.tag_inames(knl, dict(itgt_box="g.0")) + +# }}} M2LTranslationBase # {{{ VolumeTaylorM2LTranslation @@ -740,8 +743,26 @@ class VolumeTaylorM2LWithFFT(VolumeTaylorM2LWithPreprocessedMultipoles): return super().postprocess_local_exprs(tgt_expansion, src_expansion, m2l_result, src_rscale, tgt_rscale, sac) -# }}} VolumeTaylorM2LWithFFT + def optimize_loopy_kernel(self, knl, tgt_expansion, src_expansion): + # Transform the kernel so that icoeff_tgt and its duplicates + # become the outermost iname + inames = knl.default_entrypoint.all_inames() + knl = lp.rename_inames(knl, + [iname for iname in inames if "icoeff_tgt" in iname], + "icoeff_tgt", existing_ok=True) + knl = lp.add_inames_to_insn(knl, "icoeff_tgt", None) + # unprivatize icoeff_tgt because it is the outermost iname + knl = lp.unprivatize_temporaries_with_inames(knl, + {"icoeff_tgt"}, {"tgt_expansion"}) + + knl = lp.split_iname(knl, "icoeff_tgt", 32, inner_iname="inner", + inner_tag="l.0") + knl = lp.tag_inames(knl, dict(itgt_box="g.0")) + return knl + + +# }}} VolumeTaylorM2LWithFFT # {{{ FourierBesselM2LTranslation diff --git a/test/test_fmm.py b/test/test_fmm.py index f447999c..d00b5b45 100644 --- a/test/test_fmm.py +++ b/test/test_fmm.py @@ -704,10 +704,13 @@ def test_sumpy_target_point_multiplier(actx_factory, deriv_axes, visualize=False # }}} -# You can test individual routines by typing -# $ python test_fmm.py 'test_sumpy_fmm(_acf, LaplaceKernel(2), -# VolumeTaylorLocalExpansion, VolumeTaylorMultipoleExpansion, False, False, -# visualize=True)' +""" +You can test individual routines by typing +$ python test/test_fmm.py 'test_sumpy_fmm(_acf, LaplaceKernel(2), + VolumeTaylorLocalExpansion, VolumeTaylorMultipoleExpansion, + order_varies_with_level=False, use_translation_classes=True, use_fft=True, + fft_backend="pyvkfft", visualize=True)' +""" if __name__ == "__main__": if len(sys.argv) > 1: -- GitLab