Skip to content
Snippets Groups Projects
Commit 3a5aea7f authored by Matthias Diener's avatar Matthias Diener Committed by Andreas Klöckner
Browse files

test the various allocation options

parent 9e75a457
No related branches found
No related tags found
No related merge requests found
...@@ -121,6 +121,62 @@ def test_arg_size_limit(actx_factory): ...@@ -121,6 +121,62 @@ def test_arg_size_limit(actx_factory):
assert ran_callback assert ran_callback
@pytest.mark.parametrize("pass_allocator", ["auto_none", "auto_true", "auto_false",
"pass_buffer", "pass_svm",
"pass_buffer_pool", "pass_svm_pool"])
def test_pytato_actx_allocator(actx_factory, pass_allocator):
base_actx = actx_factory()
alloc = None
use_memory_pool = None
if pass_allocator == "auto_true":
use_memory_pool = True
elif pass_allocator == "auto_false":
use_memory_pool = False
elif pass_allocator == "pass_buffer":
from pyopencl.tools import ImmediateAllocator
alloc = ImmediateAllocator(base_actx.queue)
elif pass_allocator == "pass_svm":
from pyopencl.tools import SVMAllocator
alloc = SVMAllocator(base_actx.queue.context, queue=base_actx.queue)
elif pass_allocator == "pass_buffer_pool":
from pyopencl.tools import ImmediateAllocator, MemoryPool
alloc = MemoryPool(ImmediateAllocator(base_actx.queue))
elif pass_allocator == "pass_svm_pool":
from pyopencl.tools import SVMAllocator, SVMPool
alloc = SVMPool(SVMAllocator(base_actx.queue.context, queue=base_actx.queue))
actx = _PytatoPyOpenCLArrayContextForTests(base_actx.queue, allocator=alloc,
use_memory_pool=use_memory_pool)
def twice(x):
return 2 * x
f = actx.compile(twice)
res = actx.to_numpy(f(99))
assert res == 198
# Also test a case in which SVM is not available
if pass_allocator in ["auto_none", "auto_true", "auto_false"]:
def override_has_svm(dev):
return False
import pyopencl.characterize as cl_char
backup_has_coarse_grain_buffer_svm = cl_char.has_coarse_grain_buffer_svm
cl_char.has_coarse_grain_buffer_svm = override_has_svm
actx = _PytatoPyOpenCLArrayContextForTests(base_actx.queue, allocator=alloc,
use_memory_pool=use_memory_pool)
f = actx.compile(twice)
res = actx.to_numpy(f(99))
assert res == 198
cl_char.has_coarse_grain_buffer_svm = backup_has_coarse_grain_buffer_svm
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
if len(sys.argv) > 1: if len(sys.argv) > 1:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment