From 7016a0b9ab5c1fc21dba930577e30fe96014bd24 Mon Sep 17 00:00:00 2001
From: Yichao Yu <yyc1992@gmail.com>
Date: Thu, 22 May 2014 10:49:03 -0400
Subject: [PATCH] python gc

---
 pyopencl/_cffi.py                 |   4 +
 pyopencl/c_wrapper/wrap_cl_core.h |   1 +
 src/c_wrapper/error.h             |  16 ++
 src/c_wrapper/wrap_cl.cpp         | 235 ++++++++++++------------------
 4 files changed, 115 insertions(+), 141 deletions(-)

diff --git a/pyopencl/_cffi.py b/pyopencl/_cffi.py
index 6c15e0ab..596be5c2 100644
--- a/pyopencl/_cffi.py
+++ b/pyopencl/_cffi.py
@@ -160,3 +160,7 @@ _lib = _import_library()
 
 if _lib.pyopencl_have_gl():
     _ffi.cdef(_get_wrap_header("wrap_cl_gl_core.h"))
+
+import gc
+_gc_collect = _ffi.callback('int(void)')(gc.collect)
+_lib.pyopencl_set_gc(_gc_collect)
diff --git a/pyopencl/c_wrapper/wrap_cl_core.h b/pyopencl/c_wrapper/wrap_cl_core.h
index bc25d4a9..097b0ac2 100644
--- a/pyopencl/c_wrapper/wrap_cl_core.h
+++ b/pyopencl/c_wrapper/wrap_cl_core.h
@@ -100,3 +100,4 @@ void pyopencl_free_pointer_array(void**, uint32_t size);
 int pyopencl_have_gl();
 
 unsigned pyopencl_bitlog2(unsigned long v);
+void pyopencl_set_gc(int (*func)());
diff --git a/src/c_wrapper/error.h b/src/c_wrapper/error.h
index 4137067a..66039fe1 100644
--- a/src/c_wrapper/error.h
+++ b/src/c_wrapper/error.h
@@ -10,6 +10,8 @@
 
 namespace pyopencl {
 
+extern int (*python_gc)();
+
 #ifdef PYOPENCL_TRACE
 
 template<typename FirstType, typename... ArgTypes>
@@ -155,6 +157,20 @@ c_handle_error(std::function<void()> func)
     }
 }
 
+template<typename T>
+static inline T
+retry_mem_error(std::function<T()> func)
+{
+    try {
+        return func();
+    } catch (pyopencl::error &e) {
+        if (!e.is_out_of_memory() || !python_gc()) {
+            throw;
+        }
+    }
+    return func();
+}
+
 // }}}
 
 template<typename T, typename CLType, typename... ArgTypes>
diff --git a/src/c_wrapper/wrap_cl.cpp b/src/c_wrapper/wrap_cl.cpp
index 32b491d9..40517a5c 100644
--- a/src/c_wrapper/wrap_cl.cpp
+++ b/src/c_wrapper/wrap_cl.cpp
@@ -83,37 +83,6 @@
       devices = devices_vec.empty( ) ? NULL : &devices_vec.front(); \
     }                                                                   \
 
-
-
-#define PYOPENCL_RETRY_IF_MEM_ERROR(OPERATION) \
-  { \
-    bool failed_with_mem_error = false;         \
-    try                                         \
-    {                                         \
-      OPERATION                               \
-    }                                     \
-    catch (pyopencl::error &e)                  \
-    {                                         \
-      failed_with_mem_error = true;           \
-      if (!e.is_out_of_memory())              \
-        throw;                                \
-    }                                         \
-                                                \
-    if (failed_with_mem_error)                  \
-    {                                         \
-      /* If we get here, we got an error from CL.
-       * We should run the Python GC to try and free up
-       * some memory references. */                       \
-run_python_gc();                                \
-\
-/* Now retry the allocation. If it fails again,
- * let it fail. */                              \
-{                                               \
-  OPERATION                                     \
-    }                                           \
-}                                               \
-}
-
 // }}}
 
 
