diff --git a/pyopencl/__init__.py b/pyopencl/__init__.py
index 7cc3f3af7d1c2a7bff96b04eb8d1e7eca3bf5aa5..c3c9cf3f395eb567eb13fdc793c100820434332a 100644
--- a/pyopencl/__init__.py
+++ b/pyopencl/__init__.py
@@ -865,13 +865,14 @@ def _add_functionality():
         KernelWithCustomEnqueue.set_args = set_args
 
     def kernel_get_work_group_info(self, param, device):
+        cache_key = (param, device.int_ptr)
         try:
-            return self._wg_info_cache[param, device]
+            return self._wg_info_cache[cache_key]
         except KeyError:
             pass
 
         result = kernel_old_get_work_group_info(self, param, device)
-        self._wg_info_cache[param, device] = result
+        self._wg_info_cache[cache_key] = result
         return result
 
     def kernel_set_args(self, *args, **kwargs):