diff --git a/fmm.py b/fmm.py index 82c6eb23da9a617e1f78335a8a85203dfa2ccbc3..7da06aeb5239749ed54e7272c04943b69bb3faea 100644 --- a/fmm.py +++ b/fmm.py @@ -36,9 +36,12 @@ DIRECT_KERNEL = """ } """ -def main(): - target = np.random.rand(1, 4).astype(np.float32) - source = np.random.rand(1, 4).astype(np.float32) + + + +def test_direct(): + target = np.random.rand(5000, 4).astype(np.float32) + source = np.random.rand(5000, 4).astype(np.float32) ctx = cl.create_some_context() queue = cl.CommandQueue(ctx) @@ -64,9 +67,9 @@ def main(): / np.sum((target[itarg,:3] - source[:,:3])**2, axis=-1)**0.5) - print potential[:100] - print potential_host[:100] - #print la.norm(potential - potential_host) + #print potential[:100] + #print potential_host[:100] + assert la.norm(potential - potential_host)/la.norm(potential_host) < 1e-6 @@ -74,4 +77,4 @@ def main(): if __name__ == "__main__": - main() + test_direct()