From 6fe590562286741bb296043d1dec382b52d2cfe1 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 6 Jul 2011 01:35:08 -0400
Subject: [PATCH] Rename ctx_getter -> ctx_factory. Introduce PYOPENCL_CTX env
 var.

---
 pyopencl/__init__.py     |  50 +++++++++++++----
 pyopencl/characterize.py |  47 ++++++++++++++++
 pyopencl/tools.py        |  13 ++++-
 test/test_array.py       | 118 +++++++++++++++++++--------------------
 test/test_clmath.py      |  20 +++----
 test/test_wrapper.py     |  28 +++++-----
 6 files changed, 178 insertions(+), 98 deletions(-)

diff --git a/pyopencl/__init__.py b/pyopencl/__init__.py
index a74adfcb..70567a91 100644
--- a/pyopencl/__init__.py
+++ b/pyopencl/__init__.py
@@ -457,11 +457,20 @@ class Program(object):
 
 # {{{ convenience -------------------------------------------------------------
 def create_some_context(interactive=True, answers=None):
+    import os
+    if answers is None and "PYOPENCL_CTX" in os.environ:
+        ctx_spec = os.environ["PYOPENCL_CTX"]
+        answers = ctx_spec.split(":")
+
+    user_inputs = []
+
     def get_input(prompt):
         if answers:
             return str(answers.pop(0))
         else:
-            return raw_input(prompt)
+            user_input = raw_input(prompt)
+            user_inputs.append(user_input)
+            return user_input
 
     try:
         import sys
@@ -484,14 +493,36 @@ def create_some_context(interactive=True, answers=None):
 
         answer = get_input("Choice [0]:")
         if not answer:
-            choice = 0
+            platform = platforms[0]
         else:
-            choice = int(answer)
+            try:
+                choice = int(answer)
+            except ValueError:
+                choice = choice.lower()
+                platform = None
+                for i, pf in enumerate(platforms):
+                    if choice in pf.name.lower():
+                        platform = pf
+                if platform is None:
+                    raise RuntimeError("input did not match any platform")
 
-        platform = platforms[choice]
+            else:
+                platform = platforms[choice]
 
     devices = platform.get_devices()
 
+    def parse_device(choice):
+        try:
+            choice = int(answer)
+        except ValueError:
+            choice = choice.lower()
+            for i, dev in enumerate(devices):
+                if choice in dev.name.lower():
+                    return dev
+            raise RuntimeError("input did not match any platform")
+        else:
+            return devices[choice]
+
     if not devices:
         raise Error("no devices found")
     elif len(devices) == 1 or not interactive:
@@ -508,19 +539,14 @@ def create_some_context(interactive=True, answers=None):
         else:
             devices = [devices[int(i)] for i in answer.split(",")]
 
+    if user_inputs:
+        print("Set the environment variable PYOPENCL_CTX='%s' to "
+                "avoid being asked again." % ":".join(user_inputs))
     return Context(devices)
 
 
 
 
-def _make_context_creator(answers):
-    def func():
-        return create_some_context(answers=answers)
-
-    return func
-
-
-
 def _mark_copy_deprecated(func):
     def new_func(*args, **kwargs):
         from warnings import warn
diff --git a/pyopencl/characterize.py b/pyopencl/characterize.py
index b3230471..469f05ae 100644
--- a/pyopencl/characterize.py
+++ b/pyopencl/characterize.py
@@ -1,5 +1,52 @@
+from __future__ import division
+
+import pyopencl as cl
+
 def has_double_support(dev):
     for ext in dev.extensions.split(" "):
         if ext == "cl_khr_fp64":
             return True
     return False
+
+
+
+
+def reasonable_work_group_size_multiple(dev, ctx=None):
+    try:
+        return dev.warp_size_nv
+    except AttributeError:
+        pass
+
+    if ctx is None:
+        ctx = cl.Context([dev])
+    prg = cl.Program(ctx, """
+        void knl(float *a)
+        {
+            a[get_global_id(0)] = 0;
+        }
+        """)
+    return prg.knl.get_work_group_info(
+            cl.kernel_work_group_info.PREFERRED_WORK_GROUP_SIZE_MULTIPLE,
+            dev)
+
+
+
+
+def usable_local_mem_size(dev, nargs=None):
+    """Return an estimate of the usable local memory size.
+    :arg nargs: Number of 32-bit arguments passed.
+    """
+    usable_local_mem_size = dev.local_mem_size
+
+    if ("nvidia" in dev.platform.name.lower()
+            and (dev.compute_capability_major_nv,
+                dev.compute_capability_minor_nv) < (2, 0)):
+        # pre-Fermi use local mem for parameter passing
+        if nargs is None:
+            # assume maximum
+            usable_local_mem_size -= 256
+        else:
+            usable_local_mem_size -= 4*nargs
+
+    return usable_local_mem_size
+
diff --git a/pyopencl/tools.py b/pyopencl/tools.py
index 2e150246..1beb7017 100644
--- a/pyopencl/tools.py
+++ b/pyopencl/tools.py
@@ -67,7 +67,7 @@ def context_dependent_memoize(func, context, *args):
 
 
 def pytest_generate_tests_for_pyopencl(metafunc):
