diff --git a/pyopencl/tools.py b/pyopencl/tools.py
index 6a72ee9796d8b728d3bc919633a1da27a891c584..7304cfc73dbcf722223e5bb2da1510dd909de6dc 100644
--- a/pyopencl/tools.py
+++ b/pyopencl/tools.py
@@ -249,38 +249,38 @@ def pytest_generate_tests_for_pyopencl(metafunc):
 
     test_plat_and_dev = get_test_platforms_and_devices()
 
-    if ("device" in metafunc.funcargnames
-            or "ctx_factory" in metafunc.funcargnames
-            or "ctx_getter" in metafunc.funcargnames):
-        arg_dict = {}
-
-        for platform, plat_devs in test_plat_and_dev:
-            if "platform" in metafunc.funcargnames:
-                arg_dict["platform"] = platform
-
-            for device in plat_devs:
-                if "device" in metafunc.funcargnames:
-                    arg_dict["device"] = device
-
-                if "ctx_factory" in metafunc.funcargnames:
-                    arg_dict["ctx_factory"] = ContextFactory(device)
-
-                if "ctx_getter" in metafunc.funcargnames:
-                    from warnings import warn
-                    warn("The 'ctx_getter' arg is deprecated in "
-                            "favor of 'ctx_factory'.",
-                            DeprecationWarning)
-                    arg_dict["ctx_getter"] = ContextFactory(device)
-
-                metafunc.addcall(funcargs=arg_dict.copy(),
-                        id=", ".join("%s=%s" % (arg, value)
-                                for arg, value in six.iteritems(arg_dict)))
-
-    elif "platform" in metafunc.funcargnames:
-        for platform, plat_devs in test_plat_and_dev:
-            metafunc.addcall(
-                    funcargs=dict(platform=platform),
-                    id=str(platform))
+    arg_names = []
+
+    for arg in ("platform", "device", "ctx_factory", "ctx_getter"):
+        if arg not in metafunc.funcargnames:
+            continue
+
+        if arg == "ctx_getter":
+            from warnings import warn
+            warn("The 'ctx_getter' arg is deprecated in "
+                    "favor of 'ctx_factory'.",
+                    DeprecationWarning)
+
+        arg_names.append(arg)
+
+    arg_values = []
+
+    for platform, plat_devs in test_plat_and_dev:
+        if arg_names == ["platform"]:
+            arg_values.append((platform,))
+            continue
+
+        arg_dict = {"platform": platform}
+
+        for device in plat_devs:
+            arg_dict["device"] = device
+            arg_dict["ctx_factory"] = ContextFactory(device)
+            arg_dict["ctx_getter"] = ContextFactory(device)
+
+            arg_values.append(tuple(arg_dict[name] for name in arg_names))
+
+    if arg_names:
+        metafunc.parametrize(arg_names, arg_values, ids=str)
 
 
 # {{{ C argument lists