@@ -162,6 +131,13 @@ run_python_gc();                                \
 
 namespace pyopencl
 {
+static int
+dummy_python_gc()
+{
+    return 0;
+}
+
+int (*python_gc)() = dummy_python_gc;
 
   // {{{ platform
 
@@ -1179,10 +1155,10 @@ create_image_2d(context const &ctx, cl_mem_flags flags,
                 cl_image_format const &fmt, size_t width, size_t height,
                 size_t pitch, void *buffer, size_t size)
 {
-    // PYOPENCL_RETRY_IF_MEM_ERROR(
-    cl_mem mem = pyopencl_call_guarded(clCreateImage2D, ctx.data(), flags,
-                                       &fmt, width, height, pitch, buffer);
-    //);
+    auto mem = retry_mem_error<cl_mem>([&] {
+            return pyopencl_call_guarded(clCreateImage2D, ctx.data(), flags,
+                                         &fmt, width, height, pitch, buffer);
+        });
     return new_image(mem, flags & CL_MEM_USE_HOST_PTR ? buffer : NULL);
 }
 
@@ -1192,11 +1168,11 @@ create_image_3d(context const &ctx, cl_mem_flags flags,
                 size_t depth, size_t pitch_x, size_t pitch_y,
                 void *buffer, size_t size)
 {
-    //PYOPENCL_RETRY_IF_MEM_ERROR(
-    cl_mem mem = pyopencl_call_guarded(clCreateImage3D, ctx.data(), flags,
-                                       &fmt, width, height, depth, pitch_x,
-                                       pitch_y, buffer);
-    //);
+    auto mem = retry_mem_error<cl_mem>([&] {
+            return pyopencl_call_guarded(clCreateImage3D, ctx.data(), flags,
+                                         &fmt, width, height, depth, pitch_x,
+                                         pitch_y, buffer);
+        });
     return new_image(mem, flags & CL_MEM_USE_HOST_PTR ? buffer : NULL);
 }
 
@@ -1261,31 +1237,23 @@ create_image_3d(context const &ctx, cl_mem_flags flags,
 
   // {{{ image transfers
 
-  inline
-  event *enqueue_read_image(command_queue &cq,
-                            image &img,
-                            size_t *origin, size_t *region,
-                            void *buffer, size_t size,
-                            size_t row_pitch, size_t slice_pitch,
-                            void **wait_for, uint32_t num_wait_for,
-                            bool is_blocking)
-  {
+inline event*
+enqueue_read_image(command_queue &cq, image &img, size_t *origin,
+                   size_t *region, void *buffer, size_t size, size_t row_pitch,
+                   size_t slice_pitch, void **wait_for, uint32_t num_wait_for,
+                   bool is_blocking)
+{
     PYOPENCL_PARSE_WAIT_FOR;
-
     cl_event evt;
-    // TODO
-    //PYOPENCL_RETRY_IF_MEM_ERROR(
-    pyopencl_call_guarded(clEnqueueReadImage,
-                          cq.data(), img.data(), cast_bool(is_blocking),
-                          origin, region, row_pitch, slice_pitch, buffer,
-                          PYOPENCL_WAITLIST_ARGS, &evt);
-    //);
+    retry_mem_error<void>([&] {
+            pyopencl_call_guarded(clEnqueueReadImage, cq.data(), img.data(),
+                                  cast_bool(is_blocking), origin, region,
+                                  row_pitch, slice_pitch, buffer,
+                                  PYOPENCL_WAITLIST_ARGS, &evt);
+        });
     return new_event(evt);
     //PYOPENCL_RETURN_NEW_NANNY_EVENT(evt, buffer);
-  }
-
-
-
+}
 
   //   inline
   //   event *enqueue_write_image(
@@ -1708,10 +1676,10 @@ inline cl_mem
 create_buffer_gc(cl_context ctx, cl_mem_flags flags,
                  size_t size, void *host_ptr)
 {
-    // TODO
-    //PYOPENCL_RETRY_RETURN_IF_MEM_ERROR(
-    return pyopencl_call_guarded(clCreateBuffer, ctx, flags, size, host_ptr);
-    // );
+    return retry_mem_error<cl_mem>([&] {
+            return pyopencl_call_guarded(clCreateBuffer, ctx,
+                                         flags, size, host_ptr);
+        });
 }
 
 #if PYOPENCL_CL_VERSION >= 0x1010
@@ -1719,11 +1687,10 @@ inline cl_mem
 create_sub_buffer_gc(cl_mem buffer, cl_mem_flags flags,
                      cl_buffer_create_type bct, const void *buffer_create_info)
 {
-    // TODO
-    //PYOPENCL_RETRY_RETURN_IF_MEM_ERROR(
-    return pyopencl_call_guarded(clCreateSubBuffer, buffer, flags,
-                                 bct, buffer_create_info);
-    //);
+    return retry_mem_error<cl_mem>([&] {
+            return pyopencl_call_guarded(clCreateSubBuffer, buffer, flags,
+                                         bct, buffer_create_info);
+        });
 }
 #endif
 
@@ -2274,27 +2241,21 @@ public:
 
   // {{{ buffer transfers
 
-  inline
-  event *enqueue_read_buffer(
-                             command_queue &cq,
-                             memory_object_holder &mem,
-                             void *buffer, size_t size,
-                             size_t device_offset,
-                             void **wait_for, uint32_t num_wait_for,
-                             bool is_blocking)
-  {
+inline event*
+enqueue_read_buffer(command_queue &cq, memory_object_holder &mem, void *buffer,
+                    size_t size, size_t device_offset, void **wait_for,
+                    uint32_t num_wait_for, bool is_blocking)
+{
     PYOPENCL_PARSE_WAIT_FOR;
 
     cl_event evt;
-    // TODO
-    //PYOPENCL_RETRY_IF_MEM_ERROR(
-    pyopencl_call_guarded(clEnqueueReadBuffer,
-                          cq.data(), mem.data(), cast_bool(is_blocking),
-                          device_offset, size, buffer,
-                          PYOPENCL_WAITLIST_ARGS, &evt);
-    //);
+    retry_mem_error<void>([&] {
+            pyopencl_call_guarded(clEnqueueReadBuffer, cq.data(), mem.data(),
+                                  cast_bool(is_blocking), device_offset, size,
+                                  buffer, PYOPENCL_WAITLIST_ARGS, &evt);
+        });
     return new_event(evt);
-  }
+}
 
 
 
@@ -2321,63 +2282,51 @@ public:
       }
 
     cl_event evt;
-    // TODO
-    //PYOPENCL_RETRY_IF_MEM_ERROR(
-    pyopencl_call_guarded(clEnqueueCopyBuffer, cq.data(), src.data(),
-                          dst.data(), src_offset, dst_offset,
-                          byte_count, PYOPENCL_WAITLIST_ARGS, &evt);
-    // );
+    retry_mem_error<void>([&] {
+            pyopencl_call_guarded(clEnqueueCopyBuffer, cq.data(), src.data(),
+                                  dst.data(), src_offset, dst_offset,
+                                  byte_count, PYOPENCL_WAITLIST_ARGS, &evt);
+        });
     return new_event(evt);
   }
 
