diff --git a/pytato/__init__.py b/pytato/__init__.py index c45e6224fcccd814af606a8aa2076e9f5f20c5b1..582b8c61e7f45c971fb4dee1a0efa5d22b17d5b1 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -30,7 +30,7 @@ from pytato.array import ( make_dict_of_named_arrays, make_placeholder, make_size_param, make_data_wrapper, - matmul, roll, transpose, stack, reshape, + matmul, roll, transpose, stack, reshape, concatenate, ) from pytato.codegen import generate_loopy @@ -44,7 +44,7 @@ __all__ = ( "make_dict_of_named_arrays", "make_placeholder", "make_size_param", "make_data_wrapper", - "matmul", "roll", "transpose", "stack", "reshape", + "matmul", "roll", "transpose", "stack", "reshape", "concatenate", "generate_loopy", diff --git a/pytato/array.py b/pytato/array.py index 563d9f92fba4d69e1695230b14ce078f07ebd8bd..cae62639c72cdd59b4cd53b6826b227fc463a122 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -54,6 +54,7 @@ These functions generally follow the interface of the corresponding functions in .. autofunction:: roll .. autofunction:: transpose .. autofunction:: stack +.. autofunction:: concatenate Supporting Functionality ------------------------ @@ -83,6 +84,7 @@ Built-in Expression Nodes .. autoclass:: MatrixProduct .. autoclass:: LoopyFunction .. autoclass:: Stack +.. autoclass:: Concatenate .. autoclass:: AttributeLookup Index Remapping @@ -966,6 +968,51 @@ class Stack(Array): # }}} +# {{{ concatenate + +class Concatenate(Array): + """Join a sequence of arrays along an existing axis. + + .. attribute:: arrays + + An instance of :class:`tuple` of the arrays to join. The arrays must + have same shape except for the dimension corresponding to *axis*. + + .. attribute:: axis + + The axis along which the *arrays* are to be concatenated. + """ + + _fields = Array._fields + ("arrays", "axis") + _mapper_method = "map_concatenate" + + def __init__(self, + arrays: Tuple[Array, ...], + axis: int, + tags: Optional[TagsType] = None): + super().__init__(tags) + self.arrays = arrays + self.axis = axis + + @property + def namespace(self) -> Namespace: + return self.arrays[0].namespace + + @property + def dtype(self) -> np.dtype: + return np.result_type(*(arr.dtype for arr in self.arrays)) + + @property + def shape(self) -> ShapeType: + common_axis_len = sum(ary.shape[self.axis] for ary in self.arrays) + + return (self.arrays[0].shape[:self.axis] + + (common_axis_len,) + + self.arrays[0].shape[self.axis+1:]) + +# }}} + + # {{{ attribute lookup class AttributeLookup(Array): @@ -1419,6 +1466,41 @@ def stack(arrays: Sequence[Array], axis: int = 0) -> Array: return Stack(tuple(arrays), axis) +def concatenate(arrays: Sequence[Array], axis: int = 0) -> Array: + """Join a sequence of arrays along an existing axis. + + Example:: + + >>> arrays = [pt.zeros(3)] * 4 + >>> pt.concatenate(arrays, axis=0).shape + (12,) + + :param arrays: a finite sequence, each of whose elements is an + :class:`Array` . The arrays are of the same shape except along the + *axis* dimension. + :param axis: The axis along which the arrays will be concatenated. + """ + + if not arrays: + raise ValueError("need at least one array to stack") + + if not all(array.namespace is arrays[0].namespace for array in arrays): + raise ValueError("arrays must belong to the same namespace.") + + def shape_except_axis(ary: Array) -> Tuple[int, ...]: + return ary.shape[:axis] + ary.shape[axis+1:] + + for array in arrays[1:]: + if shape_except_axis(array) != shape_except_axis(arrays[0]): + raise ValueError("arrays must have the same shape expect along" + f" dimension #{axis}.") + + if not (0 <= axis <= arrays[0].ndim): + raise ValueError("invalid axis") + + return Concatenate(tuple(arrays), axis) + + def _make_slice(array: Array, starts: Sequence[int], stops: Sequence[int]) -> Array: """Extract a constant-sized slice from an array with constant offsets. diff --git a/pytato/codegen.py b/pytato/codegen.py index c5e389c33cc02ffeec51c26c6c07abd9090c69ab..b17cc185d49d1cec4c93579cc6013f9fd02e7609 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -26,7 +26,8 @@ import dataclasses from functools import partialmethod import re from typing import ( - Union, Optional, Mapping, Dict, Tuple, FrozenSet, Set, Callable, List) + Union, Optional, Mapping, Dict, Tuple, FrozenSet, Set, Callable, List, + Any) import islpy as isl import loopy as lp @@ -38,7 +39,7 @@ from pytato.array import ( Array, DictOfNamedArrays, ShapeType, IndexLambda, SizeParam, DataWrapper, InputArgumentBase, MatrixProduct, Roll, AxisPermutation, Slice, IndexRemappingBase, Stack, Placeholder, - Reshape, Namespace, DataInterface) + Reshape, Concatenate, Namespace, DataInterface) from pytato.program import BoundProgram from pytato.target import Target, PyOpenCLTarget import pytato.scalar_expr as scalar_expr @@ -110,6 +111,7 @@ class CodeGenPreprocessor(CopyMapper): :class:`~pytato.array.AxisPermutation` :class:`~pytato.array.IndexLambda` :class:`~pytato.array.Slice` :class:`~pytato.array.IndexLambda` :class:`~pytato.array.Reshape` :class:`~pytato.array.IndexLambda` + :class:`~pytato.array.Concatenate` :class:`~pytato.array.IndexLambda` ====================================== ===================================== """ @@ -165,6 +167,51 @@ class CodeGenPreprocessor(CopyMapper): dtype=expr.dtype, bindings=bindings) + def map_concatenate(self, expr: Concatenate) -> Array: + from pymbolic.primitives import If, Comparison, Subscript + + def get_subscript(array_index: int, offset: ScalarExpression) -> Subscript: + aggregate = var(f"_in{array_index}") + index = [var(f"_{i}") if i != expr.axis else (var(f"_{i}") - offset) + for i in range(len(expr.shape))] + return Subscript(aggregate, tuple(index)) + + lbounds: List[Any] = [0] + ubounds: List[Any] = [expr.arrays[0].shape[expr.axis]] + + for i, array in enumerate(expr.arrays[1:], start=1): + ubounds.append(ubounds[i-1]+array.shape[expr.axis]) + lbounds.append(ubounds[i-1]) + + # I = axis index + # + # => If(0<=_I < arrays[0].shape[axis], + # _in0[_0, _1, ..., _I, ...], + # If(arrays[0].shape[axis]<= _I < (arrays[1].shape[axis] + # +arrays[0].shape[axis]), + # _in1[_0, _1, ..., _I-arrays[0].shape[axis], ...], + # ... + # _inNm1[_0, _1, ...] ...)) + for i in range(len(expr.arrays) - 1, -1, -1): + lbound, ubound = lbounds[i], ubounds[i] + subarray_expr = get_subscript(i, lbound) + if i == len(expr.arrays) - 1: + stack_expr = subarray_expr + else: + stack_expr = If(Comparison(var(f"_{expr.axis}"), ">=", lbound) + and Comparison(var(f"_{expr.axis}"), "<", ubound), + subarray_expr, + stack_expr) + + bindings = {f"_in{i}": self.rec(array) + for i, array in enumerate(expr.arrays)} + + return IndexLambda(namespace=self.namespace, + expr=stack_expr, + shape=expr.shape, + dtype=expr.dtype, + bindings=bindings) + # {{{ index remapping (roll, axis permutation, slice) def handle_index_remapping(self, diff --git a/pytato/transform.py b/pytato/transform.py index 8d31a36fa2c457bc4fcaccf172ede520e2fef186..6485699c5073690cc36e9e466d49fdc9528a14c9 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -29,7 +29,7 @@ from typing import Any, Callable, Dict, FrozenSet from pytato.array import ( Array, IndexLambda, Namespace, Placeholder, MatrixProduct, Stack, Roll, AxisPermutation, Slice, DataWrapper, SizeParam, - DictOfNamedArrays, Reshape) + DictOfNamedArrays, Reshape, Concatenate) __doc__ = """ .. currentmodule:: pytato.transform @@ -202,6 +202,10 @@ class DependencyMapper(Mapper): def map_reshape(self, expr: Reshape) -> FrozenSet[Array]: return self.combine(frozenset([expr]), self.rec(expr.array)) + def map_concatenate(self, expr: Concatenate) -> FrozenSet[Array]: + return self.combine(frozenset([expr]), *(self.rec(ary) + for ary in expr.arrays)) + # }}} diff --git a/test/test_codegen.py b/test/test_codegen.py index 94b0a16f11a22da651a3528ee778290b1fa104f8..4b7f661004528b0d37f3dbf671fdc1835a517c12 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -38,6 +38,7 @@ import pytest # noqa import pytato as pt from pytato.array import Placeholder +from testlib import assert_allclose_to_numpy def test_basic_codegen(ctx_factory): @@ -159,15 +160,10 @@ def test_roll(ctx_factory, shift, axis): pt.make_size_param(namespace, "n") x = pt.make_placeholder(namespace, name="x", shape=("n", "n"), dtype=np.float) - prog = pt.generate_loopy( - pt.roll(x, shift=shift, axis=axis), - target=pt.PyOpenCLTarget(queue)) - x_in = np.arange(1., 10.).reshape(3, 3) - - _, (x_out,) = prog(x=x_in) - - assert (x_out == np.roll(x_in, shift=shift, axis=axis)).all() + assert_allclose_to_numpy(pt.roll(x, shift=shift, axis=axis), + queue, + {x: x_in}) @pytest.mark.parametrize("axes", ( @@ -187,12 +183,7 @@ def test_axis_permutation(ctx_factory, axes): namespace = pt.Namespace() x = pt.make_data_wrapper(namespace, x_in) - prog = pt.generate_loopy( - pt.transpose(x, axes), - target=pt.PyOpenCLTarget(queue)) - - _, (x_out,) = prog() - assert (x_out == np.transpose(x_in, axes)).all() + assert_allclose_to_numpy(pt.transpose(x, axes), queue) def test_transpose(ctx_factory): @@ -207,10 +198,7 @@ def test_transpose(ctx_factory): namespace = pt.Namespace() x = pt.make_data_wrapper(namespace, x_in) - prog = pt.generate_loopy(x.T, target=pt.PyOpenCLTarget(queue)) - - _, (x_out,) = prog() - assert (x_out == x_in.T).all() + assert_allclose_to_numpy(x.T, queue) # Doesn't include: ? (boolean), g (float128), G (complex256) @@ -397,12 +385,25 @@ def test_stack(ctx_factory, input_dims): y = pt.make_data_wrapper(namespace, y_in) for axis in range(0, 1 + input_dims): - prog = pt.generate_loopy( - pt.stack((x, y), axis=axis), - target=pt.PyOpenCLTarget(queue)) + assert_allclose_to_numpy(pt.stack((x, y), axis=axis), queue) + + +def test_concatenate(ctx_factory): + cl_ctx = ctx_factory() + queue = cl.CommandQueue(cl_ctx) + + from numpy.random import default_rng + rng = default_rng() + x0_in = rng.random(size=(3, 9, 3)) + x1_in = rng.random(size=(3, 11, 3)) + x2_in = rng.random(size=(3, 22, 3)) + + namespace = pt.Namespace() + x0 = pt.make_data_wrapper(namespace, x0_in) + x1 = pt.make_data_wrapper(namespace, x1_in) + x2 = pt.make_data_wrapper(namespace, x2_in) - _, (out,) = prog() - assert (out == np.stack((x_in, y_in), axis=axis)).all() + assert_allclose_to_numpy(pt.concatenate((x0, x1, x2), axis=1), queue) @pytest.mark.parametrize("oldshape", [(36,), @@ -424,15 +425,8 @@ def test_reshape(ctx_factory, oldshape, newshape): namespace = pt.Namespace() x = pt.make_data_wrapper(namespace, x_in) - expected_out = np.reshape(x_in, newshape=newshape) - prog = pt.generate_loopy( - pt.reshape(x, newshape=newshape), - target=pt.PyOpenCLTarget(queue)) - - _, (out,) = prog() - assert expected_out.shape == out.shape - assert (out == expected_out).all() + assert_allclose_to_numpy(pt.reshape(x, newshape=newshape), queue) def test_dict_of_named_array_codegen_avoids_recomputation(): diff --git a/test/test_pytato.py b/test/test_pytato.py index 191dc5b61d3ab69edeb0b76c9839fe42e3a8d69b..77050e101eb94e4fa5c4baf711ff526495209af7 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -136,6 +136,28 @@ def test_zero_length_arrays(): assert all(dom.is_empty() for dom in knl.domains if dom.total_dim() != 0) +def test_concatenate_input_validation(): + namespace = pt.Namespace() + + x = pt.make_placeholder(namespace, name="x", shape=(10, 10), dtype=np.float) + y = pt.make_placeholder(namespace, name="y", shape=(1, 10), dtype=np.float) + + assert pt.concatenate((x, x, x), axis=0).shape == (30, 10) + assert pt.concatenate((x, y), axis=0).shape == (11, 10) + + pt.concatenate((x,), axis=0) + pt.concatenate((x,), axis=1) + + with pytest.raises(ValueError): + pt.concatenate(()) + + with pytest.raises(ValueError): + pt.concatenate((x, y), axis=1) + + with pytest.raises(ValueError): + pt.concatenate((x, x), axis=3) + + def test_reshape_input_validation(): ns = pt.Namespace() diff --git a/test/testlib.py b/test/testlib.py new file mode 100644 index 0000000000000000000000000000000000000000..8d334531171efa388801aeeb454cd7a66b906e16 --- /dev/null +++ b/test/testlib.py @@ -0,0 +1,72 @@ +from typing import (Any, Dict) +import pyopencl as cl +import numpy +import pytato as pt +from pytato.transform import Mapper +from pytato.array import (Array, Placeholder, MatrixProduct, Stack, Roll, + AxisPermutation, Slice, DataWrapper, Reshape, + Concatenate, Namespace) + + +class NumpyBasedEvaluator(Mapper): + """ + Mapper to return the result according to an eager evaluation array package + *np*. + """ + def __init__(self, np: Any, namespace: Namespace, placeholders): + self.np = np + self.namespace = namespace + self.placeholders = placeholders + super().__init__() + + def map_placeholder(self, expr: Placeholder) -> Any: + return self.placeholders[expr] + + def map_data_wrapper(self, expr: DataWrapper) -> Any: + return self.namespace[expr.name].data + + def map_matrix_product(self, expr: MatrixProduct) -> Any: + return self.np.dot(self.rec(expr.x1), self.rec(expr.x2)) + + def map_stack(self, expr: Stack) -> Any: + arrays = [self.rec(array) for array in expr.arrays] + return self.np.stack(arrays, expr.axis) + + def map_roll(self, expr: Roll) -> Any: + return self.np.roll(self.rec(expr.array), expr.shift, expr.axis) + + def map_axis_permutation(self, expr: AxisPermutation) -> Any: + return self.np.transpose(self.rec(expr.array), expr.axes) + + def map_slice(self, expr: Slice) -> Any: + array = self.rec(expr.array) + return array[tuple(slice(start, stop) + for start, stop in zip(expr.starts, expr.stops))] + + def map_reshape(self, expr: Reshape) -> Any: + return self.np.reshape(self.rec(expr.array), expr.newshape, expr.order) + + def map_concatenate(self, expr: Concatenate) -> Any: + arrays = [self.rec(array) for array in expr.arrays] + return self.np.concatenate(arrays, expr.axis) + + +def assert_allclose_to_numpy(expr: Array, queue: cl.CommandQueue, + parameters: Dict[Placeholder, Any] = {}): + """ + Raises an :class:`AssertionError`, if there is a discrepancy between *expr* + evaluated lazily via :mod:`pytato` and eagerly via :mod:`numpy`. + + :arg queue: An instance of :class:`pyopencl.CommandQueue` to which the + generated kernel must be enqueued. + """ + np_result = NumpyBasedEvaluator(numpy, expr.namespace, parameters)(expr) + prog = pt.generate_loopy(expr, target=pt.PyOpenCLTarget(queue)) + + evt, (pt_result,) = prog(**{placeholder.name: data + for placeholder, data in parameters.items()}) + + assert pt_result.shape == np_result.shape + assert pt_result.dtype == np_result.dtype + + numpy.testing.assert_allclose(np_result, pt_result)