diff --git a/sumpy/fmm.py b/sumpy/fmm.py index 828497197a39c3576ab592fb5c610e156a80982c..52cb9a603c47f3c229097adfd62df4a56791578c 100644 --- a/sumpy/fmm.py +++ b/sumpy/fmm.py @@ -235,54 +235,6 @@ class SumpyTimingFuture: # }}} -# {{{ translation classes data - -class SumpyTranslationClassesData: - """A class for building and storing additional, optional data for - precomputation of translation classes passed to the expansion wrangler.""" - - def __init__(self, queue, trav, is_translation_per_level=True): - # FIXME: Queues should not be part of data. - self.queue = queue - self.trav = trav - self.tree = trav.tree - self.is_translation_per_level = is_translation_per_level - - @property - @memoize_method - def translation_classes_builder(self): - from boxtree.translation_classes import TranslationClassesBuilder - return TranslationClassesBuilder(self.queue.context) - - @memoize_method - def build_translation_classes_lists(self): - return self.translation_classes_builder(self.queue, self.trav, self.tree, - is_translation_per_level=self.is_translation_per_level)[0] - - @memoize_method - def m2l_translation_classes_lists(self): - return (self - .build_translation_classes_lists() - .from_sep_siblings_translation_classes) - - @memoize_method - def m2l_translation_vectors(self): - return (self - .build_translation_classes_lists() - .from_sep_siblings_translation_class_to_distance_vector) - - def m2l_translation_classes_level_starts(self): - return (self - .build_translation_classes_lists() - .from_sep_siblings_translation_classes_level_starts) - - -class SumpyTranslationClassesDataNotSuppliedWarning(UserWarning): - pass - -# }}} - - # {{{ expansion wrangler class SumpyExpansionWrangler(ExpansionWranglerInterface): @@ -315,7 +267,8 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface): kernel_extra_kwargs=None, self_extra_kwargs=None, translation_classes_data=None, - preprocessed_mpole_dtype=None): + preprocessed_mpole_dtype=None, + *, _disable_translation_classes=False): super().__init__(tree_indep, traversal) self.issued_timing_data_warning = False @@ -353,27 +306,18 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface): self.extra_kwargs = source_extra_kwargs.copy() self.extra_kwargs.update(self.kernel_extra_kwargs) - if base_kernel.is_translation_invariant: - if translation_classes_data is None: - from warnings import warn - if self.tree_indep.use_fft_for_m2l: - raise NotImplementedError( - "FFT based List 2 (multipole-to-local) translations " - "without translation_classes_data argument is not " - "implemented. Supply a translation_classes_data argument " - "to the wrangler for optimized List 2.") - else: - warn( - "List 2 (multipole-to-local) translations will be " - "unoptimized. Supply a translation_classes_data argument " - "to the wrangler for optimized List 2.", - SumpyTranslationClassesDataNotSuppliedWarning, - stacklevel=2) - self.supports_translation_classes = False - else: - self.supports_translation_classes = True - else: + if _disable_translation_classes or not base_kernel.is_translation_invariant: self.supports_translation_classes = False + else: + if translation_classes_data is None: + with cl.CommandQueue(self.tree_indep.cl_context) as queue: + from boxtree.translation_classes import TranslationClassesBuilder + translation_classes_builder = TranslationClassesBuilder( + queue.context) + translation_classes_data, _ = translation_classes_builder( + queue, traversal, self.tree, + is_translation_per_level=True) + self.supports_translation_classes = True self.translation_classes_data = translation_classes_data self.use_fft_for_m2l = self.tree_indep.use_fft_for_m2l @@ -407,7 +351,7 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface): def m2l_translation_class_level_start_box_nrs(self): with cl.CommandQueue(self.tree_indep.cl_context) as queue: data = self.translation_classes_data - return data.m2l_translation_classes_level_starts().get(queue) + return data.from_sep_siblings_translation_classes_level_starts.get(queue) @memoize_method def m2l_translation_classes_dependent_data_level_starts(self): @@ -726,8 +670,9 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface): if ntranslation_classes == 0: continue + data = self.translation_classes_data m2l_translation_vectors = ( - self.translation_classes_data.m2l_translation_vectors()) + data.from_sep_siblings_translation_class_to_distance_vector) evt, _ = precompute_kernel( queue, @@ -767,7 +712,7 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface): kwargs_for_m2l["translation_classes_level_start"] = \ translation_classes_level_start kwargs_for_m2l["m2l_translation_classes_lists"] = \ - self.translation_classes_data.m2l_translation_classes_lists() + self.translation_classes_data.from_sep_siblings_translation_classes def multipole_to_local(self, level_start_target_box_nrs, diff --git a/test/test_fmm.py b/test/test_fmm.py index 3c5ec6d7c025cd96b281b5d8f4b88d34ff7b2a24..8e7bc32994ccf7bedb57a909280550922d6c7497 100644 --- a/test/test_fmm.py +++ b/test/test_fmm.py @@ -38,12 +38,9 @@ from sumpy.expansion.local import ( LinearPDEConformingVolumeTaylorLocalExpansion) from sumpy.fmm import ( SumpyTreeIndependentDataForWrangler, - SumpyExpansionWrangler, - SumpyTranslationClassesData, - SumpyTranslationClassesDataNotSuppliedWarning) + SumpyExpansionWrangler) import pytest -import warnings import logging logger = logging.getLogger(__name__) @@ -57,7 +54,7 @@ else: faulthandler.enable() -@pytest.mark.parametrize("optimized_m2l, use_fft", +@pytest.mark.parametrize("use_translation_classes, use_fft", [(False, False), (True, False), (True, True)]) @pytest.mark.parametrize( ("knl", "local_expn_class", "mpole_expn_class", @@ -84,7 +81,7 @@ else: False), ]) def test_sumpy_fmm(ctx_factory, knl, local_expn_class, mpole_expn_class, - order_varies_with_level, optimized_m2l, use_fft): + order_varies_with_level, use_translation_classes, use_fft): logging.basicConfig(level=logging.INFO) if local_expn_class == VolumeTaylorLocalExpansion and use_fft: @@ -188,11 +185,6 @@ def test_sumpy_fmm(ctx_factory, knl, local_expn_class, mpole_expn_class, for order in order_values: target_kernels = [knl] - if optimized_m2l: - translation_classes_data = SumpyTranslationClassesData(queue, trav) - else: - translation_classes_data = None - tree_indep = SumpyTreeIndependentDataForWrangler( ctx, partial(mpole_expn_class, knl), @@ -206,14 +198,10 @@ def test_sumpy_fmm(ctx_factory, knl, local_expn_class, mpole_expn_class, def fmm_level_to_order(kernel, kernel_args, tree, lev): return order - with warnings.catch_warnings(): - if not optimized_m2l: - warnings.simplefilter("ignore", - SumpyTranslationClassesDataNotSuppliedWarning) - wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype, - fmm_level_to_order=fmm_level_to_order, - kernel_extra_kwargs=extra_kwargs, - translation_classes_data=translation_classes_data) + wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype, + fmm_level_to_order=fmm_level_to_order, + kernel_extra_kwargs=extra_kwargs, + _disable_translation_classes=not use_translation_classes) from boxtree.fmm import drive_fmm @@ -315,8 +303,7 @@ def test_unified_single_and_double(ctx_factory): strength_usage=strength_usage) wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype, fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order, - source_extra_kwargs=source_extra_kwargs, - translation_classes_data=SumpyTranslationClassesData(queue, trav)) + source_extra_kwargs=source_extra_kwargs) from boxtree.fmm import drive_fmm @@ -376,8 +363,7 @@ def test_sumpy_fmm_timing_data_collection(ctx_factory): target_kernels) wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype, - fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order, - translation_classes_data=SumpyTranslationClassesData(queue, trav)) + fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order) from boxtree.fmm import drive_fmm timing_data = {} @@ -435,8 +421,7 @@ def test_sumpy_fmm_exclude_self(ctx_factory): wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype, fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order, - self_extra_kwargs=self_extra_kwargs, - translation_classes_data=SumpyTranslationClassesData(queue, trav)) + self_extra_kwargs=self_extra_kwargs) from boxtree.fmm import drive_fmm @@ -510,8 +495,7 @@ def test_sumpy_axis_source_derivative(ctx_factory): wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype, fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order, - self_extra_kwargs=self_extra_kwargs, - translation_classes_data=SumpyTranslationClassesData(queue, trav)) + self_extra_kwargs=self_extra_kwargs) from boxtree.fmm import drive_fmm @@ -580,8 +564,7 @@ def test_sumpy_target_point_multiplier(ctx_factory, deriv_axes): wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype, fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order, - self_extra_kwargs=self_extra_kwargs, - translation_classes_data=SumpyTranslationClassesData(queue, trav)) + self_extra_kwargs=self_extra_kwargs) from boxtree.fmm import drive_fmm