From 6a728f2780305f1c9f28fbcec043359c4b7973d0 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 5 Oct 2017 00:56:43 -0500 Subject: [PATCH] Add kernel+args arguments to fmm_level_to_order --- sumpy/fmm.py | 19 ++++++++++++++----- test/test_fmm.py | 4 ++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/sumpy/fmm.py b/sumpy/fmm.py index 82fb0271..79830085 100644 --- a/sumpy/fmm.py +++ b/sumpy/fmm.py @@ -77,6 +77,11 @@ class SumpyExpansionWranglerCodeContainer(object): self.cl_context = cl_context + @memoize_method + def get_base_kernel(self): + from pytools import single_valued + return single_valued(k.get_base_kernel() for k in self.out_kernels) + @memoize_method def multipole_expansion(self, order): return self.multipole_expansion_factory(order, self.use_rscale) @@ -183,17 +188,21 @@ class SumpyExpansionWrangler(object): self.dtype = dtype - if not callable(fmm_level_to_order): - raise TypeError("fmm_level_to_order not passed") - self.level_orders = [ - fmm_level_to_order(tree, lev) for lev in range(tree.nlevels)] - if kernel_extra_kwargs is None: kernel_extra_kwargs = {} if self_extra_kwargs is None: self_extra_kwargs = {} + if not callable(fmm_level_to_order): + raise TypeError("fmm_level_to_order not passed") + + base_kernel = code_container.get_base_kernel() + kernel_arg_set = frozenset(kernel_extra_kwargs.items()) + self.level_orders = [ + fmm_level_to_order(base_kernel, kernel_arg_set, tree, lev) + for lev in range(tree.nlevels)] + self.source_extra_kwargs = source_extra_kwargs self.kernel_extra_kwargs = kernel_extra_kwargs self.self_extra_kwargs = self_extra_kwargs diff --git a/test/test_fmm.py b/test/test_fmm.py index 179994d5..bd1fe848 100644 --- a/test/test_fmm.py +++ b/test/test_fmm.py @@ -210,7 +210,7 @@ def test_sumpy_fmm(ctx_getter, knl, local_expn_class, mpole_expn_class): partial(local_expn_class, knl), out_kernels) wrangler = wcc.get_wrangler(queue, tree, dtype, - fmm_level_to_order=lambda tree, lev: order, + fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order, kernel_extra_kwargs=extra_kwargs) from boxtree.fmm import drive_fmm @@ -283,7 +283,7 @@ def test_sumpy_fmm_exclude_self(ctx_getter): exclude_self=True) wrangler = wcc.get_wrangler(queue, tree, dtype, - fmm_level_to_order=lambda tree, lev: order, + fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order, self_extra_kwargs=self_extra_kwargs) from boxtree.fmm import drive_fmm -- GitLab