diff --git a/pyopencl/__init__.py b/pyopencl/__init__.py index 7cc3f3af7d1c2a7bff96b04eb8d1e7eca3bf5aa5..c3c9cf3f395eb567eb13fdc793c100820434332a 100644 --- a/pyopencl/__init__.py +++ b/pyopencl/__init__.py @@ -865,13 +865,14 @@ def _add_functionality(): KernelWithCustomEnqueue.set_args = set_args def kernel_get_work_group_info(self, param, device): + cache_key = (param, device.int_ptr) try: - return self._wg_info_cache[param, device] + return self._wg_info_cache[cache_key] except KeyError: pass result = kernel_old_get_work_group_info(self, param, device) - self._wg_info_cache[param, device] = result + self._wg_info_cache[cache_key] = result return result def kernel_set_args(self, *args, **kwargs):