From 24ca6dc626d67dad6f293c937e1879517ecdfd2c Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Fri, 22 Dec 2017 23:33:32 -0600 Subject: [PATCH] Fix the interaction list building for a float32 tree. Closes #21. --- boxtree/traversal.py | 4 ++-- test/test_fmm.py | 53 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/boxtree/traversal.py b/boxtree/traversal.py index 6f24e5b..3448ee5 100644 --- a/boxtree/traversal.py +++ b/boxtree/traversal.py @@ -136,8 +136,8 @@ TRAVERSAL_PREAMBLE_MAKO_DEFS = r"""//CL:mako// ); %endfor - ${name}_ext_center = 0.5*(${name}_min + ${name}_max); - ${name}_radii_vec = 0.5*(${name}_max - ${name}_min); + ${name}_ext_center = ((coord_vec_t) 0.5) * (${name}_min + ${name}_max); + ${name}_radii_vec = ((coord_vec_t) 0.5) * (${name}_max - ${name}_min); } diff --git a/test/test_fmm.py b/test/test_fmm.py index 082b431..d9bbc39 100644 --- a/test/test_fmm.py +++ b/test/test_fmm.py @@ -714,6 +714,59 @@ def test_interaction_list_particle_count_thresholding(ctx_getter, enable_extents # }}} +# {{{ test fmm with float32 dtype + +@pytest.mark.parametrize("enable_extents", [True, False]) +def test_fmm_float32(ctx_getter, enable_extents): + ctx = ctx_getter() + queue = cl.CommandQueue(ctx) + + logging.basicConfig(level=logging.INFO) + + dims = 2 + nsources = 1000 + ntargets = 1000 + dtype = np.float32 + + from boxtree.fmm import drive_fmm + sources = p_normal(queue, nsources, dims, dtype, seed=15) + targets = p_normal(queue, ntargets, dims, dtype, seed=15) + + from pyopencl.clrandom import PhiloxGenerator + rng = PhiloxGenerator(queue.context, seed=12) + + if enable_extents: + target_radii = 2**rng.uniform(queue, ntargets, dtype=dtype, a=-10, b=0) + else: + target_radii = None + + from boxtree import TreeBuilder + tb = TreeBuilder(ctx) + + tree, _ = tb(queue, sources, targets=targets, + max_particles_in_box=30, + target_radii=target_radii, + debug=True, stick_out_factor=0.25) + + from boxtree.traversal import FMMTraversalBuilder + tbuild = FMMTraversalBuilder(ctx) + trav, _ = tbuild(queue, tree, debug=True) + + weights = np.ones(nsources) + weights_sum = np.sum(weights) + + host_trav = trav.get(queue=queue) + host_tree = host_trav.tree + + wrangler = ConstantOneExpansionWrangler(host_tree) + + pot = drive_fmm(host_trav, wrangler, weights) + + assert (pot == weights_sum).all() + +# }}} + + # You can test individual routines by typing # $ python test_fmm.py 'test_routine(cl.create_some_context)' -- GitLab