-  inline
-  event *enqueue_write_buffer(
-                              command_queue &cq,
-                              memory_object_holder &mem,
-                              const void *buffer,
-                              size_t size,
-                              size_t device_offset,
-                              void **wait_for, uint32_t num_wait_for,
-                              bool is_blocking)
-  {
+inline event*
+enqueue_write_buffer(command_queue &cq,memory_object_holder &mem,
+                     const void *buffer, size_t size, size_t device_offset,
+                     void **wait_for, uint32_t num_wait_for, bool is_blocking)
+{
     PYOPENCL_PARSE_WAIT_FOR;
 
     cl_event evt;
-    // TODO
-    //PYOPENCL_RETRY_IF_MEM_ERROR(
-    pyopencl_call_guarded(clEnqueueWriteBuffer, cq.data(), mem.data(),
-                          cast_bool(is_blocking), device_offset,
-                          size, buffer, PYOPENCL_WAITLIST_ARGS, &evt);
-    //);
-    // TODO
+    retry_mem_error<void>([&] {
+            pyopencl_call_guarded(clEnqueueWriteBuffer, cq.data(), mem.data(),
+                                  cast_bool(is_blocking), device_offset,
+                                  size, buffer, PYOPENCL_WAITLIST_ARGS, &evt);
+        });
     return new_event(evt);
     //PYOPENCL_RETURN_NEW_NANNY_EVENT(evt, buffer);
-  }
-
-
-
-  // }}}
+}
 
+// }}}
 
-  inline event *enqueue_nd_range_kernel(
-      command_queue &cq,
-      kernel &knl,
-      cl_uint work_dim,
-      const size_t *global_work_offset,
-      const size_t *global_work_size,
-      const size_t *local_work_size,
-      void **wait_for, uint32_t num_wait_for)
-  {
+inline event*
+enqueue_nd_range_kernel(command_queue &cq, kernel &knl, cl_uint work_dim,
+                        const size_t *global_work_offset,
+                        const size_t *global_work_size,
+                        const size_t *local_work_size,
+                        void **wait_for, uint32_t num_wait_for)
+{
     PYOPENCL_PARSE_WAIT_FOR;
 
     cl_event evt;
-
-    // TODO: PYOPENCL_RETRY_RETURN_IF_MEM_ERROR
-    pyopencl_call_guarded(clEnqueueNDRangeKernel, cq.data(), knl.data(),
-                          work_dim, global_work_offset, global_work_size,
-                          local_work_size, PYOPENCL_WAITLIST_ARGS, &evt);
+    retry_mem_error<void>([&] {
+            pyopencl_call_guarded(clEnqueueNDRangeKernel, cq.data(),
+                                  knl.data(), work_dim, global_work_offset,
+                                  global_work_size, local_work_size,
+                                  PYOPENCL_WAITLIST_ARGS, &evt);
+        });
     return new_event(evt);
-  }
+}
 
 #if PYOPENCL_CL_VERSION >= 0x1020
   inline
@@ -2407,22 +2356,19 @@ public:
   }
 #endif
 
-  inline
-  event *enqueue_marker(command_queue &cq)
-  {
+inline event*
+enqueue_marker(command_queue &cq)
+{
     cl_event evt;
-    // TODO
-    //PYOPENCL_RETRY_IF_MEM_ERROR(
     pyopencl_call_guarded(clEnqueueMarker, cq.data(), &evt);
-    //);
     return new_event(evt);
-  }
+}
 
-  inline
-  void enqueue_barrier(command_queue &cq)
-  {
+inline void
+enqueue_barrier(command_queue &cq)
+{
     pyopencl_call_guarded(clEnqueueBarrier, cq.data());
-  }
+}
 }
 
 
@@ -2439,6 +2385,13 @@ void pyopencl_free_pointer_array(void **p, uint32_t size)
     pyopencl_free_pointer(p[i]);
 }
 
+void
+pyopencl_set_gc(int (*func)())
+{
+    if (!func)
+        func = pyopencl::dummy_python_gc;
+    pyopencl::python_gc = func;
+}
 
 ::error*
 get_platforms(void **ptr_platforms, uint32_t *num_platforms)
-- 
GitLab