From c17b7d30416482aaa4a9b0597f91d38484d0a487 Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Tue, 22 Mar 2022 16:52:05 -0500
Subject: [PATCH] pass allocator to execute_distributed_partition (#240)

---
 grudge/array_context.py | 27 +++++++++++++++++++++++++++
 1 file changed, 27 insertions(+)

diff --git a/grudge/array_context.py b/grudge/array_context.py
index 4e1ac838..8ee89a35 100644
--- a/grudge/array_context.py
+++ b/grudge/array_context.py
@@ -88,6 +88,19 @@ class PyOpenCLArrayContext(_PyOpenCLArrayContextBase):
     to understand :mod:`grudge`-specific transform metadata. (Of which there isn't
     any, for now.)
     """
+    def __init__(self, queue: "pyopencl.CommandQueue",
+            allocator: Optional["pyopencl.tools.AllocatorInterface"] = None,
+            wait_event_queue_length: Optional[int] = None,
+            force_device_scalars: bool = False) -> None:
+
+        if allocator is None:
+            from warnings import warn
+            warn("No memory allocator specified, please pass one. "
+                 "(Preferably a pyopencl.tools.MemoryPool in order "
+                 "to reduce device allocations)")
+
+        super().__init__(queue, allocator,
+                         wait_event_queue_length, force_device_scalars)
 
 # }}}
 
@@ -99,6 +112,13 @@ class PytatoPyOpenCLArrayContext(_PytatoPyOpenCLArrayContextBase):
     Extends it to understand :mod:`grudge`-specific transform metadata. (Of
     which there isn't any, for now.)
     """
+    def __init__(self, queue, allocator=None):
+        if allocator is None:
+            from warnings import warn
+            warn("No memory allocator specified, please pass one. "
+                 "(Preferably a pyopencl.tools.MemoryPool in order "
+                 "to reduce device allocations)")
+        super().__init__(queue, allocator)
 
 # }}}
 
@@ -210,6 +230,7 @@ class _DistributedCompiledFunction:
         out_dict = execute_distributed_partition(
                 self.distributed_partition, self.part_id_to_prg,
                 self.actx.queue, self.actx.mpi_communicator,
+                allocator=self.actx.allocator,
                 input_args=input_args_for_prg)
 
         def to_output_template(keys, _):
@@ -224,6 +245,12 @@ class MPIPytatoArrayContextBase(MPIBasedArrayContext):
     def __init__(
             self, mpi_communicator, queue, *, mpi_base_tag, allocator=None
             ) -> None:
+        if allocator is None:
+            from warnings import warn
+            warn("No memory allocator specified, please pass one. "
+                 "(Preferably a pyopencl.tools.MemoryPool in order "
+                 "to reduce device allocations)")
+
         super().__init__(queue, allocator)
 
         self.mpi_communicator = mpi_communicator
-- 
GitLab