diff --git a/loopy/__init__.py b/loopy/__init__.py index 54c3523d5107d5a8516e1cf7cf7a6bbceef1b991..92b7fca77a85d5bf4531100040ab9b8e772d8a8c 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -331,18 +331,34 @@ def register_preamble_generators(kernel, preamble_generators): :returns: *kernel* with *manglers* registered """ + from loopy.tools import unpickles_equally + new_pgens = kernel.preamble_generators[:] for pgen in preamble_generators: if pgen not in new_pgens: + if not unpickles_equally(pgen): + raise LoopyError("preamble generator '%s' does not " + "compare equally after being upickled " + "and would thus disrupt loopy's caches" + % pgen) + new_pgens.insert(0, pgen) return kernel.copy(preamble_generators=new_pgens) def register_symbol_manglers(kernel, manglers): + from loopy.tools import unpickles_equally + new_manglers = kernel.symbol_manglers[:] for m in manglers: if m not in new_manglers: + if not unpickles_equally(m): + raise LoopyError("mangler '%s' does not " + "compare equally after being upickled " + "and would disrupt loopy's caches" + % m) + new_manglers.insert(0, m) return kernel.copy(symbol_manglers=new_manglers) @@ -354,9 +370,17 @@ def register_function_manglers(kernel, manglers): returning a :class:`loopy.CallMangleInfo`. :returns: *kernel* with *manglers* registered """ + from loopy.tools import unpickles_equally + new_manglers = kernel.function_manglers[:] for m in manglers: if m not in new_manglers: + if not unpickles_equally(m): + raise LoopyError("mangler '%s' does not " + "compare equally after being upickled " + "and would disrupt loopy's caches" + % m) + new_manglers.insert(0, m) return kernel.copy(function_manglers=new_manglers) diff --git a/loopy/tools.py b/loopy/tools.py index 15d2a859a9cbe7c7a4e0711e705c6ccce6fff61b..8c5d36390d75123ca433a30947ac2631d734f779 100644 --- a/loopy/tools.py +++ b/loopy/tools.py @@ -1,5 +1,4 @@ from __future__ import division, absolute_import -import six __copyright__ = "Copyright (C) 2012 Andreas Kloeckner" @@ -23,6 +22,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import six + import collections import numpy as np from pytools import memoize_method @@ -581,6 +582,11 @@ class LazilyUnpicklingListWithEqAndPersistentHashing(LazilyUnpicklingList): # }}} +def unpickles_equally(obj): + from six.moves.cPickle import loads, dumps + return loads(dumps(obj)) == obj + + def is_interned(s): return s is None or intern(s) is s diff --git a/test/test_loopy.py b/test/test_loopy.py index 07179ae14b3a8e92c31a7718af3e0f6d8f16f22a..3ceca5a75095053f65b652f8c0f0ec66c5603a86 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2719,9 +2719,13 @@ def test_preamble_with_separate_temporaries(ctx_factory): read_only=True), lp.GlobalArg('data', shape=(data.size,), dtype=np.float64)], ) + # fixt params, and add manglers / preamble - from testlib import SeparateTemporariesPreambleTestHelper - preamble_with_sep_helper = SeparateTemporariesPreambleTestHelper( + from testlib import ( + SeparateTemporariesPreambleTestMangler, + SeparateTemporariesPreambleTestPreambleGenerator, + ) + func_info = dict( func_name='indirect', func_arg_dtypes=(np.int32, np.int32, np.int32), func_result_dtypes=(np.int32,), @@ -2730,9 +2734,9 @@ def test_preamble_with_separate_temporaries(ctx_factory): kernel = lp.fix_parameters(kernel, **{'n': n}) kernel = lp.register_preamble_generators( - kernel, [preamble_with_sep_helper.preamble_gen]) + kernel, [SeparateTemporariesPreambleTestPreambleGenerator(**func_info)]) kernel = lp.register_function_manglers( - kernel, [preamble_with_sep_helper.mangler]) + kernel, [SeparateTemporariesPreambleTestMangler(**func_info)]) print(lp.generate_code(kernel)[0]) # and call (functionality unimportant, more that it compiles) diff --git a/test/testlib.py b/test/testlib.py index 73de4199d31736230026eb7f2eb7939a93806369..ad290ee7c60297aadd4a6baa0814b8976403cb53 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -17,14 +17,29 @@ class GridOverride(object): # {{{ test_preamble_with_separate_temporaries -class SeparateTemporariesPreambleTestHelper: +class SeparateTemporariesPreambleTestDataHolder: def __init__(self, func_name, func_arg_dtypes, func_result_dtypes, arr): self.func_name = func_name self.func_arg_dtypes = func_arg_dtypes self.func_result_dtypes = func_result_dtypes self.arr = arr - def mangler(self, kernel, name, arg_dtypes): + def __eq__(self, other): + import numpy as np + return ( + isinstance(other, type(self)) + and self.func_name == other.func_name + and self.func_arg_dtypes == other.func_arg_dtypes + and self.func_result_dtypes == other.func_result_dtypes + and np.array_equal(self.arr, other.arr)) + + def __ne__(self, other): + return not self.__eq__(other) + + +class SeparateTemporariesPreambleTestMangler( + SeparateTemporariesPreambleTestDataHolder): + def __call__(self, kernel, name, arg_dtypes): """ A function that will return a :class:`loopy.kernel.data.CallMangleInfo` to interface with the calling :class:`loopy.LoopKernel` @@ -61,7 +76,10 @@ class SeparateTemporariesPreambleTestHelper: self.func_result_dtypes), arg_dtypes=arg_dtypes) - def preamble_gen(self, preamble_info): + +class SeparateTemporariesPreambleTestPreambleGenerator( + SeparateTemporariesPreambleTestDataHolder): + def __call__(self, preamble_info): from loopy.kernel.data import temp_var_scope as scopes # find a function matching our name