From d4ecdf160d3eaba90f361944710445ae37498817 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 28 Oct 2015 02:35:10 -0500
Subject: [PATCH] Spatial tree tweaks, test

---
 pytools/spatial_btree.py | 99 ++++++++++++++++++++++------------------
 test/test_pytools.py     | 38 ++++++++++++++-
 2 files changed, 92 insertions(+), 45 deletions(-)

diff --git a/pytools/spatial_btree.py b/pytools/spatial_btree.py
index 238c8c2..00adc20 100644
--- a/pytools/spatial_btree.py
+++ b/pytools/spatial_btree.py
@@ -1,11 +1,12 @@
-from __future__ import division
-from __future__ import absolute_import
+from __future__ import division, absolute_import
 from six.moves import range
 
+import numpy as np
 
-def do_boxes_intersect(xxx_todo_changeme, xxx_todo_changeme1):
-    (bl1,tr1) = xxx_todo_changeme
-    (bl2,tr2) = xxx_todo_changeme1
+
+def do_boxes_intersect(bl, tr):
+    (bl1, tr1) = bl
+    (bl2, tr2) = tr
     (dimension,) = bl1.shape
     for i in range(0, dimension):
         if max(bl1[i], bl2[i]) > min(tr1[i], tr2[i]):
@@ -13,31 +14,16 @@ def do_boxes_intersect(xxx_todo_changeme, xxx_todo_changeme1):
     return True
 
 
-
-
-def _get_elements_bounding_box(elements):
-    import numpy
-
-    if len(elements) == 0:
-        raise RuntimeError("Cannot get the bounding box of no elements.")
-
-    bboxes = [box for el,box in elements]
-    bottom_lefts = [bl for bl,tr in bboxes]
-    top_rights = [tr for bl,tr in bboxes]
-    return numpy.minimum.reduce(bottom_lefts), numpy.minimum.reduce(top_rights)
-
-
-
-def make_buckets(bottom_left, top_right, allbuckets):
-    import numpy
-
+def make_buckets(bottom_left, top_right, allbuckets, max_elements_per_box):
     (dimensions,) = bottom_left.shape
 
     half = (top_right - bottom_left) / 2.
+
     def do(dimension, pos):
         if dimension == dimensions:
             origin = bottom_left + pos*half
-            bucket = SpatialBinaryTreeBucket(origin, origin + half)
+            bucket = SpatialBinaryTreeBucket(origin, origin + half,
+                    max_elements_per_box=max_elements_per_box)
             allbuckets.append(bucket)
             return bucket
         else:
@@ -47,9 +33,7 @@ def make_buckets(bottom_left, top_right, allbuckets):
             second = do(dimension + 1, pos)
             return [first, second]
 
-    return do(0, numpy.zeros((dimensions,), numpy.float64))
-
-
+    return do(0, np.zeros((dimensions,), np.float64))
 
 
 class SpatialBinaryTreeBucket:
@@ -57,12 +41,14 @@ class SpatialBinaryTreeBucket:
     It automatically decides whether it needs to create more subdivisions
     beneath itself or not.
 
-    :ivar elements: a list of tuples *(element, bbox)* where bbox is again
-      a tuple *(lower_left, upper_right)* of :class:`numpy.ndarray` instances
-      satisfying *(lower_right <= upper_right).all()*.
+    .. attribute:: elements
+
+        a list of tuples *(element, bbox)* where bbox is again
+        a tuple *(lower_left, upper_right)* of :class:`numpy.ndarray` instances
+        satisfying ``(lower_right <= upper_right).all()``.
     """
 
-    def __init__(self, bottom_left, top_right):
+    def __init__(self, bottom_left, top_right, max_elements_per_box=None):
         """:param bottom_left: A :mod: 'numpy' array of the minimal coordinates
         of the box being partitioned.
         :param top_right: A :mod: 'numpy' array of the maximal coordinates of
@@ -78,6 +64,12 @@ class SpatialBinaryTreeBucket:
         self.buckets = None
         self.elements = []
 
