diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index 81a64484c096e338f6b1aae400f75f0d78332b51..e66838ef12d1e052b0712094c84ea13f465b3fb3 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -135,7 +135,6 @@ class LoopKernel(RecordWithoutPickling): on expressions the user specifies later. .. attribute:: cache_manager - .. attribute:: isl_context .. attribute:: options An instance of :class:`loopy.Options` @@ -170,7 +169,6 @@ class LoopKernel(RecordWithoutPickling): applied_iname_rewrites=[], cache_manager=None, index_dtype=np.int32, - isl_context=None, options=None, state=kernel_state.INITIAL, @@ -283,7 +281,6 @@ class LoopKernel(RecordWithoutPickling): function_manglers=function_manglers, symbol_manglers=symbol_manglers, index_dtype=index_dtype, - isl_context=isl_context, options=options, state=state) @@ -472,6 +469,13 @@ class LoopKernel(RecordWithoutPickling): def get_home_domain_index(self, iname): return self._get_home_domain_map()[iname] + @memoize_method + def isl_context(self): + for dom in self.domains: + return dom.get_ctx() + + assert False + @memoize_method def combine_domains(self, domains): """ diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index e3f7ab14ee4af3f018099e2ad58cc525c2fab4ba..2851051038c988f926a16b56072f12dd144a7212 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -1086,7 +1086,6 @@ def make_kernel(domains, instructions, kernel_data=["..."], **kwargs): isl_context = domain.get_ctx() if isl_context is None: isl_context = isl.Context() - kwargs["isl_context"] = isl_context # }}} diff --git a/test/test_loopy.py b/test/test_loopy.py index 07a7379008e7125853130fceadccf82f5bbad8bc..a3f07f299bff6c9b3121e963dd1a612bd6e89725 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1500,6 +1500,9 @@ def test_rob_stroud_bernstein_full(ctx_factory): slabs=(0, 1)) knl = lp.tag_inames(knl, dict(i2="l.1", alpha1="unr", alpha2="unr")) + from pickle import dumps, loads + knl = loads(dumps(knl)) + print lp.CompiledKernel(ctx, knl).get_highlighted_code( dict( qpts=np.float32,