From a61c6c916e0faaa11c9a91e28cdcf976333456e1 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Mon, 29 Mar 2021 16:43:45 -0500 Subject: [PATCH] support use_fft in fmm and add tests --- sumpy/expansion/local.py | 41 +++++++++++++++++++++++++++------------- sumpy/fmm.py | 30 ++++++++++++++++++----------- sumpy/tools.py | 2 +- test/test_fmm.py | 39 ++++++++++++++++++++------------------ 4 files changed, 69 insertions(+), 43 deletions(-) diff --git a/sumpy/expansion/local.py b/sumpy/expansion/local.py index a8dabbfc..ff970124 100644 --- a/sumpy/expansion/local.py +++ b/sumpy/expansion/local.py @@ -28,12 +28,19 @@ from sumpy.expansion import ( HelmholtzConformingVolumeTaylorExpansion, BiharmonicConformingVolumeTaylorExpansion) -from sumpy.tools import mi_increment_axis, matvec_toeplitz_upper_triangular +from sumpy.tools import (mi_increment_axis, matvec_toeplitz_upper_triangular, + fft_toeplitz_upper_triangular) from pytools import single_valued class LocalExpansionBase(ExpansionBase): - pass + def __init__(self, kernel, order, use_rscale=None, use_fft=False): + super().__init__(kernel, order, use_rscale) + self.use_fft = use_fft + + def with_kernel(self, kernel): + return type(self)(kernel, self.order, self.use_rscale, + use_fft=self.use_fft) import logging @@ -304,7 +311,11 @@ class VolumeTaylorLocalExpansionBase(LocalExpansionBase): add_to_sac(sac, coeff) # Do the matvec - output = matvec_toeplitz_upper_triangular(toeplitz_first_row, + if self.use_fft: + output = fft_toeplitz_upper_triangular(toeplitz_first_row, + derivatives_full) + else: + output = matvec_toeplitz_upper_triangular(toeplitz_first_row, derivatives_full) # Filter out the dummy rows and scale them for target @@ -451,8 +462,9 @@ class VolumeTaylorLocalExpansion( VolumeTaylorExpansion, VolumeTaylorLocalExpansionBase): - def __init__(self, kernel, order, use_rscale=None): - VolumeTaylorLocalExpansionBase.__init__(self, kernel, order, use_rscale) + def __init__(self, kernel, order, use_rscale=None, use_fft=False): + VolumeTaylorLocalExpansionBase.__init__(self, kernel, order, use_rscale, + use_fft) VolumeTaylorExpansion.__init__(self, kernel, order, use_rscale) @@ -460,8 +472,9 @@ class LaplaceConformingVolumeTaylorLocalExpansion( LaplaceConformingVolumeTaylorExpansion, VolumeTaylorLocalExpansionBase): - def __init__(self, kernel, order, use_rscale=None): - VolumeTaylorLocalExpansionBase.__init__(self, kernel, order, use_rscale) + def __init__(self, kernel, order, use_rscale=None, use_fft=False): + VolumeTaylorLocalExpansionBase.__init__(self, kernel, order, use_rscale, + use_fft) LaplaceConformingVolumeTaylorExpansion.__init__( self, kernel, order, use_rscale) @@ -470,8 +483,9 @@ class HelmholtzConformingVolumeTaylorLocalExpansion( HelmholtzConformingVolumeTaylorExpansion, VolumeTaylorLocalExpansionBase): - def __init__(self, kernel, order, use_rscale=None): - VolumeTaylorLocalExpansionBase.__init__(self, kernel, order, use_rscale) + def __init__(self, kernel, order, use_rscale=None, use_fft=False): + VolumeTaylorLocalExpansionBase.__init__(self, kernel, order, use_rscale, + use_fft) HelmholtzConformingVolumeTaylorExpansion.__init__( self, kernel, order, use_rscale) @@ -480,8 +494,9 @@ class BiharmonicConformingVolumeTaylorLocalExpansion( BiharmonicConformingVolumeTaylorExpansion, VolumeTaylorLocalExpansionBase): - def __init__(self, kernel, order, use_rscale=None): - VolumeTaylorLocalExpansionBase.__init__(self, kernel, order, use_rscale) + def __init__(self, kernel, order, use_rscale=None, use_fft=False): + VolumeTaylorLocalExpansionBase.__init__(self, kernel, order, use_rscale, + use_fft) BiharmonicConformingVolumeTaylorExpansion.__init__( self, kernel, order, use_rscale) @@ -581,7 +596,7 @@ class _FourierBesselLocalExpansion(LocalExpansionBase): class H2DLocalExpansion(_FourierBesselLocalExpansion): - def __init__(self, kernel, order, use_rscale=None): + def __init__(self, kernel, order, use_rscale=None, use_fft=False): from sumpy.kernel import HelmholtzKernel assert (isinstance(kernel.get_base_kernel(), HelmholtzKernel) and kernel.dim == 2) @@ -596,7 +611,7 @@ class H2DLocalExpansion(_FourierBesselLocalExpansion): class Y2DLocalExpansion(_FourierBesselLocalExpansion): - def __init__(self, kernel, order, use_rscale=None): + def __init__(self, kernel, order, use_rscale=None, use_fft=False): from sumpy.kernel import YukawaKernel assert (isinstance(kernel.get_base_kernel(), YukawaKernel) and kernel.dim == 2) diff --git a/sumpy/fmm.py b/sumpy/fmm.py index 71ea9b75..900fdad2 100644 --- a/sumpy/fmm.py +++ b/sumpy/fmm.py @@ -62,7 +62,7 @@ class SumpyExpansionWranglerCodeContainer: multipole_expansion_factory, local_expansion_factory, target_kernels, exclude_self=False, use_rscale=None, - strength_usage=None, source_kernels=None): + strength_usage=None, source_kernels=None, use_fft=False): """ :arg multipole_expansion_factory: a callable of a single argument (order) that returns a multipole expansion. @@ -80,6 +80,7 @@ class SumpyExpansionWranglerCodeContainer: self.exclude_self = exclude_self self.use_rscale = use_rscale self.strength_usage = strength_usage + self.use_fft = use_fft self.cl_context = cl_context @@ -94,7 +95,8 @@ class SumpyExpansionWranglerCodeContainer: @memoize_method def local_expansion(self, order): - return self.local_expansion_factory(order, self.use_rscale) + return self.local_expansion_factory(order, self.use_rscale, + use_fft=self.use_fft) @memoize_method def p2m(self, tgt_order): @@ -329,12 +331,18 @@ class SumpyExpansionWrangler: if base_kernel.is_translation_invariant: if translation_classes_data is None: from warnings import warn - warn( - "List 2 (multipole-to-local) translations will be " - "unoptimized. Supply a translation_classes_data argument to " - "the wrangler for optimized List 2.", - SumpyTranslationClassesDataNotSuppliedWarning, - + if self.code.use_fft: + raise NotImplementedError( + "FFT based List 2 (multipole-to-local) translations " + "without translation_classes_data argument is not " + "implemented. Supply a translation_classes_data argument " + "to the wrangler for optimized List 2.") + else: + warn( + "List 2 (multipole-to-local) translations will be " + "unoptimized. Supply a translation_classes_data argument " + "to the wrangler for optimized List 2.", + SumpyTranslationClassesDataNotSuppliedWarning, stacklevel=2) self.supports_optimized_m2l = False else: @@ -367,7 +375,7 @@ class SumpyExpansionWrangler: @memoize_method def local_expansions_level_starts(self): return self._expansions_level_starts( - lambda order: len(self.code.local_expansion_factory(order)), + lambda order: len(self.code.local_expansion(order)), level_starts=self.tree.level_start_box_nrs) @memoize_method @@ -378,8 +386,8 @@ class SumpyExpansionWrangler: @memoize_method def m2l_precomputed_exprs_level_starts(self): def order_to_size(order): - mpole_expn = self.code.multipole_expansion_factory(order) - local_expn = self.code.local_expansion_factory(order) + mpole_expn = self.code.multipole_expansion(order) + local_expn = self.code.local_expansion(order) return local_expn.m2l_global_precompute_nexpr(mpole_expn) return self._expansions_level_starts(order_to_size, diff --git a/sumpy/tools.py b/sumpy/tools.py index e2c72963..1959ee69 100644 --- a/sumpy/tools.py +++ b/sumpy/tools.py @@ -73,6 +73,7 @@ __doc__ = """ from pytools import memoize_method, memoize_in from pytools.tag import Tag, tag_dataclass import math +import numbers import numpy as np import sumpy.symbolic as sym @@ -996,7 +997,6 @@ class KernelCacheWrapper: @staticmethod def _allow_redundant_execution_of_knl_scaling(knl): from loopy.match import ObjTagged - from sumpy.tools import ScalingAssignmentTag return lp.add_inames_for_unused_hw_axes( knl, within=ObjTagged(ScalingAssignmentTag())) diff --git a/test/test_fmm.py b/test/test_fmm.py index fe159b48..5a081cc5 100644 --- a/test/test_fmm.py +++ b/test/test_fmm.py @@ -53,36 +53,39 @@ else: faulthandler.enable() -@pytest.mark.parametrize("knl, local_expn_class, mpole_expn_class, optimized_m2l", [ +@pytest.mark.parametrize( + "knl, local_expn_class, mpole_expn_class, optimized_m2l, use_fft", [ (LaplaceKernel(2), VolumeTaylorLocalExpansion, VolumeTaylorMultipoleExpansion, - False), + False, False), (LaplaceKernel(2), VolumeTaylorLocalExpansion, VolumeTaylorMultipoleExpansion, - True), + True, False), (LaplaceKernel(2), LaplaceConformingVolumeTaylorLocalExpansion, - LaplaceConformingVolumeTaylorMultipoleExpansion, True), + LaplaceConformingVolumeTaylorMultipoleExpansion, True, False), (LaplaceKernel(2), LaplaceConformingVolumeTaylorLocalExpansion, - LaplaceConformingVolumeTaylorMultipoleExpansion, False), + LaplaceConformingVolumeTaylorMultipoleExpansion, True, True), + (LaplaceKernel(2), LaplaceConformingVolumeTaylorLocalExpansion, + LaplaceConformingVolumeTaylorMultipoleExpansion, False, False), (LaplaceKernel(3), VolumeTaylorLocalExpansion, VolumeTaylorMultipoleExpansion, - False), + False, False), (LaplaceKernel(3), VolumeTaylorLocalExpansion, VolumeTaylorMultipoleExpansion, - True), + True, False), (LaplaceKernel(3), LaplaceConformingVolumeTaylorLocalExpansion, - LaplaceConformingVolumeTaylorMultipoleExpansion, True), + LaplaceConformingVolumeTaylorMultipoleExpansion, True, False), (LaplaceKernel(3), LaplaceConformingVolumeTaylorLocalExpansion, - LaplaceConformingVolumeTaylorMultipoleExpansion, False), + LaplaceConformingVolumeTaylorMultipoleExpansion, False, False), (HelmholtzKernel(2), VolumeTaylorLocalExpansion, VolumeTaylorMultipoleExpansion, - False), + False, False), (HelmholtzKernel(2), HelmholtzConformingVolumeTaylorLocalExpansion, - HelmholtzConformingVolumeTaylorMultipoleExpansion, False), - (HelmholtzKernel(2), H2DLocalExpansion, H2DMultipoleExpansion, False), + HelmholtzConformingVolumeTaylorMultipoleExpansion, False, False), + (HelmholtzKernel(2), H2DLocalExpansion, H2DMultipoleExpansion, False, False), (HelmholtzKernel(3), VolumeTaylorLocalExpansion, VolumeTaylorMultipoleExpansion, - False), + False, False), (HelmholtzKernel(3), HelmholtzConformingVolumeTaylorLocalExpansion, - HelmholtzConformingVolumeTaylorMultipoleExpansion, False), - (YukawaKernel(2), Y2DLocalExpansion, Y2DMultipoleExpansion, False), + HelmholtzConformingVolumeTaylorMultipoleExpansion, False, False), + (YukawaKernel(2), Y2DLocalExpansion, Y2DMultipoleExpansion, False, True), ]) def test_sumpy_fmm(ctx_factory, knl, local_expn_class, mpole_expn_class, - optimized_m2l): + optimized_m2l, use_fft): logging.basicConfig(level=logging.INFO) ctx = ctx_factory() @@ -192,7 +195,7 @@ def test_sumpy_fmm(ctx_factory, knl, local_expn_class, mpole_expn_class, ctx, partial(mpole_expn_class, knl), partial(local_expn_class, knl), - target_kernels) + target_kernels, use_fft=use_fft) wrangler = wcc.get_wrangler(queue, tree, dtype, fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order, @@ -357,7 +360,7 @@ def test_sumpy_fmm_timing_data_collection(ctx_factory): wcc = SumpyExpansionWranglerCodeContainer( ctx, partial(mpole_expn_class, knl), - partial(local_expn_class, knl), + partial(local_expn_class, knl, use_fft=use_fft), target_kernels) wrangler = wcc.get_wrangler(queue, tree, dtype, -- GitLab