From 6e423050a6c546ebcc194450fc7a793dcb4254ec Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 17 Apr 2023 20:49:08 -0500 Subject: [PATCH] Add `axis` argument to pytato.squeeze --- pytato/array.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 03cef8f..43a2905 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -166,7 +166,7 @@ import attrs from typing import ( Optional, Callable, ClassVar, Dict, Any, Mapping, Tuple, Union, Protocol, Sequence, cast, TYPE_CHECKING, List, Iterator, TypeVar, - FrozenSet) + FrozenSet, Collection) import numpy as np import pymbolic.primitives as prim @@ -2633,12 +2633,27 @@ def broadcast_to(array: Array, shape: ShapeType) -> Array: # {{{ squeeze -def squeeze(array: Array) -> Array: - """Remove single-dimensional entries from the shape of an array.""" +def squeeze(array: Array, axis: Optional[Collection[int]] = None) -> Array: + """ + Remove single-dimensional entries from the shape of an array. + + :arg axis: Subset of 1-long axes of *array* that must be removed. If *None* + all 1-long axes are removed. + """ from pytato.utils import are_shape_components_equal + one_d_axes = frozenset({idim + for idim in range(array.ndim) + if are_shape_components_equal(array.shape[idim], 1)}) + if axis is None: + axis = one_d_axes + else: + axis = frozenset(axis) + if not (axis <= one_d_axes): + raise ValueError("cannot squeeze an axis which is not 1-long") + return array[tuple( - 0 if are_shape_components_equal(s_i, 1) else slice(s_i) + 0 if i in axis else slice(s_i) for i, s_i in enumerate(array.shape))] # }}} -- GitLab