diff --git a/grudge/array_context.py b/grudge/array_context.py index 171016bfe3ccc2bd236ae7933518a2a8da930e5e..f1a27ba3e8b32b6a7d72dd4bcefb3339bfcdcd06 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -74,7 +74,7 @@ from arraycontext.pytest import ( register_pytest_array_context_factory) from arraycontext import ArrayContext from arraycontext.container import ArrayContainer -from arraycontext.impl.pytato.compile import LazilyCompilingFunctionCaller +from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller if TYPE_CHECKING: import pytato as pt @@ -131,7 +131,8 @@ class MPIBasedArrayContext: # {{{ distributed + pytato -class _DistributedLazilyCompilingFunctionCaller(LazilyCompilingFunctionCaller): +class _DistributedLazilyCompilingFunctionCaller( + LazilyPyOpenCLCompilingFunctionCaller): def _dag_to_compiled_func(self, dict_of_named_arrays, input_id_to_name_in_program, output_id_to_name_in_program, output_template): @@ -201,8 +202,8 @@ class _DistributedLazilyCompilingFunctionCaller(LazilyCompilingFunctionCaller): class _DistributedCompiledFunction: """ A callable which captures the :class:`pytato.target.BoundProgram` resulting - from calling :attr:`~LazilyCompilingFunctionCaller.f` with a given set of - input types, and generating :mod:`loopy` IR from it. + from calling :attr:`~LazilyPyOpenCLCompilingFunctionCaller.f` with a given + set of input types, and generating :mod:`loopy` IR from it. .. attribute:: pytato_program @@ -210,8 +211,9 @@ class _DistributedCompiledFunction: A mapping from input id to the placeholder name in :attr:`CompiledFunction.pytato_program`. Input id is represented as the - position of :attr:`~LazilyCompilingFunctionCaller.f`'s argument augmented - with the leaf array's key if the argument is an array container. + position of :attr:`~LazilyPyOpenCLCompilingFunctionCaller.f`'s argument + augmented with the leaf array's key if the argument is an array + container. .. attribute:: output_id_to_name_in_program @@ -243,10 +245,10 @@ class _DistributedCompiledFunction: representation. """ - from arraycontext.impl.pytato.compile import _args_to_cl_buffers + from arraycontext.impl.pytato.compile import _args_to_device_buffers from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array from arraycontext.impl.pytato.utils import get_cl_axes_from_pt_axes - input_args_for_prg = _args_to_cl_buffers( + input_args_for_prg = _args_to_device_buffers( self.actx, self.input_id_to_name_in_program, arg_id_to_arg) from pytato.distributed import execute_distributed_partition