-    class ContextGetter:
+    class ContextFactory:
         def __init__(self, device):
             self.device = device
 
@@ -81,9 +81,10 @@ def pytest_generate_tests_for_pyopencl(metafunc):
             return cl.Context([self.device])
 
         def __str__(self):
-            return "<context getter for %s>" % self.device
+            return "<context factory for %s>" % self.device
 
     if ("device" in metafunc.funcargnames
+            or "ctx_factory" in metafunc.funcargnames
             or "ctx_getter" in metafunc.funcargnames):
         arg_dict = {}
 
@@ -95,8 +96,14 @@ def pytest_generate_tests_for_pyopencl(metafunc):
                 if "device" in metafunc.funcargnames:
                     arg_dict["device"] = device
 
+                if "ctx_factory" in metafunc.funcargnames:
+                    arg_dict["ctx_factory"] = ContextFactory(device)
+
                 if "ctx_getter" in metafunc.funcargnames:
-                    arg_dict["ctx_getter"] = ContextGetter(device)
+                    from warnings import warn
+                    warn("The 'ctx_getter' arg is deprecated in favor of 'ctx_factory'.",
+                            DeprecationWarning)
+                    arg_dict["ctx_getter"] = ContextFactory(device)
 
                 metafunc.addcall(funcargs=arg_dict.copy(),
                         id=", ".join("%s=%s" % (arg, value)
diff --git a/test/test_array.py b/test/test_array.py
index a5fe8202..7bd1b97f 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -21,8 +21,8 @@ if have_cl():
 
 
 @pytools.test.mark_test.opencl
-def test_pow_array(ctx_getter):
-    context = ctx_getter()
+def test_pow_array(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     a = np.array([1, 2, 3, 4, 5]).astype(np.float32)
@@ -36,8 +36,8 @@ def test_pow_array(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_pow_number(ctx_getter):
-    context = ctx_getter()
+def test_pow_number(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.float32)
@@ -48,8 +48,8 @@ def test_pow_number(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_abs(ctx_getter):
-    context = ctx_getter()
+def test_abs(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     a = -cl_array.arange(queue, 111, dtype=np.float32)
@@ -68,8 +68,8 @@ def test_abs(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_len(ctx_getter):
-    context = ctx_getter()
+def test_len(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.float32)
@@ -78,10 +78,10 @@ def test_len(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_multiply(ctx_getter):
+def test_multiply(ctx_factory):
     """Test the muliplication of an array with a scalar. """
 
-    context = ctx_getter()
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     for sz in [10, 50000]:
@@ -97,10 +97,10 @@ def test_multiply(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_multiply_array(ctx_getter):
+def test_multiply_array(ctx_factory):
     """Test the multiplication of two arrays."""
 
-    context = ctx_getter()
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.float32)
@@ -114,10 +114,10 @@ def test_multiply_array(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_addition_array(ctx_getter):
+def test_addition_array(ctx_factory):
     """Test the addition of two arrays."""
 
-    context = ctx_getter()
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.float32)
@@ -128,10 +128,10 @@ def test_addition_array(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_addition_scalar(ctx_getter):
+def test_addition_scalar(ctx_factory):
     """Test the addition of an array and a scalar."""
 
-    context = ctx_getter()
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.float32)
@@ -142,14 +142,14 @@ def test_addition_scalar(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_substract_array(ctx_getter):
+def test_substract_array(ctx_factory):
     """Test the substraction of two arrays."""
     #test data
     a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.float32)
     b = np.array([10, 20, 30, 40, 50,
                   60, 70, 80, 90, 100]).astype(np.float32)
 
-    context = ctx_getter()
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     a_gpu = cl_array.to_device(queue, a)
@@ -163,10 +163,10 @@ def test_substract_array(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_substract_scalar(ctx_getter):
+def test_substract_scalar(ctx_factory):
     """Test the substraction of an array and a scalar."""
 
-    context = ctx_getter()
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     #test data
@@ -183,10 +183,10 @@ def test_substract_scalar(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_divide_scalar(ctx_getter):
+def test_divide_scalar(ctx_factory):
     """Test the division of an array and a scalar."""
 
-    context = ctx_getter()
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.float32)
@@ -200,10 +200,10 @@ def test_divide_scalar(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_divide_array(ctx_getter):
+def test_divide_array(ctx_factory):
     """Test the division of an array and a scalar. """
 
-    context = ctx_getter()
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     #test data
@@ -221,8 +221,8 @@ def test_divide_array(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_random(ctx_getter):
-    context = ctx_getter()
+def test_random(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     from pyopencl.clrandom import rand as clrand
@@ -240,8 +240,8 @@ def test_random(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_nan_arithmetic(ctx_getter):
-    context = ctx_getter()
+def test_nan_arithmetic(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     def make_nan_contaminated_vector(size):
@@ -266,8 +266,8 @@ def test_nan_arithmetic(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_elwise_kernel(ctx_getter):
-    context = ctx_getter()
+def test_elwise_kernel(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     from pyopencl.clrandom import rand as clrand
@@ -288,11 +288,11 @@ def test_elwise_kernel(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_elwise_kernel_with_options(ctx_getter):
+def test_elwise_kernel_with_options(ctx_factory):
     from pyopencl.clrandom import rand as clrand
     from pyopencl.elementwise import ElementwiseKernel
 
-    context = ctx_getter()
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     in_gpu = clrand(context, queue, (50,), np.float32)
@@ -320,8 +320,8 @@ def test_elwise_kernel_with_options(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_take(ctx_getter):
-    context = ctx_getter()
+def test_take(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     idx = cl_array.arange(queue, 0, 200000, 2, dtype=np.uint32)
@@ -331,8 +331,8 @@ def test_take(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_arange(ctx_getter):
-    context = ctx_getter()
+def test_arange(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     n = 5000
@@ -341,8 +341,8 @@ def test_arange(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_reverse(ctx_getter):
-    context = ctx_getter()
+def test_reverse(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     n = 5000
@@ -355,8 +355,8 @@ def test_reverse(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_sum(ctx_getter):
-    context = ctx_getter()
+def test_sum(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     from pyopencl.clrandom import rand as clrand
@@ -371,8 +371,8 @@ def test_sum(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_minmax(ctx_getter):
-    context = ctx_getter()
+def test_minmax(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     from pyopencl.clrandom import rand as clrand
@@ -394,8 +394,8 @@ def test_minmax(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_subset_minmax(ctx_getter):
-    context = ctx_getter()
+def test_subset_minmax(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     from pyopencl.clrandom import rand as clrand
@@ -434,8 +434,8 @@ def test_subset_minmax(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_dot(ctx_getter):
-    context = ctx_getter()
+def test_dot(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     from pyopencl.clrandom import rand as clrand
@@ -453,7 +453,7 @@ def test_dot(ctx_getter):
 
 if False:
     @pytools.test.mark_test.opencl
-    def test_slice(ctx_getter):
+    def test_slice(ctx_factory):
         from pyopencl.clrandom import rand as clrand
 
         l = 20000
@@ -472,8 +472,8 @@ if False:
 
 
 @pytools.test.mark_test.opencl
-def test_if_positive(ctx_getter):
-    context = ctx_getter()
+def test_if_positive(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     from pyopencl.clrandom import rand as clrand
@@ -495,8 +495,8 @@ def test_if_positive(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_take_put(ctx_getter):
-    context = ctx_getter()
+def test_take_put(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     for n in [5, 17, 333]:
@@ -517,8 +517,8 @@ def test_take_put(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_astype(ctx_getter):
-    context = ctx_getter()
+def test_astype(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     from pyopencl.clrandom import rand as clrand
@@ -544,8 +544,8 @@ def test_astype(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_scan(ctx_getter):
-    context = ctx_getter()
+def test_scan(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     from pyopencl.scan import InclusiveScanKernel, ExclusiveScanKernel
@@ -576,8 +576,8 @@ def test_scan(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_stride_preservation(ctx_getter):
-    context = ctx_getter()
+def test_stride_preservation(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     A = np.random.rand(3, 3)
@@ -589,8 +589,8 @@ def test_stride_preservation(ctx_getter):
 
 
 @pytools.test.mark_test.opencl
-def test_vector_fill(ctx_getter):
-    context = ctx_getter()
+def test_vector_fill(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     a_gpu = cl_array.Array(queue, 100, dtype=cl_array.vec.float4)
diff --git a/test/test_clmath.py b/test/test_clmath.py
index bdeb2d8d..ab64448c 100644
--- a/test/test_clmath.py
+++ b/test/test_clmath.py
@@ -41,8 +41,8 @@ def make_unary_function_test(name, limits=(0, 1), threshold=0):
     a = float(a)
     b = float(b)
 
-    def test(ctx_getter):
-        context = ctx_getter()
+    def test(ctx_factory):
+        context = ctx_factory()
         queue = cl.CommandQueue(context)
 
         gpu_func = getattr(clmath, name)
@@ -94,8 +94,8 @@ if have_cl():
 
 
 @pytools.test.mark_test.opencl
-def test_fmod(ctx_getter):
-    context = ctx_getter()
+def test_fmod(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     for s in sizes:
@@ -111,8 +111,8 @@ def test_fmod(ctx_getter):
             assert math.fmod(a[i], a2[i]) == b[i]
 
 @pytools.test.mark_test.opencl
-def test_ldexp(ctx_getter):
-    context = ctx_getter()
+def test_ldexp(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     for s in sizes:
@@ -128,8 +128,8 @@ def test_ldexp(ctx_getter):
             assert math.ldexp(a[i], int(a2[i])) == b[i]
 
 @pytools.test.mark_test.opencl
-def test_modf(ctx_getter):
-    context = ctx_getter()
+def test_modf(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     for s in sizes:
@@ -147,8 +147,8 @@ def test_modf(ctx_getter):
             assert abs(fracpart_true - fracpart[i]) < 1e-4
 
 @pytools.test.mark_test.opencl
-def test_frexp(ctx_getter):
-    context = ctx_getter()
+def test_frexp(ctx_factory):
+    context = ctx_factory()
     queue = cl.CommandQueue(context)
 
     for s in sizes:
diff --git a/test/test_wrapper.py b/test/test_wrapper.py
index 80c77457..ed8ae617 100644
--- a/test/test_wrapper.py
+++ b/test/test_wrapper.py
@@ -178,8 +178,8 @@ class TestCL:
         assert not iform.__dict__
 
     @pytools.test.mark_test.opencl
-    def test_nonempty_supported_image_formats(self, device, ctx_getter):
-        context = ctx_getter()
+    def test_nonempty_supported_image_formats(self, device, ctx_factory):
+        context = ctx_factory()
 
         if device.image_support:
             assert len(cl.get_supported_image_formats(
@@ -189,8 +189,8 @@ class TestCL:
             skip("images not supported on %s" % device.name)
 
     @pytools.test.mark_test.opencl
-    def test_that_python_args_fail(self, ctx_getter):
-        context = ctx_getter()
+    def test_that_python_args_fail(self, ctx_factory):
+        context = ctx_factory()
 
         prg = cl.Program(context, """
             __kernel void mult(__global float *a, float b, int c)
@@ -220,8 +220,8 @@ class TestCL:
         cl.enqueue_read_buffer(queue, a_buf, a_result).wait()
 
     @pytools.test.mark_test.opencl
-    def test_image_2d(self, device, ctx_getter):
-        context = ctx_getter()
+    def test_image_2d(self, device, ctx_factory):
+        context = ctx_factory()
 
         if not device.image_support:
             from py.test import skip
@@ -267,8 +267,8 @@ class TestCL:
         assert la.norm(a_result - a) == 0
 
     @pytools.test.mark_test.opencl
-    def test_copy_buffer(self, ctx_getter):
-        context = ctx_getter()
+    def test_copy_buffer(self, ctx_factory):
+        context = ctx_factory()
 
         queue = cl.CommandQueue(context)
         mf = cl.mem_flags
@@ -285,10 +285,10 @@ class TestCL:
         assert la.norm(a - b) == 0
 
     @pytools.test.mark_test.opencl
-    def test_mempool(self, ctx_getter):
+    def test_mempool(self, ctx_factory):
         from pyopencl.tools import MemoryPool, CLAllocator
 
-        context = ctx_getter()
+        context = ctx_factory()
 
         pool = MemoryPool(CLAllocator(context))
         maxlen = 10
@@ -319,8 +319,8 @@ class TestCL:
             assert asize < asize*(1+1/8)
 
     @pytools.test.mark_test.opencl
-    def test_vector_args(self, ctx_getter):
-        context = ctx_getter()
+    def test_vector_args(self, ctx_factory):
+        context = ctx_factory()
         queue = cl.CommandQueue(context)
 
         prg = cl.Program(context, """
@@ -340,8 +340,8 @@ class TestCL:
         assert (dest == x).all()
 
     @pytools.test.mark_test.opencl
-    def test_header_dep_handling(self, ctx_getter):
-        context = ctx_getter()
+    def test_header_dep_handling(self, ctx_factory):
+        context = ctx_factory()
         queue = cl.CommandQueue(context)
 
         kernel_src = """
-- 
GitLab