diff --git a/pytato/array.py b/pytato/array.py index 03cef8ff529badf00871c1cdb87ff85270a955ee..43a2905d6bf71c9dd8abf0a7be4ae7be75cc9148 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -166,7 +166,7 @@ import attrs from typing import ( Optional, Callable, ClassVar, Dict, Any, Mapping, Tuple, Union, Protocol, Sequence, cast, TYPE_CHECKING, List, Iterator, TypeVar, - FrozenSet) + FrozenSet, Collection) import numpy as np import pymbolic.primitives as prim @@ -2633,12 +2633,27 @@ def broadcast_to(array: Array, shape: ShapeType) -> Array: # {{{ squeeze -def squeeze(array: Array) -> Array: - """Remove single-dimensional entries from the shape of an array.""" +def squeeze(array: Array, axis: Optional[Collection[int]] = None) -> Array: + """ + Remove single-dimensional entries from the shape of an array. + + :arg axis: Subset of 1-long axes of *array* that must be removed. If *None* + all 1-long axes are removed. + """ from pytato.utils import are_shape_components_equal + one_d_axes = frozenset({idim + for idim in range(array.ndim) + if are_shape_components_equal(array.shape[idim], 1)}) + if axis is None: + axis = one_d_axes + else: + axis = frozenset(axis) + if not (axis <= one_d_axes): + raise ValueError("cannot squeeze an axis which is not 1-long") + return array[tuple( - 0 if are_shape_components_equal(s_i, 1) else slice(s_i) + 0 if i in axis else slice(s_i) for i, s_i in enumerate(array.shape))] # }}}