From e29e953451a096844ac0d5a45e8362f17ebfb64e Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 20 Feb 2014 13:24:16 -0600
Subject: [PATCH] Implement a better fix for dtype pickling

---
 loopy/codegen/__init__.py | 19 ++++++++------
 loopy/kernel/array.py     | 36 +++++++++++++++++++++------
 loopy/kernel/data.py      | 35 +++++++++++++-------------
 loopy/tools.py            | 52 ++++++++++++++++++++++++++++++---------
 test/test_loopy.py        |  4 +++
 5 files changed, 103 insertions(+), 43 deletions(-)

diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index fb4fcc95b..cbc64add7 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -301,9 +301,12 @@ class ImplementedDataInfo(Record):
             unvec_shape=None, unvec_strides=None,
             offset_for_name=None, stride_for_name_and_axis=None,
             allows_offset=None):
+
+        from loopy.tools import PicklableDtype
+
         Record.__init__(self,
                 name=name,
-                dtype=np.dtype(dtype),
+                picklable_dtype=PicklableDtype(dtype),
                 cgen_declarator=cgen_declarator,
                 arg_class=arg_class,
                 base_name=base_name,
@@ -315,13 +318,13 @@ class ImplementedDataInfo(Record):
                 stride_for_name_and_axis=stride_for_name_and_axis,
                 allows_offset=allows_offset)
 
-    def __setstate__(self, state):
-        Record.__setstate__(self, state)
-
-        import loopy as lp
-        if self.dtype is not None and self.dtype is not lp.auto:
-            from loopy.tools import fix_dtype_after_unpickling
-            self.dtype = fix_dtype_after_unpickling(self.dtype)
+    @property
+    def dtype(self):
+        from loopy.tools import PicklableDtype
+        if isinstance(self.picklable_dtype, PicklableDtype):
+            return self.picklable_dtype.dtype
+        else:
+            return self.picklable_dtype
 
 # }}}
 
diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py
index e42f586b3..45360a926 100644
--- a/loopy/kernel/array.py
+++ b/loopy/kernel/array.py
@@ -464,13 +464,18 @@ class ArrayBase(Record):
         import loopy as lp
 
         if dtype is not None and dtype is not lp.auto:
-            dtype = np.dtype(dtype)
+            from loopy.tools import PicklableDtype
+            picklable_dtype = PicklableDtype(dtype)
+
+            if picklable_dtype.dtype == object:
+                raise TypeError("loopy does not directly support object arrays "
+                        "(object dtype encountered on array '%s') "
+                        "-- you may want to tag the relevant array axis as 'sep'"
+                        % name)
+        else:
+            picklable_dtype = dtype
 
-        if dtype == object:
-            raise TypeError("loopy does not directly support object arrays "
-                    "(object dtype encountered on array '%s') "
-                    "-- you may want to tag the relevant array axis as 'sep'"
-                    % name)
+        del dtype
 
         strides_known = strides is not None and strides is not lp.auto
         shape_known = shape is not None and shape is not lp.auto
@@ -575,13 +580,30 @@ class ArrayBase(Record):
 
         Record.__init__(self,
                 name=name,
-                dtype=dtype,
+                picklable_dtype=picklable_dtype,
                 shape=shape,
                 dim_tags=dim_tags,
                 offset=offset,
                 order=order,
                 **kwargs)
 
+    @property
+    def dtype(self):
+        from loopy.tools import PicklableDtype
+        if isinstance(self.picklable_dtype, PicklableDtype):
+            return self.picklable_dtype.dtype
+        else:
+            return self.picklable_dtype
+
+    def get_copy_kwargs(self, **kwargs):
+        result = Record.get_copy_kwargs(self, **kwargs)
+        if "dtype" not in result:
+            result["dtype"] = self.dtype
+
+        del result["picklable_dtype"]
+
+        return result
+
     def stringify(self, include_typename):
         import loopy as lp
 
diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py
index ad173d96a..6a333966a 100644
--- a/loopy/kernel/data.py
+++ b/loopy/kernel/data.py
@@ -177,13 +177,21 @@ def parse_tag(tag):
 
 
 class KernelArgument(Record):
-    def __setstate__(self, state):
-        Record.__setstate__(self, state)
+    def __init__(self, **kwargs):
+        dtype = kwargs.pop("dtype")
 
-        import loopy as lp
-        if self.dtype is not None and self.dtype is not lp.auto:
-            from loopy.tools import fix_dtype_after_unpickling
-            self.dtype = fix_dtype_after_unpickling(self.dtype)
+        from loopy.tools import PicklableDtype
+        kwargs["picklable_dtype"] = PicklableDtype(dtype)
+        Record.__init__(self, **kwargs)
+
+    def get_copy_kwargs(self, **kwargs):
+        result = Record.get_copy_kwargs(self, **kwargs)
+        if "dtype" not in result:
+            result["dtype"] = self.dtype
+
+        del result["picklable_dtype"]
+
+        return result
 
 
 class GlobalArg(ArrayBase, KernelArgument):
