From c17b7d30416482aaa4a9b0597f91d38484d0a487 Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Tue, 22 Mar 2022 16:52:05 -0500 Subject: [PATCH] pass allocator to execute_distributed_partition (#240) --- grudge/array_context.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/grudge/array_context.py b/grudge/array_context.py index 4e1ac838..8ee89a35 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -88,6 +88,19 @@ class PyOpenCLArrayContext(_PyOpenCLArrayContextBase): to understand :mod:`grudge`-specific transform metadata. (Of which there isn't any, for now.) """ + def __init__(self, queue: "pyopencl.CommandQueue", + allocator: Optional["pyopencl.tools.AllocatorInterface"] = None, + wait_event_queue_length: Optional[int] = None, + force_device_scalars: bool = False) -> None: + + if allocator is None: + from warnings import warn + warn("No memory allocator specified, please pass one. " + "(Preferably a pyopencl.tools.MemoryPool in order " + "to reduce device allocations)") + + super().__init__(queue, allocator, + wait_event_queue_length, force_device_scalars) # }}} @@ -99,6 +112,13 @@ class PytatoPyOpenCLArrayContext(_PytatoPyOpenCLArrayContextBase): Extends it to understand :mod:`grudge`-specific transform metadata. (Of which there isn't any, for now.) """ + def __init__(self, queue, allocator=None): + if allocator is None: + from warnings import warn + warn("No memory allocator specified, please pass one. " + "(Preferably a pyopencl.tools.MemoryPool in order " + "to reduce device allocations)") + super().__init__(queue, allocator) # }}} @@ -210,6 +230,7 @@ class _DistributedCompiledFunction: out_dict = execute_distributed_partition( self.distributed_partition, self.part_id_to_prg, self.actx.queue, self.actx.mpi_communicator, + allocator=self.actx.allocator, input_args=input_args_for_prg) def to_output_template(keys, _): @@ -224,6 +245,12 @@ class MPIPytatoArrayContextBase(MPIBasedArrayContext): def __init__( self, mpi_communicator, queue, *, mpi_base_tag, allocator=None ) -> None: + if allocator is None: + from warnings import warn + warn("No memory allocator specified, please pass one. " + "(Preferably a pyopencl.tools.MemoryPool in order " + "to reduce device allocations)") + super().__init__(queue, allocator) self.mpi_communicator = mpi_communicator -- GitLab