From 709c5cdb91eca0266cd9635d2b190234ce9f999d Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 18 Feb 2014 11:59:54 -0600
Subject: [PATCH] Add caching to preprocessing, scheduling, and code generation

---
 loopy/codegen/__init__.py |  73 ++++++++++++++++++++--
 loopy/kernel/__init__.py  |  83 ++++++++++++++++++++++++-
 loopy/kernel/array.py     |  27 +++++++-
 loopy/kernel/data.py      | 125 +++++++++++++++++++++++++++++++++++---
 loopy/options.py          |   7 +++
 loopy/preprocess.py       |  27 +++++++-
 loopy/schedule.py         |  66 ++++++++++++++------
 loopy/tools.py            |  85 ++++++++++++++++++++++++++
 8 files changed, 458 insertions(+), 35 deletions(-)

diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index c33f336cb..1c9eb386a 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -29,6 +29,10 @@ import islpy as isl
 
 import numpy as np
 
+from pytools.persistent_dict import PersistentDict
+from loopy.tools import LoopyKeyBuilder
+from loopy.version import VERSION_TEXT
+
 import logging
 logger = logging.getLogger(__name__)
 
@@ -206,13 +210,41 @@ class CodeGenerationState(object):
 
 # {{{ cgen overrides
 
-from cgen import POD as PODBase
+from cgen import Declarator
 
 
-class POD(PODBase):
-    def get_decl_pair(self):
+class POD(Declarator):
+    """A simple declarator: The type is given as a :class:`numpy.dtype`
+    and the *name* is given as a string.
+    """
+
+    def __init__(self, dtype, name):
+        dtype = np.dtype(dtype)
+
         from pyopencl.tools import dtype_to_ctype
-        return [dtype_to_ctype(self.dtype)], self.name
+        self.ctype = dtype_to_ctype(dtype)
+        self.name = name
+
+    def get_decl_pair(self):
+        return [self.ctype], self.name
+
+    def struct_maker_code(self, name):
+        return name
+
+    @property
+    def dtype(self):
+        from pyopencl.tools import NAME_TO_DTYPE
+        return NAME_TO_DTYPE[self.ctype]
+
+    def struct_format(self):
+        return self.dtype.char
+
+    def alignment_requirement(self):
+        import pyopencl._pvt_struct as _struct
+        return _struct.calcsize(self.struct_format())
+
+    def default_value(self):
+        return 0
 
 # }}}
 
@@ -283,9 +315,21 @@ class ImplementedDataInfo(Record):
                 stride_for_name_and_axis=stride_for_name_and_axis,
                 allows_offset=allows_offset)
 
+    def __setstate__(self, state):
+        Record.__setstate__(self, state)
+
+        import loopy as lp
+        if self.dtype is not None and self.dtype is not lp.auto:
+            from loopy.tools import fix_dtype_after_unpickling
+            self.dtype = fix_dtype_after_unpickling(self.dtype)
+
 # }}}
 
 
+code_gen_cache = PersistentDict("loopy-code-gen-cache-"+VERSION_TEXT,
+        key_builder=LoopyKeyBuilder())
+
+
 # {{{ main code generation entrypoint
 
 def generate_code(kernel, device=None):
@@ -297,6 +341,19 @@ def generate_code(kernel, device=None):
         raise LoopyError("cannot generate code for a kernel that has not been "
                 "scheduled")
 
+    if device is not None:
+        device_id = device.persistent_unique_id
+    else:
+        device_id = None
+
+    code_gen_cache_key = (kernel, device_id)
+    try:
+        result = code_gen_cache[code_gen_cache_key]
+        logger.info("%s: code generation cache hit" % kernel.name)
+        return result
+    except KeyError:
+        pass
+
     from loopy.preprocess import infer_unknown_types
     kernel = infer_unknown_types(kernel, expect_completion=True)
 
@@ -435,7 +492,13 @@ def generate_code(kernel, device=None):
 
     logger.info("%s: generate code: done" % kernel.name)
 
-    return result, impl_arg_info
+    result = result, impl_arg_info
+
+    for arg in impl_arg_info:
+        print arg.name, arg.dtype
+
+    code_gen_cache[code_gen_cache_key] = result
+    return result
 
 # }}}
 
diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py
index 98e496cfc..c70fffa07 100644
--- a/loopy/kernel/__init__.py
+++ b/loopy/kernel/__init__.py
@@ -1030,7 +1030,9 @@ class LoopKernel(RecordWithoutPickling):
 
     # }}}
 
-    def __getinitargs__(self):
+    # {{{ pickling
+
+    def __getstate__(self):
         result = dict(
                 (key, getattr(self, key))
                 for key in self.__class__.fields
@@ -1040,6 +1042,85 @@ class LoopKernel(RecordWithoutPickling):
 
         return result
 
+    def __setstate__(self, state):
+        for k, v in state.iteritems():
+            setattr(self, k, v)
+
+        from loopy.kernel.tools import SetOperationCacheManager
+        self.cache_manager = SetOperationCacheManager()
+
+    # }}}
+
+    # {{{ persistent hash key generation / comparison
+
+    hash_fields = [
+            "domains",
+            "instructions",
+            "args",
+            "schedule",
+            "name",
+            "preambles",
+            "assumptions",
+            "local_sizes",
+            "temporary_variables",
+            "iname_to_tag",
+            "substitutions",
+            "iname_slab_increments",
+            "loop_priority",
+            "silenced_warnings",
+            "options",
+            "state",
+            ]
+
+    comparison_fields = hash_fields + [
+            # Contains pymbolic expressions, hence a (small) headache to hash.
+            # Likely not needed for hash uniqueness => headache avoided.
+            "applied_iname_rewrites",
+
+            # These are lists of functions. It's not clear how to
+            # hash these correctly, so let's not attempt it. We'll
+            # just assume that the rest of the hash is specific enough
+            # that we won't have to rely on differences in these to
+            # resolve hash conflicts.
+
+            "preamble_generators",
+            "function_manglers",
+            "symbol_manglers",
+            ]
+
+    def update_persistent_hash(self, key_hash, key_builder):
+        """Custom hash computation function for use with
+        :class:`pytools.persistent_dict.PersistentDict`.
+
+        Only works in conjunction with :class:`loopy.tools.KeyBuilder`.
+        """
+        for field_name in self.hash_fields:
+            key_builder.rec(key_hash, getattr(self, field_name))
+
+    def __eq__(self, other):
+        if not isinstance(other, LoopKernel):
+            return False
+
+        for field_name in self.comparison_fields:
+            if field_name == "domains":
+                for set_a, set_b in zip(self.domains, other.domains):
+                    if not set_a.plain_is_equal(set_b):
+                        return False
+
+            elif field_name == "assumptions":
+                if not self.assumptions.plain_is_equal(other.assumptions):
+                    return False
+
+            elif getattr(self, field_name) != getattr(other, field_name):
+                return False
+
+        return True
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+    # }}}
+
 # }}}
 
 # vim: foldmethod=marker
diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py
index 9b47636c2..e42f586b3 100644
--- a/loopy/kernel/array.py
+++ b/loopy/kernel/array.py
@@ -39,7 +39,12 @@ from loopy.tools import is_integer
 # {{{ array dimension tags
 
 class ArrayDimImplementationTag(Record):
-    pass
+    def update_persistent_hash(self, key_hash, key_builder):
+        """Custom hash computation function for use with
+        :class:`pytools.persistent_dict.PersistentDict`.
+        """
+
+        key_builder.rec(key_hash, self.stringify(True).encode("utf8"))
 
 
 class _StrideArrayDimTagBase(ArrayDimImplementationTag):
@@ -93,10 +98,15 @@ class FixedStrideArrayDimTag(_StrideArrayDimTagBase):
 
 class ComputedStrideArrayDimTag(_StrideArrayDimTagBase):
     """
-    :arg order: "C" or "F", indicating whether this argument dimension will be added
+    .. attribute:: order
+
+        "C" or "F", indicating whether this argument dimension will be added
         as faster-moving ("C") or more-slowly-moving ("F") than the previous
         argument.
-    :arg pad_to: :attr:`ArrayBase.dtype` granularity to which to pad this dimension
+
+    .. attribute:: pad_to
+
+        :attr:`ArrayBase.dtype` granularity to which to pad this dimension
 
     This type of stride arg dim gets converted to :class:`FixedStrideArrayDimTag`
     on input to :class:`ArrayBase` subclasses.
@@ -614,6 +624,17 @@ class ArrayBase(Record):
     def __repr__(self):
         return "<%s>" % self.__str__()
 
+    def update_persistent_hash(self, key_hash, key_builder):
+        """Custom hash computation function for use with
+        :class:`pytools.persistent_dict.PersistentDict`.
+        """
+
+        key_builder.rec(key_hash, self.name)
+        key_builder.rec(key_hash, self.dtype)
+        key_builder.update_for_pymbolic_expression(key_hash, self.shape)
+        key_builder.rec(key_hash, self.dim_tags)
+        key_builder.rec(key_hash, self.offset)
+
     @property
     @memoize_method
     def numpy_strides(self):
diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py
index 07ddd6bc4..ad173d96a 100644
--- a/loopy/kernel/data.py
+++ b/loopy/kernel/data.py
@@ -28,10 +28,10 @@ THE SOFTWARE.
 import numpy as np
 from pytools import Record, memoize_method
 from loopy.kernel.array import ArrayBase
-from loopy.diagnostic import LoopyError
+from loopy.diagnostic import LoopyError  # noqa
 
 
-class auto:
+class auto(object):
     """A generic placeholder object for something that should be automatically
     detected.  See, for example, the *shape* or *strides* argument of
     :class:`GlobalArg`.
@@ -46,6 +46,13 @@ class IndexTag(Record):
     def __hash__(self):
         raise RuntimeError("use .key to hash index tags")
 
+    def update_persistent_hash(self, key_hash, key_builder):
+        """Custom hash computation function for use with
+        :class:`pytools.persistent_dict.PersistentDict`.
+        """
+
+        return key_builder.rec(key_hash, self.key)
+
 
 class ParallelTag(IndexTag):
     pass
@@ -58,7 +65,7 @@ class HardwareParallelTag(ParallelTag):
 class UniqueTag(IndexTag):
     @property
     def key(self):
-        return type(self)
+        return type(self).__name__
 
 
 class AxisTag(UniqueTag):
@@ -70,7 +77,7 @@ class AxisTag(UniqueTag):
 
     @property
     def key(self):
-        return (type(self), self.axis)
+        return (type(self).__name__, self.axis)
 
     def __str__(self):
         return "%s.%d" % (
@@ -90,7 +97,9 @@ class LocalIndexTag(LocalIndexTagBase, AxisTag):
 
 
 class AutoLocalIndexTagBase(LocalIndexTagBase):
-    pass
+    @property
+    def key(self):
+        return type(self).__name__
 
 
 class AutoFitLocalIndexTag(AutoLocalIndexTagBase):
@@ -99,7 +108,9 @@ class AutoFitLocalIndexTag(AutoLocalIndexTagBase):
 
 
 class IlpBaseTag(ParallelTag):
-    pass
+    @property
+    def key(self):
+        return type(self).__name__
 
 
 class UnrolledIlpTag(IlpBaseTag):
@@ -116,11 +127,19 @@ class UnrollTag(IndexTag):
     def __str__(self):
         return "unr"
 
+    @property
+    def key(self):
+        return type(self).__name__
+
 
 class ForceSequentialTag(IndexTag):
     def __str__(self):
         return "forceseq"
 
+    @property
+    def key(self):
+        return type(self).__name__
+
 
 def parse_tag(tag):
     if tag is None:
@@ -158,7 +177,13 @@ def parse_tag(tag):
 
 
 class KernelArgument(Record):
-    pass
+    def __setstate__(self, state):
+        Record.__setstate__(self, state)
+
+        import loopy as lp
+        if self.dtype is not None and self.dtype is not lp.auto:
+            from loopy.tools import fix_dtype_after_unpickling
+            self.dtype = fix_dtype_after_unpickling(self.dtype)
 
 
 class GlobalArg(ArrayBase, KernelArgument):
@@ -237,6 +262,14 @@ class ValueArg(KernelArgument):
     def __repr__(self):
         return "<%s>" % self.__str__()
 
+    def update_persistent_hash(self, key_hash, key_builder):
+        """Custom hash computation function for use with
+        :class:`pytools.persistent_dict.PersistentDict`.
+        """
+
+        key_builder.rec(key_hash, self.name)
+        key_builder.rec(key_hash, self.dtype)
+
 # }}}
 
 
@@ -310,6 +343,14 @@ class TemporaryVariable(ArrayBase):
     def __str__(self):
         return self.stringify(include_typename=False)
 
+    def __setstate__(self, state):
+        ArrayBase.__setstate__(self, state)
+
+        import loopy as lp
+        if self.dtype is not None and self.dtype is not lp.auto:
+            from loopy.tools import fix_dtype_after_unpickling
+            self.dtype = fix_dtype_after_unpickling(self.dtype)
+
 # }}}
 
 
@@ -332,6 +373,15 @@ class SubstitutionRule(Record):
         return "%s(%s) := %s" % (
                 self.name, ", ".join(self.arguments), self.expression)
 
+    def update_persistent_hash(self, key_hash, key_builder):
+        """Custom hash computation function for use with
+        :class:`pytools.persistent_dict.PersistentDict`.
+        """
+
+        key_builder.rec(key_hash, self.name)
+        key_builder.rec(key_hash, self.arguments)
+        key_builder.update_for_pymbolic_expression(key_hash, self.expression)
+
 # }}}
 
 
@@ -491,6 +541,34 @@ class InstructionBase(Record):
 
         return result
 
+    # {{{ comparison, hashing
+
+    def __eq__(self, other):
+        if not type(self) == type(other):
+            return False
+
+        for field_name in self.fields:
+            if getattr(self, field_name) != getattr(other, field_name):
+                return False
+
+        return True
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+    def update_persistent_hash(self, key_hash, key_builder):
+        """Custom hash computation function for use with
+        :class:`pytools.persistent_dict.PersistentDict`.
+
+        Only works in conjunction with :class:`loopy.tools.KeyBuilder`.
+        """
+
+        # Order matters for hash forming--sort the field names
+        for field_name in sorted(self.fields):
+            key_builder.rec(key_hash, getattr(self, field_name))
+
+    # }}}
+
 # }}}
 
 
@@ -613,6 +691,21 @@ class ExpressionInstruction(InstructionBase):
             result += "\n" + 10*" " + "if (%s)" % " && ".join(self.predicates)
         return result
 
+    def update_persistent_hash(self, key_hash, key_builder):
+        """Custom hash computation function for use with
+        :class:`pytools.persistent_dict.PersistentDict`.
+
+        Only works in conjunction with :class:`loopy.tools.KeyBuilder`.
+        """
+
+        # Order matters for hash forming--sort the fields.
+        for field_name in sorted(self.fields):
+            if field_name in ["assignee", "expression"]:
+                key_builder.update_for_pymbolic_expression(
+                        key_hash, getattr(self, field_name))
+            else:
+                key_builder.rec(key_hash, getattr(self, field_name))
+
 # }}}
 
 
@@ -790,6 +883,24 @@ class CInstruction(InstructionBase):
         return first_line + "\n    " + "\n    ".join(
                 self.code.split("\n"))
 
+    def update_persistent_hash(self, key_hash, key_builder):
+        """Custom hash computation function for use with
+        :class:`pytools.persistent_dict.PersistentDict`.
+
+        Only works in conjunction with :class:`loopy.tools.KeyBuilder`.
+        """
+
+        # Order matters for hash forming--sort the fields.
+        for field_name in sorted(self.fields):
+            if field_name == "assignees":
+                for a in self.assignees:
+                    key_builder.update_for_pymbolic_expression(key_hash, a)
+            elif field_name == "iname_exprs":
+                for name, val in self.iname_exprs:
+                    key_builder.rec(key_hash, name)
+                    key_builder.update_for_pymbolic_expression(key_hash, val)
+            else:
+                key_builder.rec(key_hash, getattr(self, field_name))
 # }}}
 
 # }}}
diff --git a/loopy/options.py b/loopy/options.py
index 9e656ad63..042eba6ca 100644
--- a/loopy/options.py
+++ b/loopy/options.py
@@ -143,6 +143,13 @@ class Options(Record):
         for f in self.__class__.fields:
             setattr(self, f, getattr(self, f) or getattr(other, f))
 
+    def update_persistent_hash(self, key_hash, key_builder):
+        """Custom hash computation function for use with
+        :class:`pytools.persistent_dict.PersistentDict`.
+        """
+        for field_name in sorted(self.__class__.fields):
+            key_builder.rec(key_hash, getattr(self, field_name))
+
 
 KEY_VAL_RE = re.compile("^([a-zA-Z0-9]+)=(.*)$")
 
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 600fe983e..e32b04eed 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -29,6 +29,10 @@ from loopy.diagnostic import (
         LoopyError, LoopyWarning, WriteRaceConditionWarning, warn,
         LoopyAdvisory)
 
+from pytools.persistent_dict import PersistentDict
+from loopy.tools import LoopyKeyBuilder
+from loopy.version import VERSION_TEXT
+
 import logging
 logger = logging.getLogger(__name__)
 
@@ -1062,12 +1066,30 @@ def adjust_local_temp_var_storage(kernel, device):
 # }}}
 
 
+preprocess_cache = PersistentDict("loopy-preprocess-cache-"+VERSION_TEXT,
+        key_builder=LoopyKeyBuilder())
+
+
 def preprocess_kernel(kernel, device=None):
     from loopy.kernel import kernel_state
     if kernel.state != kernel_state.INITIAL:
         raise LoopyError("cannot re-preprocess an already preprocessed "
                 "kernel")
 
+    if device is not None:
+        device_id = device.persistent_unique_id
+    else:
+        device_id = None
+
+    pp_cache_key = (kernel, device_id)
+    try:
+        result = preprocess_cache[pp_cache_key]
+        logger.info("%s: preprocess cache hit" % kernel.name)
+        return result
+    except KeyError:
+        pass
+
+    print "PREPRO MISS"
     logger.info("%s: preprocess start" % kernel.name)
 
     from loopy.subst import expand_subst
@@ -1111,10 +1133,11 @@ def preprocess_kernel(kernel, device=None):
 
     logger.info("%s: preprocess done" % kernel.name)
 
-    return kernel.copy(
+    kernel = kernel.copy(
             state=kernel_state.PREPROCESSED)
 
+    preprocess_cache[pp_cache_key] = kernel
 
-
+    return kernel
 
 # vim: foldmethod=marker
diff --git a/loopy/schedule.py b/loopy/schedule.py
index c884a2bfd..fa81ae756 100644
--- a/loopy/schedule.py
+++ b/loopy/schedule.py
@@ -28,6 +28,10 @@ import sys
 import islpy as isl
 from loopy.diagnostic import LoopyError  # noqa
 
+from pytools.persistent_dict import PersistentDict
+from loopy.tools import LoopyKeyBuilder
+from loopy.version import VERSION_TEXT
+
 import logging
 logger = logging.getLogger(__name__)
 
@@ -37,17 +41,24 @@ logger = logging.getLogger(__name__)
 class ScheduleItem(Record):
     __slots__ = []
 
+    def update_persistent_hash(self, key_hash, key_builder):
+        """Custom hash computation function for use with
+        :class:`pytools.persistent_dict.PersistentDict`.
+        """
+        for field_name in self.hash_fields:
+            key_builder.rec(key_hash, getattr(self, field_name))
+
 
 class EnterLoop(ScheduleItem):
-    __slots__ = ["iname"]
+    hash_fields = __slots__ = ["iname"]
 
 
 class LeaveLoop(ScheduleItem):
-    __slots__ = ["iname"]
+    hash_fields = __slots__ = ["iname"]
 
 
 class RunInstruction(ScheduleItem):
-    __slots__ = ["insn_id"]
+    hash_fields = __slots__ = ["insn_id"]
 
 
 class Barrier(ScheduleItem):
@@ -60,7 +71,7 @@ class Barrier(ScheduleItem):
 
         ``"local"`` or ``"global"``
     """
-    __slots__ = ["comment", "kind"]
+    hash_fields = __slots__ = ["comment", "kind"]
 
 # }}}
 
@@ -1095,23 +1106,44 @@ def generate_loop_schedules(kernel, debug_args={}):
 # }}}
 
 
+schedule_cache = PersistentDict("loopy-schedule-cache-v2-"+VERSION_TEXT,
+        key_builder=LoopyKeyBuilder())
+
+
 def get_one_scheduled_kernel(kernel):
-    kernel_count = 0
 
-    for scheduled_kernel in generate_loop_schedules(kernel):
-        kernel_count += 1
+    sched_cache_key = kernel
+    try:
+        result, ambiguous = schedule_cache[sched_cache_key]
 
-        if kernel_count == 1:
-            # use the first schedule
-            result = scheduled_kernel
+        logger.info("%s: schedule cache hit" % kernel.name)
+        from_cache = True
+    except KeyError:
+        from_cache = False
+        ambiguous = False
 
-        if kernel_count == 2:
-            from warnings import warn
-            from loopy.diagnostic import LoopyWarning
-            warn("kernel scheduling was ambiguous--more than one "
-                    "schedule found, ignoring", LoopyWarning,
-                    stacklevel=2)
-            break
+        kernel_count = 0
+
+        for scheduled_kernel in generate_loop_schedules(kernel):
+            kernel_count += 1
+
+            if kernel_count == 1:
+                # use the first schedule
+                result = scheduled_kernel
+
+            if kernel_count == 2:
+                ambiguous = True
+                break
+
+    if ambiguous:
+        from warnings import warn
+        from loopy.diagnostic import LoopyWarning
+        warn("kernel scheduling was ambiguous--more than one "
+                "schedule found, ignoring", LoopyWarning,
+                stacklevel=2)
+
+    if not from_cache:
+        schedule_cache[sched_cache_key] = result, ambiguous
 
     return result
 
diff --git a/loopy/tools.py b/loopy/tools.py
index 452a2ca9b..62b893edf 100644
--- a/loopy/tools.py
+++ b/loopy/tools.py
@@ -23,7 +23,92 @@ THE SOFTWARE.
 """
 
 import numpy as np
+from pytools.persistent_dict import KeyBuilder as KeyBuilderBase
+from loopy.symbolic import WalkMapper
 
 
 def is_integer(obj):
     return isinstance(obj, (int, long, np.integer))
+
+
+# {{{ custom KeyBuilder subclass
+
+class PersistentHashWalkMapper(WalkMapper):
+    """A subclass of :class:`loopy.symbolic.WalkMapper` for constructing
+    persistent hash keys for use with
+    :class:`pytools.persistent_dict.PersistentDict`.
+
+    See also :meth:`LoopyKeyBuilder.update_for_pymbolic_expression`.
+    """
+
+    def __init__(self, key_hash):
+        self.key_hash = key_hash
+
+    def visit(self, expr):
+        self.key_hash.update(type(expr).__name__.encode("utf8"))
+
+    def map_variable(self, expr):
+        self.key_hash.update(expr.name.encode("utf8"))
+
+    def map_constant(self, expr):
+        self.key_hash.update(repr(expr).encode("utf8"))
+
+
+class LoopyKeyBuilder(KeyBuilderBase):
+    """A custom :class:`pytools.persistent_dict.KeyBuilder` subclass
+    for objects within :mod:`loopy`.
+    """
+
+    # Lists, sets and dicts aren't immutable. But loopy kernels are, so we're
+    # simply ignoring that fact here.
+    update_for_list = KeyBuilderBase.update_for_tuple
+    update_for_set = KeyBuilderBase.update_for_frozenset
+
+    def update_for_dict(self, key_hash, key):
+        # Order matters for the hash--insert in sorted order.
+        for dict_key in sorted(key.iterkeys()):
+            self.rec(key_hash, (dict_key, key[dict_key]))
+
+    def update_for_BasicSet(self, key_hash, key):
+        from islpy import Printer
+        prn = Printer.to_str(key.get_ctx())
+        getattr(prn, "print_"+key._base_name)(key)
+        key_hash.update(prn.get_str().encode("utf8"))
+
+    def update_for_type(self, key_hash, key):
+        try:
+            method = getattr(self, "update_for_type_"+key.__name__)
+        except AttributeError:
+            pass
+        else:
+            method(key_hash, key)
+            return
+
+        raise TypeError("unsupported type for persistent hash keying: %s"
+                % type(key))
+
+    def update_for_type_auto(self, key_hash, key):
+        key_hash.update("auto".encode("utf8"))
+
+    def update_for_pymbolic_expression(self, key_hash, key):
+        if key is None:
+            self.update_for_NoneType(key_hash, key)
+        else:
+            PersistentHashWalkMapper(key_hash)(key)
+
+# }}}
+
+
+def fix_dtype_after_unpickling(dtype):
+    # Work around https://github.com/numpy/numpy/issues/4317
+    from pyopencl.compyte.dtypes import DTYPE_TO_NAME
+    for other_dtype in DTYPE_TO_NAME:
+        # Incredibly, DTYPE_TO_NAME contains strings...
+        if isinstance(other_dtype, np.dtype) and dtype == other_dtype:
+            return other_dtype
+
+    raise RuntimeError(
+            "don't know what to do with (likely broken) unpickled dtype '%s'"
+            % dtype)
+
+# vim: foldmethod=marker
-- 
GitLab