Skip to content
Snippets Groups Projects
Unverified Commit c17b7d30 authored by Matthias Diener's avatar Matthias Diener Committed by GitHub
Browse files

pass allocator to execute_distributed_partition (#240)

parent 41d30aaa
No related branches found
No related tags found
No related merge requests found
Pipeline #276868 passed
...@@ -88,6 +88,19 @@ class PyOpenCLArrayContext(_PyOpenCLArrayContextBase): ...@@ -88,6 +88,19 @@ class PyOpenCLArrayContext(_PyOpenCLArrayContextBase):
to understand :mod:`grudge`-specific transform metadata. (Of which there isn't to understand :mod:`grudge`-specific transform metadata. (Of which there isn't
any, for now.) any, for now.)
""" """
def __init__(self, queue: "pyopencl.CommandQueue",
allocator: Optional["pyopencl.tools.AllocatorInterface"] = None,
wait_event_queue_length: Optional[int] = None,
force_device_scalars: bool = False) -> None:
if allocator is None:
from warnings import warn
warn("No memory allocator specified, please pass one. "
"(Preferably a pyopencl.tools.MemoryPool in order "
"to reduce device allocations)")
super().__init__(queue, allocator,
wait_event_queue_length, force_device_scalars)
# }}} # }}}
...@@ -99,6 +112,13 @@ class PytatoPyOpenCLArrayContext(_PytatoPyOpenCLArrayContextBase): ...@@ -99,6 +112,13 @@ class PytatoPyOpenCLArrayContext(_PytatoPyOpenCLArrayContextBase):
Extends it to understand :mod:`grudge`-specific transform metadata. (Of Extends it to understand :mod:`grudge`-specific transform metadata. (Of
which there isn't any, for now.) which there isn't any, for now.)
""" """
def __init__(self, queue, allocator=None):
if allocator is None:
from warnings import warn
warn("No memory allocator specified, please pass one. "
"(Preferably a pyopencl.tools.MemoryPool in order "
"to reduce device allocations)")
super().__init__(queue, allocator)
# }}} # }}}
...@@ -210,6 +230,7 @@ class _DistributedCompiledFunction: ...@@ -210,6 +230,7 @@ class _DistributedCompiledFunction:
out_dict = execute_distributed_partition( out_dict = execute_distributed_partition(
self.distributed_partition, self.part_id_to_prg, self.distributed_partition, self.part_id_to_prg,
self.actx.queue, self.actx.mpi_communicator, self.actx.queue, self.actx.mpi_communicator,
allocator=self.actx.allocator,
input_args=input_args_for_prg) input_args=input_args_for_prg)
def to_output_template(keys, _): def to_output_template(keys, _):
...@@ -224,6 +245,12 @@ class MPIPytatoArrayContextBase(MPIBasedArrayContext): ...@@ -224,6 +245,12 @@ class MPIPytatoArrayContextBase(MPIBasedArrayContext):
def __init__( def __init__(
self, mpi_communicator, queue, *, mpi_base_tag, allocator=None self, mpi_communicator, queue, *, mpi_base_tag, allocator=None
) -> None: ) -> None:
if allocator is None:
from warnings import warn
warn("No memory allocator specified, please pass one. "
"(Preferably a pyopencl.tools.MemoryPool in order "
"to reduce device allocations)")
super().__init__(queue, allocator) super().__init__(queue, allocator)
self.mpi_communicator = mpi_communicator self.mpi_communicator = mpi_communicator
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment