From 144bbc139f3affeb46a54a9fdc6c5f51ffd9da4b Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Thu, 3 Mar 2022 23:02:26 -0600
Subject: [PATCH] get_reasonable_array_context_class improvements (#230)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* get_reasonable_array_context_class improvements
- add documentation
- add logging
- check for correct loopy branch

* reword doc

Co-authored-by: Andreas Klöckner <inform@tiker.net>

* clarify which branches are needed

* add warning about missing branches

* check for mismatched loopy/meshmode branches

Co-authored-by: Andreas Klöckner <inform@tiker.net>
---
 grudge/array_context.py | 49 +++++++++++++++++++++++++++++++++++------
 1 file changed, 42 insertions(+), 7 deletions(-)

diff --git a/grudge/array_context.py b/grudge/array_context.py
index 675e47a2..4e1ac838 100644
--- a/grudge/array_context.py
+++ b/grudge/array_context.py
@@ -39,11 +39,30 @@ from meshmode.array_context import (
         PyOpenCLArrayContext as _PyOpenCLArrayContextBase,
         PytatoPyOpenCLArrayContext as _PytatoPyOpenCLArrayContextBase)
 
+import logging
+logger = logging.getLogger(__name__)
+
 try:
     # FIXME: temporary workaround while SingleGridWorkBalancingPytatoArrayContext
     # is not available in meshmode's main branch
+    # (it currently needs
+    # https://github.com/kaushikcfd/meshmode/tree/pytato-array-context-transforms)
     from meshmode.array_context import SingleGridWorkBalancingPytatoArrayContext
-    _HAVE_SINGLE_GRID_WORK_BALANCING = True
+
+    try:
+        # Crude check if we have the correct loopy branch
+        # (https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms)
+        from loopy.codegen.result import get_idis_for_kernel  # noqa
+    except ImportError:
+        from warnings import warn
+        warn("Your loopy and meshmode branches are mismatched. "
+             "Please make sure that you have the "
+             "https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms "  # noqa
+             "branch of loopy.")
+        _HAVE_SINGLE_GRID_WORK_BALANCING = False
+    else:
+        _HAVE_SINGLE_GRID_WORK_BALANCING = True
+
 except ImportError:
     _HAVE_SINGLE_GRID_WORK_BALANCING = False
 
@@ -317,23 +336,39 @@ register_pytest_array_context_factory("grudge.pytato-pyopencl",
 def get_reasonable_array_context_class(
         lazy: bool = True, distributed: bool = True
         ) -> Type[ArrayContext]:
+    """Returns a reasonable :class:`PyOpenCLArrayContext` currently
+    supported given the constraints of *lazy* and *distributed*."""
     if lazy:
+        if not _HAVE_SINGLE_GRID_WORK_BALANCING:
+            from warnings import warn
+            warn("No device-parallel actx available, execution will be slow. "
+                 "Please make sure you have the right branches for loopy "
+                 "(https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms) "  # noqa
+                 "and meshmode "
+                 "(https://github.com/kaushikcfd/meshmode/tree/pytato-array-context-transforms).")  # noqa
         # lazy, non-distributed
         if not distributed:
             if _HAVE_SINGLE_GRID_WORK_BALANCING:
-                return SingleGridWorkBalancingPytatoArrayContext
+                actx_class = SingleGridWorkBalancingPytatoArrayContext
             else:
-                return PytatoPyOpenCLArrayContext
+                actx_class = PytatoPyOpenCLArrayContext
         # distributed+lazy:
         if _HAVE_SINGLE_GRID_WORK_BALANCING:
-            return MPISingleGridWorkBalancingPytatoArrayContext
+            actx_class = MPISingleGridWorkBalancingPytatoArrayContext
         else:
-            return MPIBasePytatoPyOpenCLArrayContext
+            actx_class = MPIBasePytatoPyOpenCLArrayContext
     else:
         if distributed:
-            return MPIPyOpenCLArrayContext
+            actx_class = MPIPyOpenCLArrayContext
         else:
-            return PyOpenCLArrayContext
+            actx_class = PyOpenCLArrayContext
+
+    logger.info("get_reasonable_array_context_class: %s lazy=%r distributed=%r "
+                "device-parallel=%r",
+                actx_class.__name__, lazy, distributed,
+                # eager is always device-parallel:
+                (_HAVE_SINGLE_GRID_WORK_BALANCING or not lazy))
+    return actx_class
 
 # }}}
 
-- 
GitLab