diff --git a/boxtree/traversal.py b/boxtree/traversal.py index 6f24e5b5769485cb6b9ca3fecfc6bb92397b2f3c..3448ee56a0f359160cc0bfcec17957029b56cbba 100644 --- a/boxtree/traversal.py +++ b/boxtree/traversal.py @@ -136,8 +136,8 @@ TRAVERSAL_PREAMBLE_MAKO_DEFS = r"""//CL:mako// ); %endfor - ${name}_ext_center = 0.5*(${name}_min + ${name}_max); - ${name}_radii_vec = 0.5*(${name}_max - ${name}_min); + ${name}_ext_center = ((coord_vec_t) 0.5) * (${name}_min + ${name}_max); + ${name}_radii_vec = ((coord_vec_t) 0.5) * (${name}_max - ${name}_min); } diff --git a/test/test_fmm.py b/test/test_fmm.py index 082b43176d84385098d8298fd1f50e56c7383810..d9bbc398b3a37a9b4405879dc10ffa1a2e901e57 100644 --- a/test/test_fmm.py +++ b/test/test_fmm.py @@ -714,6 +714,59 @@ def test_interaction_list_particle_count_thresholding(ctx_getter, enable_extents # }}} +# {{{ test fmm with float32 dtype + +@pytest.mark.parametrize("enable_extents", [True, False]) +def test_fmm_float32(ctx_getter, enable_extents): + ctx = ctx_getter() + queue = cl.CommandQueue(ctx) + + logging.basicConfig(level=logging.INFO) + + dims = 2 + nsources = 1000 + ntargets = 1000 + dtype = np.float32 + + from boxtree.fmm import drive_fmm + sources = p_normal(queue, nsources, dims, dtype, seed=15) + targets = p_normal(queue, ntargets, dims, dtype, seed=15) + + from pyopencl.clrandom import PhiloxGenerator + rng = PhiloxGenerator(queue.context, seed=12) + + if enable_extents: + target_radii = 2**rng.uniform(queue, ntargets, dtype=dtype, a=-10, b=0) + else: + target_radii = None + + from boxtree import TreeBuilder + tb = TreeBuilder(ctx) + + tree, _ = tb(queue, sources, targets=targets, + max_particles_in_box=30, + target_radii=target_radii, + debug=True, stick_out_factor=0.25) + + from boxtree.traversal import FMMTraversalBuilder + tbuild = FMMTraversalBuilder(ctx) + trav, _ = tbuild(queue, tree, debug=True) + + weights = np.ones(nsources) + weights_sum = np.sum(weights) + + host_trav = trav.get(queue=queue) + host_tree = host_trav.tree + + wrangler = ConstantOneExpansionWrangler(host_tree) + + pot = drive_fmm(host_trav, wrangler, weights) + + assert (pot == weights_sum).all() + +# }}} + + # You can test individual routines by typing # $ python test_fmm.py 'test_routine(cl.create_some_context)'