diff --git a/test/testlib.py b/test/testlib.py index c9cf1959ecdc0d906d93f051a7960387ff8de75c..4ff50920e18923c33254d07473f6adbd2544aa6c 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -244,6 +244,76 @@ def make_random_dag(rdagc: RandomDAGContext) -> Any: # }}} +# {{{ get_random_dag_w_no_placholders + +def get_random_pt_dag(seed: int, + *, + additional_generators: Optional[ + Sequence[Tuple[int, + Callable[[RandomDAGContext], Array]]] + ] = None, + axis_len: int = 4, + convert_dws_to_placeholders: bool = False + ) -> pt.DictOfNamedArrays: + if additional_generators is None: + additional_generators = [] + + from testlib import RandomDAGContext, make_random_dag + from typing import cast + + rdagc_comm = RandomDAGContext(np.random.default_rng(seed=seed), + axis_len=axis_len, use_numpy=False, + additional_generators=additional_generators) + dag = pt.make_dict_of_named_arrays({"result": make_random_dag(rdagc_comm)}) + + if convert_dws_to_placeholders: + from pytools import UniqueNameGenerator + vng = UniqueNameGenerator() + + def make_dws_placeholder(expr: pt.transform.ArrayOrNames + ) -> pt.transform.ArrayOrNames: + if isinstance(expr, pt.DataWrapper): + return pt.make_placeholder(vng("_pt_ph"), + expr.shape, expr.dtype) + else: + return expr + + dag = cast(pt.DictOfNamedArrays, + pt.transform.map_and_copy(dag, make_dws_placeholder)) + + return dag + + +def get_random_pt_dag_with_send_recv_nodes( + seed: int, + rank: int, + size: int, + *, + comm_fake_probability: int = 500, + axis_len: int = 4, + convert_dws_to_placeholders: bool = False + ) -> pt.DictOfNamedArrays: + comm_tag = 17 + + def gen_comm(rdagc: RandomDAGContext) -> pt.Array: + nonlocal comm_tag + comm_tag += 1 + inner = make_random_dag(rdagc) + return pt.staple_distributed_send( + inner, dest_rank=(rank-1) % size, comm_tag=comm_tag, + stapled_to=pt.make_distributed_recv( + src_rank=(rank+1) % size, comm_tag=comm_tag, + shape=inner.shape, dtype=inner.dtype)) + + return get_random_pt_dag( + seed=seed, + axis_len=axis_len, + convert_dws_to_placeholders=convert_dws_to_placeholders, + additional_generators=[(comm_fake_probability, gen_comm)]) + +# }}} + + # {{{ tags used only by the regression tests class FooInameTag(Tag):