diff --git a/grudge/array_context.py b/grudge/array_context.py
index 00e6fb0d64b41d28c1fd1b081ddfad3fc19eb3ec..d672a68ad32b2df70610770ad626148212f3ec5c 100644
--- a/grudge/array_context.py
+++ b/grudge/array_context.py
@@ -232,7 +232,21 @@ class _DistributedLazilyPyOpenCLCompilingFunctionCaller(
         self.actx._compile_trace_callback(self.f, "pre_find_distributed_partition",
                 dict_of_named_arrays)
 
-        distributed_partition = pt.find_distributed_partition(dict_of_named_arrays)
+        # https://github.com/inducer/pytato/pull/393 changes the function signature
+        try:
+            # pylint: disable=too-many-function-args
+            distributed_partition = pt.find_distributed_partition(
+                # pylint-ignore-reason:
+                # '_BasePytatoArrayContext' has no
+                # 'mpi_communicator' member
+                # pylint: disable=no-member
+                self.actx.mpi_communicator, dict_of_named_arrays)
+        except TypeError as e:
+            if "find_distributed_partition() takes 1 positional" in str(e):
+                distributed_partition = pt.find_distributed_partition(
+                    dict_of_named_arrays)
+            else:
+                raise
 
         if __debug__:
             # pylint-ignore-reason: