From e29e953451a096844ac0d5a45e8362f17ebfb64e Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 20 Feb 2014 13:24:16 -0600 Subject: [PATCH] Implement a better fix for dtype pickling --- loopy/codegen/__init__.py | 19 ++++++++------ loopy/kernel/array.py | 36 +++++++++++++++++++++------ loopy/kernel/data.py | 35 +++++++++++++------------- loopy/tools.py | 52 ++++++++++++++++++++++++++++++--------- test/test_loopy.py | 4 +++ 5 files changed, 103 insertions(+), 43 deletions(-) diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py index fb4fcc95b..cbc64add7 100644 --- a/loopy/codegen/__init__.py +++ b/loopy/codegen/__init__.py @@ -301,9 +301,12 @@ class ImplementedDataInfo(Record): unvec_shape=None, unvec_strides=None, offset_for_name=None, stride_for_name_and_axis=None, allows_offset=None): + + from loopy.tools import PicklableDtype + Record.__init__(self, name=name, - dtype=np.dtype(dtype), + picklable_dtype=PicklableDtype(dtype), cgen_declarator=cgen_declarator, arg_class=arg_class, base_name=base_name, @@ -315,13 +318,13 @@ class ImplementedDataInfo(Record): stride_for_name_and_axis=stride_for_name_and_axis, allows_offset=allows_offset) - def __setstate__(self, state): - Record.__setstate__(self, state) - - import loopy as lp - if self.dtype is not None and self.dtype is not lp.auto: - from loopy.tools import fix_dtype_after_unpickling - self.dtype = fix_dtype_after_unpickling(self.dtype) + @property + def dtype(self): + from loopy.tools import PicklableDtype + if isinstance(self.picklable_dtype, PicklableDtype): + return self.picklable_dtype.dtype + else: + return self.picklable_dtype # }}} diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py index e42f586b3..45360a926 100644 --- a/loopy/kernel/array.py +++ b/loopy/kernel/array.py @@ -464,13 +464,18 @@ class ArrayBase(Record): import loopy as lp if dtype is not None and dtype is not lp.auto: - dtype = np.dtype(dtype) + from loopy.tools import PicklableDtype + picklable_dtype = PicklableDtype(dtype) + + if picklable_dtype.dtype == object: + raise TypeError("loopy does not directly support object arrays " + "(object dtype encountered on array '%s') " + "-- you may want to tag the relevant array axis as 'sep'" + % name) + else: + picklable_dtype = dtype - if dtype == object: - raise TypeError("loopy does not directly support object arrays " - "(object dtype encountered on array '%s') " - "-- you may want to tag the relevant array axis as 'sep'" - % name) + del dtype strides_known = strides is not None and strides is not lp.auto shape_known = shape is not None and shape is not lp.auto @@ -575,13 +580,30 @@ class ArrayBase(Record): Record.__init__(self, name=name, - dtype=dtype, + picklable_dtype=picklable_dtype, shape=shape, dim_tags=dim_tags, offset=offset, order=order, **kwargs) + @property + def dtype(self): + from loopy.tools import PicklableDtype + if isinstance(self.picklable_dtype, PicklableDtype): + return self.picklable_dtype.dtype + else: + return self.picklable_dtype + + def get_copy_kwargs(self, **kwargs): + result = Record.get_copy_kwargs(self, **kwargs) + if "dtype" not in result: + result["dtype"] = self.dtype + + del result["picklable_dtype"] + + return result + def stringify(self, include_typename): import loopy as lp diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index ad173d96a..6a333966a 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -177,13 +177,21 @@ def parse_tag(tag): class KernelArgument(Record): - def __setstate__(self, state): - Record.__setstate__(self, state) + def __init__(self, **kwargs): + dtype = kwargs.pop("dtype") - import loopy as lp - if self.dtype is not None and self.dtype is not lp.auto: - from loopy.tools import fix_dtype_after_unpickling - self.dtype = fix_dtype_after_unpickling(self.dtype) + from loopy.tools import PicklableDtype + kwargs["picklable_dtype"] = PicklableDtype(dtype) + Record.__init__(self, **kwargs) + + def get_copy_kwargs(self, **kwargs): + result = Record.get_copy_kwargs(self, **kwargs) + if "dtype" not in result: + result["dtype"] = self.dtype + + del result["picklable_dtype"] + + return result class GlobalArg(ArrayBase, KernelArgument): @@ -245,7 +253,7 @@ class ValueArg(KernelArgument): if dtype is not None: dtype = np.dtype(dtype) - Record.__init__(self, name=name, dtype=dtype, + KernelArgument.__init__(self, name=name, dtype=dtype, approximately=approximately) def __str__(self): @@ -295,7 +303,7 @@ class TemporaryVariable(ArrayBase): "is_local" ] - def __init__(self, name, dtype, shape=(), is_local=auto, + def __init__(self, name, dtype=None, shape=(), is_local=auto, dim_tags=None, offset=0, strides=None, order=None, base_indices=None, storage_shape=None): """ @@ -311,7 +319,8 @@ class TemporaryVariable(ArrayBase): if base_indices is None: base_indices = (0,) * len(shape) - ArrayBase.__init__(self, name=name, dtype=dtype, shape=shape, + ArrayBase.__init__(self, name=name, + dtype=dtype, shape=shape, dim_tags=dim_tags, order="C", base_indices=base_indices, is_local=is_local, storage_shape=storage_shape) @@ -343,14 +352,6 @@ class TemporaryVariable(ArrayBase): def __str__(self): return self.stringify(include_typename=False) - def __setstate__(self, state): - ArrayBase.__setstate__(self, state) - - import loopy as lp - if self.dtype is not None and self.dtype is not lp.auto: - from loopy.tools import fix_dtype_after_unpickling - self.dtype = fix_dtype_after_unpickling(self.dtype) - # }}} diff --git a/loopy/tools.py b/loopy/tools.py index f894379d2..56cdd3ed8 100644 --- a/loopy/tools.py +++ b/loopy/tools.py @@ -91,16 +91,46 @@ class LoopyKeyBuilder(KeyBuilderBase): # }}} -def fix_dtype_after_unpickling(dtype): - # Work around https://github.com/numpy/numpy/issues/4317 - from pyopencl.compyte.dtypes import DTYPE_TO_NAME - for other_dtype in DTYPE_TO_NAME: - # Incredibly, DTYPE_TO_NAME contains strings... - if isinstance(other_dtype, np.dtype) and dtype == other_dtype: - return other_dtype - - raise RuntimeError( - "don't know what to do with (likely broken) unpickled dtype '%s'" - % dtype) +class PicklableDtype(object): + """This object works around several issues with pickling :class:`numpy.dtype` + objects. It does so by serving as a picklable wrapper around the original + dtype. + + The issues are the following + + - :class:`numpy.dtype` objects for custom types in :mod:`loopy` are usually + registered in the :mod:`pyopencl` dtype registry. This registration may + have been lost after unpickling. This container restores it implicitly, + as part of unpickling. + + - There is a`numpy bug <https://github.com/numpy/numpy/issues/4317>`_ + that prevents unpickled dtypes from hashing properly. This is solved + by retrieving the 'canonical' type from the dtype registry. + """ + + def __init__(self, dtype): + self.dtype = np.dtype(dtype) + + def __hash__(self): + return hash(self.dtype) + + def __eq__(self, other): + return ( + type(self) == type(other) + and self.dtype == other.dtype) + + def __ne__(self, other): + return not self.__eq__(self, other) + + def __getstate__(self): + from pyopencl.compyte.dtypes import DTYPE_TO_NAME + c_name = DTYPE_TO_NAME[self.dtype] + + return (c_name, self.dtype) + + def __setstate__(self, state): + name, dtype = state + from pyopencl.tools import get_or_register_dtype + self.dtype = get_or_register_dtype([name], dtype) # vim: foldmethod=marker diff --git a/test/test_loopy.py b/test/test_loopy.py index a8b7ff74d..fe17ba3e7 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -365,6 +365,8 @@ def test_stencil_with_overfetch(ctx_factory): def test_eq_constraint(ctx_factory): + logging.basicConfig(level=logging.INFO) + ctx = ctx_factory() knl = lp.make_kernel( @@ -388,6 +390,8 @@ def test_eq_constraint(ctx_factory): def test_argmax(ctx_factory): + logging.basicConfig(level=logging.INFO) + dtype = np.dtype(np.float32) ctx = ctx_factory() queue = cl.CommandQueue(ctx) -- GitLab