diff --git a/doc/dag.rst b/doc/dag.rst index 8d0970a01231a827e5f5fd82e74630c4d369c9af..63d6e0120f3f51c01bafcbdc2a6a1d333aba0d08 100644 --- a/doc/dag.rst +++ b/doc/dag.rst @@ -18,6 +18,11 @@ Comparing two expression Graphs .. automodule:: pytato.equality +Partitioning Array Expression Graphs +==================================== + +.. automodule:: pytato.partition + Utilities and Diagnostics ========================= diff --git a/examples/partition.py b/examples/partition.py new file mode 100644 index 0000000000000000000000000000000000000000..2a0111e108701de54f1afab027948a3d63362c78 --- /dev/null +++ b/examples/partition.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python + +import pytato as pt +import pyopencl as cl +import numpy as np +from pytato.partition import (execute_partitions, + generate_code_for_partitions, find_partitions) + +from pytato.transform import TopoSortMapper + +from dataclasses import dataclass + + +@dataclass(frozen=True, eq=True) +class MyPartitionId(): + num: int + + +def get_partition_id(topo_list, expr) -> MyPartitionId: + # Partition nodes into groups of two + res = MyPartitionId(topo_list.index(expr)//2) + return res + + +def main(): + x_in = np.random.randn(2, 2) + x = pt.make_data_wrapper(x_in) + y = pt.stack([x@x.T, 2*x, 42+x]) + y = y + 55 + + tm = TopoSortMapper() + tm(y) + + from functools import partial + pfunc = partial(get_partition_id, tm.topological_order) + + # Find the partitions + outputs = pt.DictOfNamedArrays({"out": y}) + parts = find_partitions(outputs, pfunc) + + # Show the partitions + from pytato.visualization import get_dot_graph_from_partitions + get_dot_graph_from_partitions(parts) + + # Execute the partitions + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + + prg_per_partition = generate_code_for_partitions(parts) + + context = execute_partitions(parts, prg_per_partition, queue) + + final_res = [context[k] for k in outputs.keys()] + + # Execute the unpartitioned code for comparison + prg = pt.generate_loopy(y) + _, (out, ) = prg(queue) + + np.testing.assert_allclose([out], final_res) + + print("Partitioning test succeeded.") + + +if __name__ == "__main__": + main() diff --git a/pytato/__init__.py b/pytato/__init__.py index bf9608ff7d3591809b64eec3cfb9cb22ca02f6ab..dbb46c2aa9b7e3a002b03a5c27267bd4cd797d31 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -60,7 +60,8 @@ from pytato.target.loopy.codegen import generate_loopy from pytato.target import Target from pytato.target.loopy import LoopyPyOpenCLTarget from pytato.visualization import (get_dot_graph, show_dot_graph, - get_ascii_graph, show_ascii_graph) + get_ascii_graph, show_ascii_graph, + get_dot_graph_from_partitions) import pytato.analysis as analysis import pytato.tags as tags import pytato.transform as transform @@ -79,8 +80,8 @@ __all__ = ( "Target", "LoopyPyOpenCLTarget", - "get_dot_graph", "show_dot_graph", "get_ascii_graph", "show_ascii_graph", - + "get_dot_graph", "show_dot_graph", "get_ascii_graph", + "show_ascii_graph", "get_dot_graph_from_partitions", "abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan", "sinh", "cosh", "tanh", "exp", "log", "log10", "isnan", "sqrt", "conj", "arctan2", diff --git a/pytato/partition.py b/pytato/partition.py new file mode 100644 index 0000000000000000000000000000000000000000..f533c291530cdcdd63256b16be83cea60da7d22e --- /dev/null +++ b/pytato/partition.py @@ -0,0 +1,354 @@ +from __future__ import annotations + +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from typing import Any, Callable, Dict, Union, Set, List, Hashable, Tuple, TypeVar +from dataclasses import dataclass + + +from pytato.transform import EdgeCachedMapper, CachedWalkMapper +from pytato.array import ( + Array, AbstractResultWithNamedArrays, Placeholder, + DictOfNamedArrays, make_placeholder) + +from pytato.target import BoundProgram + + +__doc__ = """ +.. autoclass:: CodePartitions +.. autoexception:: PartitionInducedCycleError + +.. autofunction:: find_partitions +.. autofunction:: execute_partitions +""" + + +ArrayOrNames = Union[Array, AbstractResultWithNamedArrays] +T = TypeVar("T", Array, AbstractResultWithNamedArrays) +PartitionId = Hashable + + +# {{{ graph partitioner + +class _GraphPartitioner(EdgeCachedMapper): + """Given a function *get_partition_id*, produces subgraphs representing + the computation. Users should not use this class directly, but use + :meth:`find_partitions` instead. + """ + + # {{{ infrastructure + + def __init__(self, get_partition_id: + Callable[[ArrayOrNames], PartitionId]) -> None: + super().__init__() + + # Function to determine the Partition ID + self._get_partition_id: Callable[[ArrayOrNames], PartitionId] = \ + get_partition_id + + # Naming for newly created PlaceHolders at partition edges + from pytools import UniqueNameGenerator + self.name_generator = UniqueNameGenerator(forced_prefix="_part_ph_") + + # "edges" of the partitioned graph, maps an edge between two partitions, + # represented by a tuple of partition identifiers, to a set of placeholder + # names "conveying" information across the edge. + self.partition_pair_to_edges: Dict[Tuple[PartitionId, PartitionId], + Set[str]] = {} + + self.var_name_to_result: Dict[str, Array] = {} + + self._seen_node_to_placeholder: Dict[ArrayOrNames, Placeholder] = {} + + # Reading the seen partition IDs out of partition_pair_to_edges is incorrect: + # e.g. if each partition is self-contained, no edges would appear. Instead, + # we remember each partition ID we see below, to guarantee that we don't + # miss any of them. + self.seen_partition_ids: Set[PartitionId] = set() + + def get_partition_id(self, expr: ArrayOrNames) -> PartitionId: + part_id = self._get_partition_id(expr) + self.seen_partition_ids.add(part_id) + return part_id + + def does_edge_cross_partition_boundary(self, + node1: ArrayOrNames, node2: ArrayOrNames) -> bool: + return self.get_partition_id(node1) != self.get_partition_id(node2) + + def make_new_placeholder_name(self) -> str: + return self.name_generator() + + def add_interpartition_edge(self, target: ArrayOrNames, dependency: ArrayOrNames, + placeholder_name: str) -> None: + pid_target = self.get_partition_id(target) + pid_dependency = self.get_partition_id(dependency) + + self.partition_pair_to_edges.setdefault( + (pid_target, pid_dependency), set()).add(placeholder_name) + + def handle_edge(self, expr: ArrayOrNames, child: ArrayOrNames) -> Any: + if self.does_edge_cross_partition_boundary(expr, child): + try: + ph = self._seen_node_to_placeholder[child] + except KeyError: + ph_name = self.make_new_placeholder_name() + # If an edge crosses a partition boundary, replace the + # depended-upon node (that nominally lives in the other partition) + # with a Placeholder that lives in the current partition. For each + # partition, collect the placeholder names that it’s supposed to + # compute. + + if not isinstance(child, Array): + raise NotImplementedError("not currently supporting " + "DictOfNamedArrays in the middle of graph " + "partitioning") + + ph = make_placeholder(ph_name, + shape=child.shape, + dtype=child.dtype, + tags=child.tags) + + self.var_name_to_result[ph_name] = self.rec(child) + + self._seen_node_to_placeholder[child] = ph + + assert ph.name + self.add_interpartition_edge(expr, child, ph.name) + return ph + + else: + return self.rec(child) + + def __call__(self, expr: T, *args: Any, **kwargs: Any) -> Any: + # Need to make sure the first node's partition is 'seen' + self.get_partition_id(expr) + + return super().__call__(expr, *args, **kwargs) + + # }}} + +# }}} + + +# {{{ code partitions + +@dataclass +class CodePartitions: + """Store information about generated partitions. + + .. attribute:: toposorted_partitions + + List of topologically sorted partitions, represented by their + identifiers. + + .. attribute:: partition_id_to_input_names + + Mapping of partition identifiers to names of placeholders + the partition requires as input. + + .. attribute:: partition_id_to_output_names + + Mapping of partition IDs to the names of placeholders + they provide as output. + + .. attribute:: var_name_to_result + + Mapping of placeholder names to their respective :class:`pytato.array.Array` + they represent. + """ + toposorted_partitions: List[PartitionId] + partition_id_to_input_names: Dict[PartitionId, Set[str]] + partition_id_to_output_names: Dict[PartitionId, Set[str]] + var_name_to_result: Dict[str, Array] + +# }}} + + +class PartitionInducedCycleError(Exception): + """Raised by :func:`find_partitions` if the partitioning induced a + cycle in the graph of partitions. + """ + + +# {{{ find_partitions + +def find_partitions(outputs: DictOfNamedArrays, + part_func: Callable[[ArrayOrNames], PartitionId]) ->\ + CodePartitions: + """Partitions the *expr* according to *part_func* and generates code for + each partition. Raises :exc:`PartitionInducedCycleError` if the partitioning + induces a cycle, e.g. for a graph like the following:: + + ┌───┐ + ┌──┤ A ├──┐ + │ └───┘ │ + │ ┌─▼─┐ + │ │ B │ + │ └─┬─┘ + │ ┌───┐ │ + └─►│ C │◄─┘ + └───┘ + + where ``A`` and ``C`` are in partition 1, and ``B`` is in partition 2. + + :param expr: The expression to partition. + :param part_func: A callable that returns an instance of + :class:`Hashable` for a node. + :returns: An instance of :class:`CodePartitions` that contains the partitions. + """ + + pf = _GraphPartitioner(part_func) + rewritten_outputs = {name: pf(expr) for name, expr in outputs._data.items()} + + partition_id_to_output_names: Dict[PartitionId, Set[str]] = { + pid: set() for pid in pf.seen_partition_ids} + partition_id_to_input_names: Dict[PartitionId, Set[str]] = { + pid: set() for pid in pf.seen_partition_ids} + + partitions = set() + + var_name_to_result = pf.var_name_to_result.copy() + + for out_name, rewritten_output in rewritten_outputs.items(): + out_part_id = part_func(outputs._data[out_name]) + partition_id_to_output_names.setdefault(out_part_id, set()).add(out_name) + var_name_to_result[out_name] = rewritten_output + + # Mapping of nodes to their successors; used to compute the topological order + partition_nodes_to_targets: Dict[PartitionId, List[PartitionId]] = { + pid: [] for pid in pf.seen_partition_ids} + + for (pid_target, pid_dependency), var_names in \ + pf.partition_pair_to_edges.items(): + partitions.add(pid_target) + partitions.add(pid_dependency) + + partition_nodes_to_targets[pid_dependency].append(pid_target) + + for var_name in var_names: + partition_id_to_output_names[pid_dependency].add(var_name) + partition_id_to_input_names[pid_target].add(var_name) + + from pytools.graph import compute_topological_order, CycleError + try: + toposorted_partitions = compute_topological_order(partition_nodes_to_targets) + except CycleError: + raise PartitionInducedCycleError + + result = CodePartitions(toposorted_partitions, partition_id_to_input_names, + partition_id_to_output_names, var_name_to_result) + + if __debug__: + _check_partition_disjointness(result) + + return result + + +class _SeenNodesWalkMapper(CachedWalkMapper): + def __init__(self) -> None: + super().__init__() + self.seen_nodes: Set[ArrayOrNames] = set() + + def visit(self, expr: ArrayOrNames) -> bool: + super().visit(expr) + self.seen_nodes.add(expr) + return True + + +def _check_partition_disjointness(parts: CodePartitions) -> None: + part_id_to_nodes: Dict[PartitionId, Set[ArrayOrNames]] = {} + + for part_id, out_names in parts.partition_id_to_output_names.items(): + + mapper = _SeenNodesWalkMapper() + for out_name in out_names: + mapper(parts.var_name_to_result[out_name]) + + # FIXME This check won't do much unless we successfully visit + # all the nodes, but we're not currently checking that. + for my_node in mapper.seen_nodes: + for other_part_id, other_node_set in part_id_to_nodes.items(): + # Placeholders represent values computed in one partition + # and used in one or more other ones. As a result, the + # same placeholder may occur in more than one partition. + assert ( + isinstance(my_node, Placeholder) + or my_node not in other_node_set), ( + "partitions not disjoint: " + f"{my_node.__class__.__name__} (id={id(my_node)}) " + f"in both '{part_id}' and '{other_part_id}'") + + part_id_to_nodes[part_id] = mapper.seen_nodes + +# }}} + + +# {{{ generate_code_for_partitions + +def generate_code_for_partitions(parts: CodePartitions) \ + -> Dict[PartitionId, BoundProgram]: + """Return a mapping of partition identifiers to their + :class:`pytato.target.BoundProgram`.""" + from pytato import generate_loopy + prg_per_partition = {} + for pid in parts.toposorted_partitions: + d = DictOfNamedArrays( + {var_name: parts.var_name_to_result[var_name] + for var_name in parts.partition_id_to_output_names[pid] + }) + prg_per_partition[pid] = generate_loopy(d) + + return prg_per_partition + +# }}} + + +# {{{ execute_partitions + +def execute_partitions(parts: CodePartitions, prg_per_partition: + Dict[PartitionId, BoundProgram], queue: Any) -> Dict[str, Any]: + """Executes a set of partitions on a :class:`pyopencl.CommandQueue`. + + :param parts: An instance of :class:`CodePartitions` representing the + partitioned code. + :param queue: An instance of :class:`pyopencl.CommandQueue` to execute the + code on. + :returns: A dictionary of variable names mapped to their values. + """ + context: Dict[str, Any] = {} + for pid in parts.toposorted_partitions: + inputs = { + k: context[k] for k in parts.partition_id_to_input_names[pid] + if k in context} + + _evt, result_dict = prg_per_partition[pid](queue=queue, **inputs) + context.update(result_dict) + + return context + +# }}} + + +# vim: foldmethod=marker diff --git a/pytato/transform.py b/pytato/transform.py index 76a8a224bc3ac97c8f963e1763d937bcbaffbc38..bd56201241eeae22466ef4a29b46a1b455d03abd 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -24,8 +24,9 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from abc import ABC, abstractmethod from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic, - List, Mapping, Iterable, Optional) + List, Mapping, Iterable, Optional, Tuple) from pytato.array import ( Array, IndexLambda, Placeholder, MatrixProduct, Stack, Roll, @@ -59,6 +60,7 @@ __doc__ = """ .. autoclass:: CachedWalkMapper .. autoclass:: TopoSortMapper .. autoclass:: CachedMapAndCopyMapper +.. autoclass:: EdgeCachedMapper .. autofunction:: copy_dict_of_named_arrays .. autofunction:: get_dependencies .. autofunction:: map_and_copy @@ -686,7 +688,10 @@ class CachedWalkMapper(WalkMapper): # {{{ TopoSortMapper class TopoSortMapper(CachedWalkMapper): - """A mapper that creates a list of nodes in topological order.""" + """A mapper that creates a list of nodes in topological order. + + :members: topological_order + """ def __init__(self) -> None: super().__init__() @@ -1158,10 +1163,141 @@ def tag_child_nodes(graph: Dict[ArrayOrNames, Set[ArrayOrNames]], tag: Any, node_to_tags.setdefault(starting_point, set()).add(tag) if starting_point in graph: for other_node_key in graph[starting_point]: - tag_child_nodes(graph, other_node_key, tag, node_to_tags) + tag_child_nodes(graph, other_node_key, tag, + node_to_tags) return node_to_tags # }}} + +# {{{ EdgeCachedMapper + +class EdgeCachedMapper(CachedMapper[ArrayOrNames], ABC): + """ + Mapper class to execute a rewriting method (:meth:`handle_edge`) on each + edge in the graph. + + .. automethod:: handle_edge + """ + + @abstractmethod + def handle_edge(self, expr: ArrayOrNames, child: ArrayOrNames) -> Any: + pass + + def _handle_shape(self, expr: Array, shape: Any, *args: Any) -> Tuple[Any, ...]: + return tuple([ + self.handle_edge(expr, dim, *args) if isinstance(dim, Array) else dim + for dim in shape]) + + # {{{ map_xxx methods + + def map_named_array(self, expr: NamedArray, *args: Any) -> NamedArray: + return NamedArray( + container=self.handle_edge(expr, expr._container, *args), + name=expr.name, + tags=expr.tags) + + def map_index_lambda(self, expr: IndexLambda, *args: Any) -> IndexLambda: + return IndexLambda(expr=expr.expr, + shape=self._handle_shape(expr, expr.shape), + dtype=expr.dtype, + bindings={name: self.handle_edge(expr, child) + for name, child in expr.bindings.items()}, + tags=expr.tags) + + def map_einsum(self, expr: Einsum, *args: Any) -> Einsum: + return Einsum( + access_descriptors=expr.access_descriptors, + args=tuple(self.handle_edge(expr, arg, *args) + for arg in expr.args), + tags=expr.tags) + + def map_matrix_product(self, expr: MatrixProduct, *args: Any) -> MatrixProduct: + return MatrixProduct(x1=self.handle_edge(expr, expr.x1, *args), + x2=self.handle_edge(expr, expr.x2, *args), + tags=expr.tags) + + def map_stack(self, expr: Stack, *args: Any) -> Stack: + return Stack( + arrays=tuple(self.handle_edge(expr, ary, *args) + for ary in expr.arrays), + axis=expr.axis, + tags=expr.tags) + + def map_concatenate(self, expr: Concatenate, *args: Any) -> Concatenate: + return Concatenate( + arrays=tuple(self.handle_edge(expr, ary, *args) + for ary in expr.arrays), + axis=expr.axis, + tags=expr.tags) + + def map_roll(self, expr: Roll, *args: Any) -> Roll: + return Roll(array=self.handle_edge(expr, expr.array, *args), + shift=expr.shift, + axis=expr.axis, + tags=expr.tags) + + def map_axis_permutation(self, expr: AxisPermutation, *args: Any) \ + -> AxisPermutation: + return AxisPermutation( + array=self.handle_edge(expr, expr.array, *args), + axes=expr.axes, + tags=expr.tags) + + def map_reshape(self, expr: Reshape, *args: Any) -> Reshape: + return Reshape( + array=self.handle_edge(expr, expr.array, *args), + newshape=self._handle_shape(expr, expr.newshape), + order=expr.order, + tags=expr.tags) + + def map_basic_index(self, expr: BasicIndex, *args: Any) -> BasicIndex: + return BasicIndex( + array=self.handle_edge(expr, expr.array, *args), + indices=tuple(self.handle_edge(expr, idx, *args) + if isinstance(idx, Array) else idx + for idx in expr.indices)) + + def map_contiguous_advanced_index(self, + expr: AdvancedIndexInContiguousAxes, *args: Any) \ + -> AdvancedIndexInContiguousAxes: + return AdvancedIndexInContiguousAxes( + array=self.handle_edge(expr, expr.array, *args), + indices=tuple(self.handle_edge(expr, idx, *args) + if isinstance(idx, Array) else idx + for idx in expr.indices)) + + def map_non_contiguous_advanced_index(self, + expr: AdvancedIndexInNoncontiguousAxes, *args: Any) \ + -> AdvancedIndexInNoncontiguousAxes: + return AdvancedIndexInNoncontiguousAxes( + array=self.handle_edge(expr, expr.array, *args), + indices=tuple(self.handle_edge(expr, idx, *args) + if isinstance(idx, Array) else idx + for idx in expr.indices)) + + def map_data_wrapper(self, expr: DataWrapper, *args: Any) -> DataWrapper: + return DataWrapper( + name=expr.name, + data=expr.data, + shape=self._handle_shape(expr, expr.shape, *args), + tags=expr.tags) + + def map_placeholder(self, expr: Placeholder, *args: Any) -> Placeholder: + assert expr.name + + return Placeholder(name=expr.name, + shape=self._handle_shape(expr, expr.shape, *args), + dtype=expr.dtype, + tags=expr.tags) + + def map_size_param(self, expr: SizeParam, *args: Any) -> SizeParam: + assert expr.name + return SizeParam(name=expr.name, tags=expr.tags) + + # }}} + +# }}} + # vim: foldmethod=marker diff --git a/pytato/visualization.py b/pytato/visualization.py index 5708400f022ed40fbfbefda06ec8f8e79630610e..3dc80beb18e78e7032f540d7c2bd8dc1cf1f1056 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -28,7 +28,8 @@ THE SOFTWARE. import contextlib import dataclasses import html -from typing import Callable, Dict, Union, Iterator, List, Mapping +from typing import (Callable, Dict, Union, Iterator, List, Mapping, Hashable, + Set) from pytools import UniqueNameGenerator from pytools.codegen import CodeGenerator as CodeGeneratorBase @@ -40,11 +41,14 @@ from pytato.array import ( from pytato.codegen import normalize_outputs import pytato.transform +from pytato.partition import CodePartitions + __doc__ = """ .. currentmodule:: pytato .. autofunction:: get_dot_graph +.. autofunction:: get_dot_graph_from_partitions .. autofunction:: show_dot_graph .. autofunction:: get_ascii_graph .. autofunction:: show_ascii_graph @@ -163,7 +167,8 @@ class DotEmitter(CodeGeneratorBase): self("}") -def _emit_array(emit: DotEmitter, info: DotNodeInfo, id: str) -> None: +def _emit_array(emit: DotEmitter, info: DotNodeInfo, id: str, + color: str = "white") -> None: td_attrib = 'border="0"' table_attrib = 'border="0" cellborder="1" cellspacing="0"' @@ -179,7 +184,7 @@ def _emit_array(emit: DotEmitter, info: DotNodeInfo, id: str) -> None: % (td_attrib, dot_escape(name), td_attrib, dot_escape(field))) table = "\n%s
" % (table_attrib, "".join(rows)) - emit("%s [label=<%s>]" % (id, table)) + emit("%s [label=<%s> style=filled fillcolor=%s]" % (id, table, color)) def _emit_name_cluster(emit: DotEmitter, names: Mapping[str, Array], @@ -259,6 +264,96 @@ def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str: return emit.get() +def get_dot_graph_from_partitions(parts: CodePartitions) -> str: + r"""Return a string in the `dot `_ language depicting the + graph of the computation of *result*. + + :arg result: Outputs of the computation (cf. + :func:`pytato.generate_loopy`). + """ + # Maps each partition to a dict of its arrays with the node info + part_id_to_node_to_node_info: Dict[Hashable, Dict[Array, DotNodeInfo]] = {} + + for part_id, out_names in parts.partition_id_to_output_names.items(): + part_node_to_info: Dict[Array, DotNodeInfo] = {} + + mapper = ArrayToDotNodeInfoMapper() + for out_name in out_names: + mapper(parts.var_name_to_result[out_name], part_node_to_info) + + part_id_to_node_to_node_info[part_id] = part_node_to_info + + id_gen = UniqueNameGenerator() + + emit = DotEmitter() + + with emit.block("digraph computation"): + emit("node [shape=rectangle]") + array_to_id: Dict[Array, str] = {} + + # Fill array_to_id in a first pass. Technically, this isn't + # necessary, if parts.toposorted_partitions is *actually* topologically + # sorted. But if *cough* hypothetically parts.toposorted_partitions + # were not actually topologically sorted, like if you were in the + # middle of investigating a bug with the topological sort, ... + for part_id in parts.toposorted_partitions: + for array, _ in part_id_to_node_to_node_info[part_id].items(): + array_to_id[array] = id_gen("array") + + # Second pass: emit the graph. + for part_id in parts.toposorted_partitions: + part_node_to_info = part_id_to_node_to_node_info[part_id] + input_arrays: List[Array] = [] + output_arrays: Set[Array] = set() + internal_arrays: List[Array] = [] + + for array, _ in part_node_to_info.items(): + if isinstance(array, InputArgumentBase): + input_arrays.append(array) + else: + internal_arrays.append(array) + + for out_name in parts.partition_id_to_output_names[part_id]: + ary = parts.var_name_to_result[out_name] + output_arrays.add(ary) + if ary in internal_arrays: + internal_arrays.remove(ary) + + with emit.block(f'subgraph "cluster_part_{part_id}"'): + emit("style=dashed") + emit(f'label="{part_id}"') + + # Emit inputs + for array in input_arrays: + _emit_array(emit, part_node_to_info[array], + array_to_id[array], "deepskyblue") + + # Emit cross-partition edges + if array.name: # type: ignore [attr-defined] + tgt = array_to_id[ + parts.var_name_to_result[array.name]] # type: ignore + emit(f"{tgt} -> {array_to_id[array]}") + + # Emit internal nodes + for array in internal_arrays: + _emit_array(emit, part_node_to_info[array], array_to_id[array]) + + # Emit outputs + for array in output_arrays: + _emit_array(emit, part_node_to_info[array], + array_to_id[array], "gold") + + # Emit intra-partition edges + for array, node in part_node_to_info.items(): + for label, tail_array in node.edges.items(): + tail = array_to_id[tail_array] + head = array_to_id[array] + emit('%s -> %s [label="%s"]' % + (tail, head, dot_escape(label))) + + return emit.get() + + def show_dot_graph(result: Union[str, Array, DictOfNamedArrays]) -> None: """Show a graph representing the computation of *result* in a browser. diff --git a/test/test_codegen.py b/test/test_codegen.py index 9974b6b87b4b7c948e4bd4dfadeca062f9ef8b83..1419de4ae8a3a31021a2b1a9ebc3ed63b6d49836 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1287,8 +1287,6 @@ def test_random_dag_against_numpy(ctx_factory): from testlib import RandomDAGContext, make_random_dag axis_len = 5 - - # Warn about from warnings import filterwarnings, catch_warnings with catch_warnings(): # We'd like to know if Numpy divides by zero. @@ -1316,6 +1314,70 @@ def test_random_dag_against_numpy(ctx_factory): assert np.allclose(pt_result["result"], ref_result) +def test_partitioner(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + from testlib import RandomDAGContext, make_random_dag + + axis_len = 5 + + ntests = 50 + ncycles = 0 + for i in range(ntests): + print(i) + seed = 120 + i + rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed), + axis_len=axis_len, use_numpy=False) + rdagc_np = RandomDAGContext(np.random.default_rng(seed=seed), + axis_len=axis_len, use_numpy=True) + + ref_result = make_random_dag(rdagc_np) + + from pytato.transform import materialize_with_mpms + dict_named_arys = materialize_with_mpms(pt.DictOfNamedArrays( + {"result": make_random_dag(rdagc_pt)})) + + from dataclasses import dataclass + from pytato.transform import TopoSortMapper + from pytato.partition import (find_partitions, + execute_partitions, generate_code_for_partitions, + PartitionInducedCycleError) + + @dataclass(frozen=True, eq=True) + class MyPartitionId(): + num: int + + def get_partition_id(topo_list, expr) -> MyPartitionId: + return topo_list.index(expr) // 3 + + tm = TopoSortMapper() + tm(dict_named_arys) + + from functools import partial + part_func = partial(get_partition_id, tm.topological_order) + + try: + parts = find_partitions(dict_named_arys, part_func) + except PartitionInducedCycleError: + print("CYCLE!") + # FIXME *shrug* nothing preventing that currently + ncycles += 1 + continue + + # Execute the partitioned code + prg_per_partition = generate_code_for_partitions(parts) + + context = execute_partitions(parts, prg_per_partition, queue) + + pt_part_res, = [context[k] for k in dict_named_arys] + + np.testing.assert_allclose(pt_part_res, ref_result) + + # Assert that at least 2/3 of our tests did not get skipped because of cycles + assert ncycles < ntests // 3 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])