diff --git a/pyopencl/tools.py b/pyopencl/tools.py
index 461b2138e58fbb228798b2a157d75388e38838b4..ae1609a5fd9ee1ecde74f826aac6f9c087884b56 100644
--- a/pyopencl/tools.py
+++ b/pyopencl/tools.py
@@ -172,6 +172,31 @@ atexit.register(clear_first_arg_caches)
 # }}}
 
 
+# {{{ pytest fixtures
+
+class _ContextFactory:
+    def __init__(self, device):
+        self.device = device
+
+    def __call__(self):
+        # Get rid of leftovers from past tests.
+        # CL implementations are surprisingly limited in how many
+        # simultaneous contexts they allow...
+        clear_first_arg_caches()
+
+        from gc import collect
+        collect()
+
+        import pyopencl as cl
+        return cl.Context([self.device])
+
+    def __str__(self):
+        # Don't show address, so that parallel test collection works
+        return ("<context factory for <pyopencl.Device '%s' on '%s'>" %
+                (self.device.name.strip(),
+                 self.device.platform.name.strip()))
+
+
 def get_test_platforms_and_devices(plat_dev_string=None):
     """Parse a string of the form 'PYOPENCL_TEST=0:0,1;intel:i5'.
 
@@ -229,36 +254,17 @@ def get_test_platforms_and_devices(plat_dev_string=None):
                 for platform in cl.get_platforms()]
 
 
-def pytest_generate_tests_for_pyopencl(metafunc):
-    import pyopencl as cl
-
-    class ContextFactory:
-        def __init__(self, device):
-            self.device = device
-
-        def __call__(self):
-            # Get rid of leftovers from past tests.
-            # CL implementations are surprisingly limited in how many
-            # simultaneous contexts they allow...
-
-            clear_first_arg_caches()
-
-            from gc import collect
-            collect()
+def get_pyopencl_fixture_arg_names(metafunc, extra_arg_names=None):
+    if extra_arg_names is None:
+        extra_arg_names = []
 
-            return cl.Context([self.device])
-
-        def __str__(self):
-            # Don't show address, so that parallel test collection works
-            return ("<context factory for <pyopencl.Device '%s' on '%s'>" %
-                    (self.device.name.strip(),
-                     self.device.platform.name.strip()))
-
-    test_plat_and_dev = get_test_platforms_and_devices()
+    supported_arg_names = [
+            "platform", "device",
+            "ctx_factory", "ctx_getter",
+            ] + extra_arg_names
 
     arg_names = []
-
-    for arg in ("platform", "device", "ctx_factory", "ctx_getter"):
+    for arg in supported_arg_names:
         if arg not in metafunc.fixturenames:
             continue
 
@@ -270,21 +276,22 @@ def pytest_generate_tests_for_pyopencl(metafunc):
 
         arg_names.append(arg)
 
-    arg_values = []
+    return arg_names
 
-    for platform, plat_devs in test_plat_and_dev:
-        if arg_names == ["platform"]:
-            arg_values.append((platform,))
-            continue
 
+def get_pyopencl_fixture_arg_values():
+    import pyopencl as cl
+
+    arg_values = []
+    for platform, devices in get_test_platforms_and_devices():
         arg_dict = {"platform": platform}
 
-        for device in plat_devs:
+        for device in devices:
             arg_dict["device"] = device
-            arg_dict["ctx_factory"] = ContextFactory(device)
-            arg_dict["ctx_getter"] = ContextFactory(device)
+            arg_dict["ctx_factory"] = _ContextFactory(device)
+            arg_dict["ctx_getter"] = _ContextFactory(device)
 
-            arg_values.append(tuple(arg_dict[name] for name in arg_names))
+        arg_values.append(arg_dict)
 
     def idfn(val):
         if isinstance(val, cl.Platform):
@@ -293,8 +300,23 @@ def pytest_generate_tests_for_pyopencl(metafunc):
         else:
             return str(val)
 
-    if arg_names:
-        metafunc.parametrize(arg_names, arg_values, ids=idfn)
+    return arg_values, idfn
+
+
+def pytest_generate_tests_for_pyopencl(metafunc):
+    arg_names = get_pyopencl_fixture_arg_names(metafunc)
+    if not arg_names:
+        return
+
+    arg_values, ids = get_pyopencl_fixture_arg_values()
+    arg_values = [
+            tuple(arg_dict[name] for name in arg_names)
+            for arg_dict in arg_values
+            ]
+
+    metafunc.parametrize(arg_names, arg_values, ids=ids)
+
+# }}}
 
 
 # {{{ C argument lists