From adf1b10e48c8ea245ee53e92d82014f340bd8ed7 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 18 Mar 2025 15:15:46 -0500 Subject: [PATCH] Add typing.assert_tuple --- loopy/transform/data.py | 10 ++++++---- loopy/typing.py | 8 ++++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/loopy/transform/data.py b/loopy/transform/data.py index 2b0606ec..2662ca15 100644 --- a/loopy/transform/data.py +++ b/loopy/transform/data.py @@ -22,7 +22,6 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - from dataclasses import dataclass, replace from typing import TYPE_CHECKING, cast from warnings import warn @@ -39,9 +38,12 @@ from loopy.kernel.data import AddressSpace, ImageArg, TemporaryVariable, auto from loopy.kernel.function_interface import CallableKernel, ScalarCallable from loopy.translation_unit import TranslationUnit, for_each_kernel from loopy.types import LoopyType +from loopy.typing import assert_tuple if TYPE_CHECKING: + from pymbolic import ArithmeticExpression + from loopy.typing import Expression @@ -990,7 +992,7 @@ def add_padding_to_avoid_bank_conflicts(kernel, device): @dataclass(frozen=True) class _BaseStorageInfo: name: str - next_offset: Expression + next_offset: ArithmeticExpression approx_nbytes: int | None = None @@ -1086,8 +1088,8 @@ def allocate_temporaries_for_base_storage(kernel: LoopKernel, else tv._base_storage_access_may_be_aliasing)) bs_tv = new_tvs[bsi.name] - assert isinstance(bs_tv.shape, tuple) - bs_size, = bs_tv.shape + bs_size: ArithmeticExpression + bs_size, = assert_tuple(bs_tv.shape) if aliased: new_bs_size = _sym_max(bs_size, ary_size) else: diff --git a/loopy/typing.py b/loopy/typing.py index 5316c356..5d578ac7 100644 --- a/loopy/typing.py +++ b/loopy/typing.py @@ -85,3 +85,11 @@ def integer_expr_or_err(expr: Expression) -> Integer | ExpressionNode: return expr else: raise ValueError(f"expected integer or expression, got {type(expr)}") + + +ElT = TypeVar("ElT") + + +def assert_tuple(obj: tuple[ElT, ...] | object) -> tuple[ElT, ...]: + assert isinstance(obj, tuple) + return obj -- GitLab