diff --git a/pytato/array.py b/pytato/array.py index 9e73a96c75edc9c8547da5db2205f46f15cb360c..bbf4ae739fdb8c096ff76bc02859882b1f85698a 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -556,9 +556,9 @@ class Array(Taggable): indices = tuple(var(f"_{i}") for i in range(self.ndim)) expr = op(var("_in0")[indices]) - bindings = {"_in0": self} + bindings = Map({"_in0": self}) return IndexLambda( - expr, + expr=expr, shape=self.shape, dtype=self.dtype, bindings=bindings, @@ -2003,7 +2003,7 @@ def make_placeholder(name: str, raise ValueError("'axes' dimensionality mismatch:" f" expected {len(shape)}, got {len(axes)}.") - return Placeholder(name, shape, dtype, axes=axes, + return Placeholder(name=name, shape=shape, dtype=dtype, axes=axes, tags=(tags | _get_default_tags())) @@ -2085,7 +2085,8 @@ def full(shape: ConvertibleToShape, fill_value: ScalarType, else: fill_value = dtype.type(fill_value) - return IndexLambda(fill_value, shape, dtype, {}, + return IndexLambda(expr=fill_value, shape=shape, dtype=dtype, + bindings=Map(), tags=_get_default_tags(), axes=_get_default_axes(len(shape)), var_to_reduction_descr=Map()) @@ -2130,8 +2131,8 @@ def eye(N: int, M: Optional[int] = None, k: int = 0, # noqa: N803 if not isinstance(k, INT_CLASSES): raise ValueError(f"k must be int, got {type(k)}.") - return IndexLambda(parse(f"1 if ((_1 - _0) == {k}) else 0"), - shape=(N, M), dtype=dtype, bindings={}, + return IndexLambda(expr=parse(f"1 if ((_1 - _0) == {k}) else 0"), + shape=(N, M), dtype=dtype, bindings=Map({}), tags=_get_default_tags(), axes=_get_default_axes(2), var_to_reduction_descr=Map()) @@ -2225,8 +2226,8 @@ def arange(*args: Any, **kwargs: Any) -> Array: size = max(0, int(ceil((stop-start)/step))) from pymbolic.primitives import Variable - return IndexLambda(start + Variable("_0") * step, - shape=(size,), dtype=dtype, bindings={}, + return IndexLambda(expr=start + Variable("_0") * step, + shape=(size,), dtype=dtype, bindings=Map(), tags=_get_default_tags(), axes=_get_default_axes(1), var_to_reduction_descr=Map()) @@ -2332,7 +2333,7 @@ def logical_not(x: ArrayOrScalar) -> Union[Array, bool]: assert isinstance(x, Array) from pytato.utils import with_indices_for_broadcasted_shape - return IndexLambda(with_indices_for_broadcasted_shape(prim.Variable("_in0"), + return IndexLambda(expr=with_indices_for_broadcasted_shape(prim.Variable("_in0"), x.shape, x.shape), shape=x.shape, @@ -2389,10 +2390,11 @@ def where(condition: ArrayOrScalar, expr3 = utils.update_bindings_and_get_broadcasted_expr(y, "_in2", bindings, result_shape) - return IndexLambda(prim.If(expr1, expr2, expr3), + return IndexLambda( + expr=prim.If(expr1, expr2, expr3), shape=result_shape, dtype=dtype, - bindings=bindings, + bindings=Map(bindings), tags=_get_default_tags(), axes=_get_default_axes(len(result_shape)), var_to_reduction_descr=Map()) @@ -2446,7 +2448,7 @@ def minimum(x1: ArrayOrScalar, x2: ArrayOrScalar) -> ArrayOrScalar: def make_index_lambda( expression: Union[str, ScalarExpression], - bindings: Dict[str, Array], + bindings: Mapping[str, Array], shape: ShapeType, dtype: Any, var_to_reduction_descr: Optional[Mapping[str, ReductionDescriptor]] = None @@ -2488,7 +2490,7 @@ def make_index_lambda( # }}} return IndexLambda(expr=expression, - bindings=bindings, + bindings=Map(bindings), shape=shape, dtype=dtype, tags=_get_default_tags(), @@ -2574,7 +2576,7 @@ def broadcast_to(array: Array, shape: ShapeType) -> Array: shape)), shape=shape, dtype=array.dtype, - bindings={"in": array}, + bindings=Map({"in": array}), tags=_get_default_tags(), axes=_get_default_axes(len(shape)), var_to_reduction_descr=Map()) @@ -2647,7 +2649,7 @@ def expand_dims(array: Array, axis: Union[Tuple[int, ...], int]) -> Array: assert len(new_shape) == output_ndim - return Reshape(array, tuple(new_shape), "C", + return Reshape(array=array, newshape=tuple(new_shape), order="C", tags=(_get_default_tags() | {ExpandedDimsReshape(tuple(normalized_axis))}), axes=_get_default_axes(len(new_shape))) diff --git a/pytato/cmath.py b/pytato/cmath.py index e83334afdec33f1eed3eae87ae6689f7e3c27504..88fca8fa6dfa4ec5913f12c7a4a838b2c7940ab8 100644 --- a/pytato/cmath.py +++ b/pytato/cmath.py @@ -111,9 +111,9 @@ def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...], assert ret_dtype is not None return IndexLambda( - prim.Call(var(f"pytato.c99.{func_name}"), + expr=prim.Call(var(f"pytato.c99.{func_name}"), tuple(sym_args)), - shape, ret_dtype, bindings, + shape=shape, dtype=ret_dtype, bindings=Map(bindings), tags=_get_default_tags(), axes=_get_default_axes(len(shape)), var_to_reduction_descr=Map(), diff --git a/pytato/distributed/nodes.py b/pytato/distributed/nodes.py index 7ff84ad81dbd72d89d12e92137018f610f5e6601..465bda312f74138735fcff51c6f09e1e48c1a0aa 100644 --- a/pytato/distributed/nodes.py +++ b/pytato/distributed/nodes.py @@ -222,7 +222,8 @@ def make_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagTyp send_tags: FrozenSet[Tag] = frozenset()) -> \ DistributedSend: """Make a :class:`DistributedSend` object.""" - return DistributedSend(sent_data, dest_rank, comm_tag, send_tags) + return DistributedSend(data=sent_data, dest_rank=dest_rank, comm_tag=comm_tag, + tags=send_tags) def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagType, @@ -233,8 +234,9 @@ def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagT """Make a :class:`DistributedSend` object wrapped in a :class:`DistributedSendRefHolder` object.""" return DistributedSendRefHolder( - DistributedSend(sent_data, dest_rank, comm_tag, send_tags), - stapled_to, tags=ref_holder_tags) + send=DistributedSend(data=sent_data, dest_rank=dest_rank, + comm_tag=comm_tag, tags=send_tags), + passthrough_data=stapled_to, tags=ref_holder_tags) def make_distributed_recv(src_rank: int, comm_tag: CommTagType, @@ -249,7 +251,9 @@ def make_distributed_recv(src_rank: int, comm_tag: CommTagType, axes = _get_default_axes(len(shape)) dtype = np.dtype(dtype) - return DistributedRecv(src_rank, comm_tag, shape, dtype, tags=tags, axes=axes) + return DistributedRecv( + src_rank=src_rank, comm_tag=comm_tag, shape=shape, dtype=dtype, + tags=tags, axes=axes) # }}} diff --git a/pytato/function.py b/pytato/function.py index 4441c52afb885709b489f746e23f1b6c71126362..6e5d044d2731fd77abe246f6b67c824938780cf3 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -340,11 +340,12 @@ def trace_call(f: Callable[..., ReturnT], raise ValueError(f"Kw argument named '{kw}' not allowed.") # Get placeholders from the ``args``, ``kwargs``. - pl_args = tuple(Placeholder(f"in__pt_{iarg}", arg.shape, arg.dtype, + pl_args = tuple(Placeholder(name=f"in__pt_{iarg}", + shape=arg.shape, dtype=arg.dtype, axes=arg.axes, tags=arg.tags) for iarg, arg in enumerate(args)) - pl_kwargs = {kw: Placeholder(f"in_{kw}", arg.shape, - arg.dtype, axes=arg.axes, tags=arg.tags) + pl_kwargs = {kw: Placeholder(name=f"in_{kw}", shape=arg.shape, + dtype=arg.dtype, axes=arg.axes, tags=arg.tags) for kw, arg in kwargs.items()} # Pass the placeholders diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index ce033cde59caff12bcf5d8291f404f35edcc1899..b7ac02b2fcc8bf9fd7d9e86f8518fb05fe84ee61 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1309,11 +1309,11 @@ class MPMSMaterializer(Mapper): children_rec = {bnd_name: self.rec(bnd) for bnd_name, bnd in sorted(expr.bindings.items())} - new_expr = IndexLambda(expr.expr, - expr.shape, - expr.dtype, - bindings={bnd_name: bnd.expr - for bnd_name, bnd in sorted(children_rec.items())}, + new_expr = IndexLambda(expr=expr.expr, + shape=expr.shape, + dtype=expr.dtype, + bindings=Map({bnd_name: bnd.expr + for bnd_name, bnd in sorted(children_rec.items())}), axes=expr.axes, var_to_reduction_descr=expr.var_to_reduction_descr, tags=expr.tags) diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 698566ced71b86f0b0580d9dde14a1606d2f4975..4aa1d4ccacef5174f9bc5ae2a72856d9543087c3 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -96,8 +96,9 @@ class ToIndexLambdaMixin: return IndexLambda(expr=expr.expr, shape=self._rec_shape(expr.shape), dtype=expr.dtype, - bindings={name: self.rec(bnd) - for name, bnd in expr.bindings.items()}, + bindings=Map({name: self.rec(bnd) + for name, bnd + in sorted(expr.bindings.items())}), axes=expr.axes, var_to_reduction_descr=expr.var_to_reduction_descr, tags=expr.tags) @@ -132,7 +133,7 @@ class ToIndexLambdaMixin: shape=self._rec_shape(expr.shape), dtype=expr.dtype, axes=expr.axes, - bindings=bindings, + bindings=Map(bindings), var_to_reduction_descr=Map(), tags=expr.tags) @@ -179,7 +180,7 @@ class ToIndexLambdaMixin: return IndexLambda(expr=concat_expr, shape=self._rec_shape(expr.shape), dtype=expr.dtype, - bindings=bindings, + bindings=Map(bindings), axes=expr.axes, var_to_reduction_descr=Map(), tags=expr.tags) @@ -248,7 +249,7 @@ class ToIndexLambdaMixin: return IndexLambda(expr=inner_expr, shape=self._rec_shape(expr.shape), dtype=expr.dtype, - bindings=bindings, + bindings=Map(bindings), axes=expr.axes, var_to_reduction_descr=Map(var_to_redn_descr), tags=expr.tags) @@ -274,8 +275,8 @@ class ToIndexLambdaMixin: return IndexLambda(expr=index_expr, shape=self._rec_shape(expr.shape), dtype=expr.dtype, - bindings={name: self.rec(bnd) - for name, bnd in bindings.items()}, + bindings=Map({name: self.rec(bnd) + for name, bnd in bindings.items()}), axes=expr.axes, var_to_reduction_descr=Map(), tags=expr.tags) @@ -337,7 +338,7 @@ class ToIndexLambdaMixin: return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), - bindings=bindings, + bindings=Map(bindings), shape=self._rec_shape(expr.shape), dtype=expr.dtype, axes=expr.axes, @@ -399,7 +400,7 @@ class ToIndexLambdaMixin: return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), - bindings=bindings, + bindings=Map(bindings), shape=self._rec_shape(expr.shape), dtype=expr.dtype, axes=expr.axes, @@ -432,7 +433,7 @@ class ToIndexLambdaMixin: return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), - bindings=bindings, + bindings=Map(bindings), shape=self._rec_shape(expr.shape), dtype=expr.dtype, axes=expr.axes, @@ -446,7 +447,7 @@ class ToIndexLambdaMixin: return IndexLambda(expr=index_expr, shape=self._rec_shape(expr.shape), dtype=expr.dtype, - bindings={"_in0": self.rec(expr.array)}, + bindings=Map({"_in0": self.rec(expr.array)}), axes=expr.axes, var_to_reduction_descr=Map(), tags=expr.tags) @@ -461,7 +462,7 @@ class ToIndexLambdaMixin: return IndexLambda(expr=index_expr, shape=self._rec_shape(expr.shape), dtype=expr.dtype, - bindings={"_in0": self.rec(expr.array)}, + bindings=Map({"_in0": self.rec(expr.array)}), axes=expr.axes, var_to_reduction_descr=Map(), tags=expr.tags) diff --git a/pytato/utils.py b/pytato/utils.py index 920d537be48816c1a09e98ab69274857a8883ed2..58bcb5803fa2a92c85c0b358af5aa9839746e360 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -202,10 +202,10 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, expr2 = update_bindings_and_get_broadcasted_expr(a2, "_in1", bindings, result_shape) - return IndexLambda(op(expr1, expr2), + return IndexLambda(expr=op(expr1, expr2), shape=result_shape, dtype=result_dtype, - bindings=bindings, + bindings=Map(bindings), tags=_get_default_tags(), var_to_reduction_descr=Map(), axes=_get_default_axes(len(result_shape)))