@@ -245,7 +253,7 @@ class ValueArg(KernelArgument):
         if dtype is not None:
             dtype = np.dtype(dtype)
 
-        Record.__init__(self, name=name, dtype=dtype,
+        KernelArgument.__init__(self, name=name, dtype=dtype,
                 approximately=approximately)
 
     def __str__(self):
@@ -295,7 +303,7 @@ class TemporaryVariable(ArrayBase):
             "is_local"
             ]
 
-    def __init__(self, name, dtype, shape=(), is_local=auto,
+    def __init__(self, name, dtype=None, shape=(), is_local=auto,
             dim_tags=None, offset=0, strides=None, order=None,
             base_indices=None, storage_shape=None):
         """
@@ -311,7 +319,8 @@ class TemporaryVariable(ArrayBase):
         if base_indices is None:
             base_indices = (0,) * len(shape)
 
-        ArrayBase.__init__(self, name=name, dtype=dtype, shape=shape,
+        ArrayBase.__init__(self, name=name,
+                dtype=dtype, shape=shape,
                 dim_tags=dim_tags, order="C",
                 base_indices=base_indices, is_local=is_local,
                 storage_shape=storage_shape)
@@ -343,14 +352,6 @@ class TemporaryVariable(ArrayBase):
     def __str__(self):
         return self.stringify(include_typename=False)
 
-    def __setstate__(self, state):
-        ArrayBase.__setstate__(self, state)
-
-        import loopy as lp
-        if self.dtype is not None and self.dtype is not lp.auto:
-            from loopy.tools import fix_dtype_after_unpickling
-            self.dtype = fix_dtype_after_unpickling(self.dtype)
-
 # }}}
 
 
diff --git a/loopy/tools.py b/loopy/tools.py
index f894379d2..56cdd3ed8 100644
--- a/loopy/tools.py
+++ b/loopy/tools.py
@@ -91,16 +91,46 @@ class LoopyKeyBuilder(KeyBuilderBase):
 # }}}
 
 
-def fix_dtype_after_unpickling(dtype):
-    # Work around https://github.com/numpy/numpy/issues/4317
-    from pyopencl.compyte.dtypes import DTYPE_TO_NAME
-    for other_dtype in DTYPE_TO_NAME:
-        # Incredibly, DTYPE_TO_NAME contains strings...
-        if isinstance(other_dtype, np.dtype) and dtype == other_dtype:
-            return other_dtype
-
-    raise RuntimeError(
-            "don't know what to do with (likely broken) unpickled dtype '%s'"
-            % dtype)
+class PicklableDtype(object):
+    """This object works around several issues with pickling :class:`numpy.dtype`
+    objects. It does so by serving as a picklable wrapper around the original
+    dtype.
+
+    The issues are the following
+
+    - :class:`numpy.dtype` objects for custom types in :mod:`loopy` are usually
+      registered in the :mod:`pyopencl` dtype registry. This registration may
+      have been lost after unpickling. This container restores it implicitly,
+      as part of unpickling.
+
+    - There is a`numpy bug <https://github.com/numpy/numpy/issues/4317>`_
+      that prevents unpickled dtypes from hashing properly. This is solved
+      by retrieving the 'canonical' type from the dtype registry.
+    """
+
+    def __init__(self, dtype):
+        self.dtype = np.dtype(dtype)
+
+    def __hash__(self):
+        return hash(self.dtype)
+
+    def __eq__(self, other):
+        return (
+                type(self) == type(other)
+                and self.dtype == other.dtype)
+
+    def __ne__(self, other):
+        return not self.__eq__(self, other)
+
+    def __getstate__(self):
+        from pyopencl.compyte.dtypes import DTYPE_TO_NAME
+        c_name = DTYPE_TO_NAME[self.dtype]
+
+        return (c_name, self.dtype)
+
+    def __setstate__(self, state):
+        name, dtype = state
+        from pyopencl.tools import get_or_register_dtype
+        self.dtype = get_or_register_dtype([name], dtype)
 
 # vim: foldmethod=marker
diff --git a/test/test_loopy.py b/test/test_loopy.py
index a8b7ff74d..fe17ba3e7 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -365,6 +365,8 @@ def test_stencil_with_overfetch(ctx_factory):
 
 
 def test_eq_constraint(ctx_factory):
+    logging.basicConfig(level=logging.INFO)
+
     ctx = ctx_factory()
 
     knl = lp.make_kernel(
@@ -388,6 +390,8 @@ def test_eq_constraint(ctx_factory):
 
 
 def test_argmax(ctx_factory):
+    logging.basicConfig(level=logging.INFO)
+
     dtype = np.dtype(np.float32)
     ctx = ctx_factory()
     queue = cl.CommandQueue(ctx)
-- 
GitLab