diff --git a/pytato/__init__.py b/pytato/__init__.py index 64d10257b721f74327bd28ad5c1596c6c500293c..c45e6224fcccd814af606a8aa2076e9f5f20c5b1 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -30,7 +30,7 @@ from pytato.array import ( make_dict_of_named_arrays, make_placeholder, make_size_param, make_data_wrapper, - matmul, roll, transpose, stack, + matmul, roll, transpose, stack, reshape, ) from pytato.codegen import generate_loopy @@ -44,7 +44,7 @@ __all__ = ( "make_dict_of_named_arrays", "make_placeholder", "make_size_param", "make_data_wrapper", - "matmul", "roll", "transpose", "stack", + "matmul", "roll", "transpose", "stack", "reshape", "generate_loopy", diff --git a/pytato/array.py b/pytato/array.py index a07155e41fc9188025b9b4fffa7961278ca9eb0e..563d9f92fba4d69e1695230b14ce078f07ebd8bd 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -153,7 +153,7 @@ from pytools import is_single_valued, memoize_method, UniqueNameGenerator from pytools.tag import Tag, UniqueTag, TagsType, tag_dataclass import pytato.scalar_expr as scalar_expr -from pytato.scalar_expr import ScalarExpression +from pytato.scalar_expr import ScalarExpression, IntegralScalarExpression # Get a type variable that represents the type of '...' @@ -247,11 +247,11 @@ class Namespace(Mapping[str, "Array"]): # {{{ shape -ShapeType = Tuple[ScalarExpression, ...] +ShapeType = Tuple[IntegralScalarExpression, ...] ConvertibleToShapeComponent = Union[int, prim.Expression, str] ConvertibleToShape = Union[ str, - ScalarExpression, + IntegralScalarExpression, Tuple[ConvertibleToShapeComponent, ...]] @@ -306,7 +306,8 @@ def normalize_shape( if isinstance(shape, (Number, prim.Expression)): shape = (shape,) - return tuple(normalize_shape_component(s) for s in shape) + # https://github.com/python/mypy/issues/3186 + return tuple(normalize_shape_component(s) for s in shape) # type: ignore # }}} @@ -1076,8 +1077,40 @@ class AxisPermutation(IndexRemappingBase): class Reshape(IndexRemappingBase): """ + Reshape an array. + + .. attribute:: array + + The array to be reshaped + + .. attribute:: newshape + + The output shape + + .. attribute:: order + + Output layout order, either ``C`` or ``F``. """ + _fields = Array._fields + ("array", "newshape", "order") + _mapper_method = "map_reshape" + + def __init__(self, + array: Array, + newshape: Tuple[int, ...], + order: str, + tags: Optional[TagsType] = None): + # FIXME: Get rid of this restriction + assert order == "C" + + super().__init__(array, tags) + self.newshape = newshape + self.order = order + + @property + def shape(self) -> Tuple[int, ...]: + return self.newshape + # }}} @@ -1431,6 +1464,60 @@ def _make_slice(array: Array, starts: Sequence[int], stops: Sequence[int]) -> Ar return Slice(array, tuple(starts), tuple(stops)) +def reshape(array: Array, newshape: Sequence[int], order: str = "C") -> Array: + """ + :param array: array to be reshaped + :param newshape: shape of the resulting array + :param order: ``"C"`` or ``"F"``. Layout order of the result array. Only + ``"C"`` allowed for now. + + .. note:: + + reshapes of arrays with symbolic shapes not yet implemented. + """ + from pytools import product + + if newshape.count(-1) > 1: + raise ValueError("can only specify one unknown dimension") + + if not all(isinstance(axis_len, int) for axis_len in array.shape): + raise ValueError("reshape of arrays with symbolic lengths not allowed") + + if order != "C": + raise NotImplementedError("Reshapes to a 'F'-ordered arrays") + + newshape_explicit = [] + + for new_axislen in newshape: + if not isinstance(new_axislen, int): + raise ValueError("Symbolic reshapes not allowed.") + + if not(new_axislen > 0 or new_axislen == -1): + raise ValueError("newshape should be either sequence of positive ints or" + " -1") + + # {{{ infer the axis length corresponding to axis marked "-1" + + if new_axislen == -1: + size_of_rest_of_newaxes = -1 * product(newshape) + + if array.size % size_of_rest_of_newaxes != 0: + raise ValueError(f"cannot reshape array of size {array.size}" + f" into ({size_of_rest_of_newaxes})") + + new_axislen = array.size // size_of_rest_of_newaxes + + # }}} + + newshape_explicit.append(new_axislen) + + if product(newshape_explicit) != array.size: + raise ValueError(f"cannot reshape array of size {array.size}" + f" into {newshape}") + + return Reshape(array, tuple(newshape_explicit), order) + + def make_dict_of_named_arrays(data: Dict[str, Array]) -> DictOfNamedArrays: """Make a :class:`DictOfNamedArrays` object and ensure that all arrays share the same namespace. diff --git a/pytato/codegen.py b/pytato/codegen.py index 6ad3958f02f4641a5a52a875bff03e419899e9e9..c5e389c33cc02ffeec51c26c6c07abd9090c69ab 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -38,7 +38,7 @@ from pytato.array import ( Array, DictOfNamedArrays, ShapeType, IndexLambda, SizeParam, DataWrapper, InputArgumentBase, MatrixProduct, Roll, AxisPermutation, Slice, IndexRemappingBase, Stack, Placeholder, - Namespace, DataInterface) + Reshape, Namespace, DataInterface) from pytato.program import BoundProgram from pytato.target import Target, PyOpenCLTarget import pytato.scalar_expr as scalar_expr @@ -109,6 +109,7 @@ class CodeGenPreprocessor(CopyMapper): :class:`~pytato.array.Roll` :class:`~pytato.array.IndexLambda` :class:`~pytato.array.AxisPermutation` :class:`~pytato.array.IndexLambda` :class:`~pytato.array.Slice` :class:`~pytato.array.IndexLambda` + :class:`~pytato.array.Reshape` :class:`~pytato.array.IndexLambda` ====================================== ===================================== """ @@ -198,11 +199,31 @@ class CodeGenPreprocessor(CopyMapper): def _indices_for_slice(self, expr: Slice) -> SymbolicIndex: return tuple(var(f"_{d}") + expr.starts[d] for d in range(expr.ndim)) + def _indices_for_reshape(self, expr: Reshape) -> SymbolicIndex: + newstrides = [1] # reshaped array strides + for axis_len in reversed(expr.shape[1:]): + newstrides.insert(0, newstrides[0]*axis_len) + + flattened_idx = sum(prim.Variable(f"_{i}")*stride + for i, stride in enumerate(newstrides)) + + oldstrides = [1] # input array strides + for axis_len in reversed(expr.array.shape[1:]): + oldstrides.insert(0, oldstrides[0]*axis_len) + + oldsizetills = [expr.array.shape[-1]] # input array size till for axes idx + for axis_len in reversed(expr.array.shape[:-1]): + oldsizetills.insert(0, oldsizetills[0]*axis_len) + + return tuple(((flattened_idx % sizetill) // stride) + for stride, sizetill in zip(oldstrides, oldsizetills)) + # https://github.com/python/mypy/issues/8619 map_roll = partialmethod(handle_index_remapping, _indices_for_roll) # type: ignore # noqa map_axis_permutation = ( partialmethod(handle_index_remapping, _indices_for_axis_permutation)) # type: ignore # noqa map_slice = partialmethod(handle_index_remapping, _indices_for_slice) # type: ignore # noqa + map_reshape = partialmethod(handle_index_remapping, _indices_for_reshape) # noqa # }}} diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 0cbc33935004f0a07f34bb4ca51a249b5f8d739b..a0921917c037b053a991f8dfb5ff6217967e1312 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -58,6 +58,7 @@ Scalar Expressions # {{{ scalar expressions +IntegralScalarExpression = Union[int, prim.Expression] ScalarExpression = Union[Number, prim.Expression] diff --git a/pytato/transform.py b/pytato/transform.py index f9aa5f024fe3261813855dce95e3dc248ebe70cf..8d31a36fa2c457bc4fcaccf172ede520e2fef186 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -29,7 +29,7 @@ from typing import Any, Callable, Dict, FrozenSet from pytato.array import ( Array, IndexLambda, Namespace, Placeholder, MatrixProduct, Stack, Roll, AxisPermutation, Slice, DataWrapper, SizeParam, - DictOfNamedArrays) + DictOfNamedArrays, Reshape) __doc__ = """ .. currentmodule:: pytato.transform @@ -199,6 +199,9 @@ class DependencyMapper(Mapper): def map_slice(self, expr: Slice) -> FrozenSet[Array]: return self.combine(frozenset([expr]), self.rec(expr.array)) + def map_reshape(self, expr: Reshape) -> FrozenSet[Array]: + return self.combine(frozenset([expr]), self.rec(expr.array)) + # }}} diff --git a/test/test_codegen.py b/test/test_codegen.py index 028fecff9e433d15fa65d394275ceebd8fde07f3..94b0a16f11a22da651a3528ee778290b1fa104f8 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -405,6 +405,36 @@ def test_stack(ctx_factory, input_dims): assert (out == np.stack((x_in, y_in), axis=axis)).all() +@pytest.mark.parametrize("oldshape", [(36,), + (3, 3, 4), + (12, 3), + (2, 2, 3, 3, 1)]) +@pytest.mark.parametrize("newshape", [(-1,), + (-1, 6), + (4, 9), + (9, -1), + (36, -1)]) +def test_reshape(ctx_factory, oldshape, newshape): + cl_ctx = ctx_factory() + queue = cl.CommandQueue(cl_ctx) + + from numpy.random import default_rng + rng = default_rng() + x_in = rng.random(size=oldshape) + + namespace = pt.Namespace() + x = pt.make_data_wrapper(namespace, x_in) + expected_out = np.reshape(x_in, newshape=newshape) + + prog = pt.generate_loopy( + pt.reshape(x, newshape=newshape), + target=pt.PyOpenCLTarget(queue)) + + _, (out,) = prog() + assert expected_out.shape == out.shape + assert (out == expected_out).all() + + def test_dict_of_named_array_codegen_avoids_recomputation(): ns = pt.Namespace() x = pt.make_placeholder(ns, shape=(10, 4), dtype=float, name="x") diff --git a/test/test_pytato.py b/test/test_pytato.py index 1923807b0cd9835904a31fe0045153ab96a7a249..191dc5b61d3ab69edeb0b76c9839fe42e3a8d69b 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -136,6 +136,25 @@ def test_zero_length_arrays(): assert all(dom.is_empty() for dom in knl.domains if dom.total_dim() != 0) +def test_reshape_input_validation(): + ns = pt.Namespace() + + x = pt.make_placeholder(ns, shape=(3, 3, 4), dtype=np.float) + + assert pt.reshape(x, (-1,)).shape == (36,) + assert pt.reshape(x, (-1, 6)).shape == (6, 6) + assert pt.reshape(x, (4, -1)).shape == (4, 9) + assert pt.reshape(x, (36, -1)).shape == (36, 1) + + with pytest.raises(ValueError): + # 36 not a multiple of 25 + pt.reshape(x, (5, 5)) + + with pytest.raises(ValueError): + # 2 unknown dimensions + pt.reshape(x, (-1, -1, 3)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])