From 8e9dbb860c39dba3ef6b2b47b5600dcc290d8160 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 12 Feb 2025 09:48:02 -0600 Subject: [PATCH] Better typing for iname tag filtering --- loopy/check.py | 2 ++ loopy/kernel/__init__.py | 13 ++++++++----- loopy/kernel/data.py | 20 ++++++++++++++++---- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/loopy/check.py b/loopy/check.py index 1a63c90b..0a901e3a 100644 --- a/loopy/check.py +++ b/loopy/check.py @@ -50,6 +50,7 @@ from loopy.kernel.data import ( AddressSpace, ArrayArg, ArrayDimImplementationTag, + AxisTag, InameImplementationTag, TemporaryVariable, auto, @@ -1426,6 +1427,7 @@ def _check_for_unused_hw_axes_in_kernel_chunk( iname, AutoLocalInameTagBase, max_num=1) if ltags: + tag: AxisTag tag, = ltags local_axes_used.add(tag.axis) elif gtags: diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index f487078c..af783bab 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -38,6 +38,7 @@ from functools import cached_property from sys import intern from typing import ( TYPE_CHECKING, + AbstractSet, Any, Callable, ClassVar, @@ -59,7 +60,7 @@ from pytools import ( memoize_method, natsorted, ) -from pytools.tag import Tag, Taggable +from pytools.tag import Tag, Taggable, TagT import loopy.codegen import loopy.kernel.data # to help out Sphinx @@ -539,14 +540,16 @@ class LoopKernel(Taggable): def iname_tags(self, iname): return self.inames[iname].tags - def iname_tags_of_type(self, iname, tag_type_or_types, - max_num=None, min_num=None): + def iname_tags_of_type( + self, iname: str, + tag_type_or_types: type[TagT] | tuple[type[TagT], ...], + max_num: int | None = None, + min_num: int | None = None + ) -> AbstractSet[TagT]: """Return a subset of *tags* that matches type *tag_type*. Raises exception if the number of tags found were greater than *max_num* or less than *min_num*. - :arg tags: An iterable of tags. - :arg tag_type_or_types: a subclass of :class:`loopy.kernel.data.InameTag`. :arg max_num: the maximum number of tags expected to be found. :arg min_num: the minimum number of tags expected to be found. """ diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index 3dd1cf82..474f444e 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -43,7 +43,7 @@ import numpy # FIXME: imported as numpy to allow sphinx to resolve things import numpy as np from pytools import ImmutableRecord -from pytools.tag import Tag, Taggable, UniqueTag as UniqueTagBase +from pytools.tag import Tag, Taggable, TagT, UniqueTag as UniqueTagBase from loopy.diagnostic import LoopyError from loopy.kernel.array import ArrayBase, ArrayDimImplementationTag @@ -64,7 +64,7 @@ from loopy.typing import Expression, ShapeType, auto if TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import Iterable, Mapping from pymbolic import ArithmeticExpression, Variable @@ -98,6 +98,10 @@ References .. class:: ToLoopyTypeConvertible See :class:`loopy.ToLoopyTypeConvertible`. + +.. class:: TagT + + A type variable with a lower bound of :class:`pytools.tag.Tag`. """ # This docstring is included in ref_internals. Do not include parts of the public @@ -143,7 +147,12 @@ def _names_from_dim_tags( # {{{ iname tags -def filter_iname_tags_by_type(tags, tag_type, max_num=None, min_num=None): +def filter_iname_tags_by_type( + tags: Iterable[Tag], + tag_type: type[TagT] | tuple[type[TagT], ...], + max_num: int | None = None, + min_num: int | None = None, + ) -> set[TagT]: """Return a subset of *tags* that matches type *tag_type*. Raises exception if the number of tags found were greater than *max_num* or less than *min_num*. @@ -154,7 +163,9 @@ def filter_iname_tags_by_type(tags, tag_type, max_num=None, min_num=None): :arg min_num: the minimum number of tags expected to be found. """ - result = {tag for tag in tags if isinstance(tag, tag_type)} + result: set[TagT] = cast( + "set[TagT]", + {tag for tag in tags if isinstance(tag, tag_type)}) def strify_tag_type(): if isinstance(tag_type, tuple): @@ -170,6 +181,7 @@ def filter_iname_tags_by_type(tags, tag_type, max_num=None, min_num=None): if len(result) < min_num: raise LoopyError("must have more than {} tags " "of type(s): {}".format(max_num, strify_tag_type())) + return result -- GitLab