diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py index 7dc76c4c9f78ce128044eecf2502e9cdd7a7386c..ae188f0bbae207261dc253ef7864992320216556 100644 --- a/test/test_pytato_arraycontext.py +++ b/test/test_pytato_arraycontext.py @@ -121,6 +121,62 @@ def test_arg_size_limit(actx_factory): assert ran_callback +@pytest.mark.parametrize("pass_allocator", ["auto_none", "auto_true", "auto_false", + "pass_buffer", "pass_svm", + "pass_buffer_pool", "pass_svm_pool"]) +def test_pytato_actx_allocator(actx_factory, pass_allocator): + base_actx = actx_factory() + alloc = None + use_memory_pool = None + + if pass_allocator == "auto_true": + use_memory_pool = True + elif pass_allocator == "auto_false": + use_memory_pool = False + elif pass_allocator == "pass_buffer": + from pyopencl.tools import ImmediateAllocator + alloc = ImmediateAllocator(base_actx.queue) + elif pass_allocator == "pass_svm": + from pyopencl.tools import SVMAllocator + alloc = SVMAllocator(base_actx.queue.context, queue=base_actx.queue) + elif pass_allocator == "pass_buffer_pool": + from pyopencl.tools import ImmediateAllocator, MemoryPool + alloc = MemoryPool(ImmediateAllocator(base_actx.queue)) + elif pass_allocator == "pass_svm_pool": + from pyopencl.tools import SVMAllocator, SVMPool + alloc = SVMPool(SVMAllocator(base_actx.queue.context, queue=base_actx.queue)) + + actx = _PytatoPyOpenCLArrayContextForTests(base_actx.queue, allocator=alloc, + use_memory_pool=use_memory_pool) + + def twice(x): + return 2 * x + + f = actx.compile(twice) + res = actx.to_numpy(f(99)) + + assert res == 198 + + # Also test a case in which SVM is not available + if pass_allocator in ["auto_none", "auto_true", "auto_false"]: + def override_has_svm(dev): + return False + + import pyopencl.characterize as cl_char + + backup_has_coarse_grain_buffer_svm = cl_char.has_coarse_grain_buffer_svm + cl_char.has_coarse_grain_buffer_svm = override_has_svm + + actx = _PytatoPyOpenCLArrayContextForTests(base_actx.queue, allocator=alloc, + use_memory_pool=use_memory_pool) + f = actx.compile(twice) + res = actx.to_numpy(f(99)) + + assert res == 198 + + cl_char.has_coarse_grain_buffer_svm = backup_has_coarse_grain_buffer_svm + + if __name__ == "__main__": import sys if len(sys.argv) > 1: