diff --git a/loopy/target/numba.py b/loopy/target/numba.py index 6946063ee04f52a4890344b4cbff9446bacb6923..81614371079e1073049530d9fce217c790157d92 100644 --- a/loopy/target/numba.py +++ b/loopy/target/numba.py @@ -93,6 +93,7 @@ class NumbaTarget(TargetBase): from warnings import warn warn("The Numba targets are not yet feature-complete", LoopyWarning, stacklevel=2) + self._jit_func_cache = {} def split_kernel_at_global_barriers(self): return False @@ -103,6 +104,18 @@ class NumbaTarget(TargetBase): def get_device_ast_builder(self): return NumbaJITASTBuilder(self) + def get_kernel_executor_cache_key(self, *args, **kwargs): + return None + + def get_kernel_executor(self, knl, *args, **kwargs): + if knl not in self._jit_func_cache: + from loopy import generate_code + code, _ = generate_code(knl) + ns = {} + exec(code, ns) + self._jit_func_cache[knl] = ns[knl.name] + return self._jit_func_cache[knl] + # {{{ types @memoize_method diff --git a/test/test_target.py b/test/test_target.py index b656383e7bbe008892f45159faadd2d195d67a3b..046e6df5724afb4ee0d6d2ebb15bef575910bde4 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -224,6 +224,20 @@ def test_numba_cuda_target(): print(lp.generate_code_v2(knl).all_code()) +def test_numba_call(): + try: + import numba # noqa + except ImportError: + pytest.skip("Numba not available.") + target = lp.NumbaTarget() + knl = lp.make_kernel("{ [i]: 0<=i 1: exec(sys.argv[1])