diff --git a/grudge/array_context.py b/grudge/array_context.py index 8ee89a35e76aad7843dc62c8bd534898648cb616..171016bfe3ccc2bd236ae7933518a2a8da930e5e 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -32,12 +32,14 @@ THE SOFTWARE. # {{{ imports from typing import ( - TYPE_CHECKING, Mapping, Tuple, Any, Callable, Optional, Type) + TYPE_CHECKING, Mapping, Tuple, Any, Callable, Optional, Type, + FrozenSet) from dataclasses import dataclass - +from pytools.tag import Tag from meshmode.array_context import ( PyOpenCLArrayContext as _PyOpenCLArrayContextBase, PytatoPyOpenCLArrayContext as _PytatoPyOpenCLArrayContextBase) +from pyrsistent import pmap import logging logger = logging.getLogger(__name__) @@ -160,6 +162,8 @@ class _DistributedLazilyCompilingFunctionCaller(LazilyCompilingFunctionCaller): # }}} part_id_to_prg = {} + name_in_program_to_tags = pmap() + name_in_program_to_axes = pmap() from pytato import DictOfNamedArrays for part in distributed_partition.parts.values(): @@ -167,7 +171,20 @@ class _DistributedLazilyCompilingFunctionCaller(LazilyCompilingFunctionCaller): {var_name: distributed_partition.var_name_to_result[var_name] for var_name in part.output_names }) - part_id_to_prg[part.pid], _, _ = self._dag_to_transformed_loopy_prg(d) + ( + part_id_to_prg[part.pid], + part_prg_name_to_tags, + part_prg_name_to_axes + ) = self._dag_to_transformed_loopy_prg(d) + + assert not (set(name_in_program_to_tags.keys()) + & set(part_prg_name_to_tags.keys())) + assert not (set(name_in_program_to_axes.keys()) + & set(part_prg_name_to_axes.keys())) + name_in_program_to_tags = name_in_program_to_tags.update( + part_prg_name_to_tags) + name_in_program_to_axes = name_in_program_to_axes.update( + part_prg_name_to_axes) return _DistributedCompiledFunction( actx=self.actx, @@ -175,6 +192,8 @@ class _DistributedLazilyCompilingFunctionCaller(LazilyCompilingFunctionCaller): part_id_to_prg=part_id_to_prg, input_id_to_name_in_program=input_id_to_name_in_program, output_id_to_name_in_program=output_id_to_name_in_program, + name_in_program_to_tags=name_in_program_to_tags, + name_in_program_to_axes=name_in_program_to_axes, output_template=output_template) @@ -213,6 +232,8 @@ class _DistributedCompiledFunction: part_id_to_prg: "Mapping[PartId, pt.target.BoundProgram]" input_id_to_name_in_program: Mapping[Tuple[Any, ...], str] output_id_to_name_in_program: Mapping[Tuple[Any, ...], str] + name_in_program_to_tags: Mapping[str, FrozenSet[Tag]] + name_in_program_to_axes: Mapping[str, Tuple["pt.Axis", ...]] output_template: ArrayContainer def __call__(self, arg_id_to_arg) -> ArrayContainer: @@ -223,6 +244,8 @@ class _DistributedCompiledFunction: """ from arraycontext.impl.pytato.compile import _args_to_cl_buffers + from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array + from arraycontext.impl.pytato.utils import get_cl_axes_from_pt_axes input_args_for_prg = _args_to_cl_buffers( self.actx, self.input_id_to_name_in_program, arg_id_to_arg) @@ -234,7 +257,12 @@ class _DistributedCompiledFunction: input_args=input_args_for_prg) def to_output_template(keys, _): - return self.actx.thaw(out_dict[self.output_id_to_name_in_program[keys]]) + ary_name_in_prg = self.output_id_to_name_in_program[keys] + return self.actx.thaw(to_tagged_cl_array( + out_dict[ary_name_in_prg], + axes=get_cl_axes_from_pt_axes( + self.name_in_program_to_axes[ary_name_in_prg]), + tags=self.name_in_program_to_tags[ary_name_in_prg])) from arraycontext.container.traversal import rec_keyed_map_array_container return rec_keyed_map_array_container(to_output_template,