+        if max_elements_per_box is None:
+            dimensions, = self.bottom_left.shape
+            max_elements_per_box = 8 * 2**dimensions
+
+        self.max_elements_per_box = max_elements_per_box
+
     def insert(self, element, bbox):
         """Insert an element into the spatial tree.
 
@@ -97,14 +89,15 @@ class SpatialBinaryTreeBucket:
                 if do_boxes_intersect((bucket.bottom_left, bucket.top_right), bbox):
                     bucket.insert(element, bbox)
 
-        (dimensions,) = self.bottom_left.shape
         if self.buckets is None:
             # No subdivisions yet.
-            if len(self.elements) > 8 * 2**dimensions:
+            if len(self.elements) > self.max_elements_per_box:
                 # Too many elements. Need to subdivide.
                 self.all_buckets = []
-                self.buckets = make_buckets(self.bottom_left, self.top_right,
-                                            self.all_buckets)
+                self.buckets = make_buckets(
+                        self.bottom_left, self.top_right,
+                        self.all_buckets,
+                        max_elements_per_box=self.max_elements_per_box)
 
                 # Move all elements from the full bucket into the new finer ones
                 for el, el_bbox in self.elements:
@@ -122,7 +115,6 @@ class SpatialBinaryTreeBucket:
             # Go find which sudivision to place element
             insert_into_subdivision(element, bbox)
 
-
     def generate_matches(self, point):
         if self.buckets:
             # We have subdivisions. Use them.
@@ -142,16 +134,35 @@ class SpatialBinaryTreeBucket:
                 yield el
 
     def visualize(self, file):
-        file.write("%f %f\n" % (self.bottom_left[0], self.bottom_left[1]));
-        file.write("%f %f\n" % (self.top_right[0], self.bottom_left[1]));
-        file.write("%f %f\n" % (self.top_right[0], self.top_right[1]));
-        file.write("%f %f\n" % (self.bottom_left[0], self.top_right[1]));
-        file.write("%f %f\n\n" % (self.bottom_left[0], self.bottom_left[1]));
+        file.write("%f %f\n" % (self.bottom_left[0], self.bottom_left[1]))
+        file.write("%f %f\n" % (self.top_right[0], self.bottom_left[1]))
+        file.write("%f %f\n" % (self.top_right[0], self.top_right[1]))
+        file.write("%f %f\n" % (self.bottom_left[0], self.top_right[1]))
+        file.write("%f %f\n\n" % (self.bottom_left[0], self.bottom_left[1]))
         if self.buckets:
             for i in self.all_buckets:
                 i.visualize(file)
 
+    def plot(self, **kwargs):
+        import matplotlib.pyplot as pt
+        import matplotlib.patches as mpatches
+        from matplotlib.path import Path
+
+        el = self.bottom_left
+        eh = self.top_right
+        pathdata = [
+            (Path.MOVETO, (el[0], el[1])),
+            (Path.LINETO, (eh[0], el[1])),
+            (Path.LINETO, (eh[0], eh[1])),
+            (Path.LINETO, (el[0], eh[1])),
+            (Path.CLOSEPOLY, (el[0], el[1])),
+            ]
+
+        codes, verts = zip(*pathdata)
+        path = Path(verts, codes)
+        patch = mpatches.PathPatch(path, **kwargs)
+        pt.gca().add_patch(patch)
 
-
-
-
+        if self.buckets:
+            for i in self.all_buckets:
+                i.plot(**kwargs)
diff --git a/test/test_pytools.py b/test/test_pytools.py
index 8279f4f..3773243 100644
--- a/test/test_pytools.py
+++ b/test/test_pytools.py
@@ -2,7 +2,6 @@ from __future__ import division, with_statement
 from __future__ import absolute_import
 
 import pytest
-import sys  # noqa
 
 
 @pytest.mark.skipif("sys.version_info < (2, 5)")
@@ -128,3 +127,40 @@ def test_memoize_keyfunc():
     assert f(1, (2,)) == 2
     assert f(2, j=(2, 3)) == 4
     assert count[0] == 2
+
+
+@pytest.mark.parametrize("dims", [2, 3])
+def test_spatial_btree(dims, do_plot=False):
+    import numpy as np
+    nparticles = 2000
+    x = -1 + 2*np.random.rand(dims, nparticles)
+    x = np.sign(x)*np.abs(x)**1.9
+    x = (1.4 + x) % 2 - 1
+
+    bl = np.min(x, axis=-1)
+    tr = np.max(x, axis=-1)
+    print(bl, tr)
+
+    from pytools.spatial_btree import SpatialBinaryTreeBucket
+    tree = SpatialBinaryTreeBucket(bl, tr, max_elements_per_box=10)
+    for i in range(nparticles):
+        tree.insert(i, (x[:, i], x[:, i]))
+
+    if do_plot:
+        import matplotlib.pyplot as pt
+        pt.gca().set_aspect("equal")
+        pt.plot(x[0], x[1], "x")
+        tree.plot(fill=None)
+        pt.show()
+
+
+if __name__ == "__main__":
+    # make sure that import failures get reported, instead of skipping the tests.
+    import pyopencl  # noqa
+
+    import sys
+    if len(sys.argv) > 1:
+        exec(sys.argv[1])
+    else:
+        from py.test.cmdline import main
+        main([__file__])
-- 
GitLab