From 29a294cf14168f8a930c28b2880c3978e7208376 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 24 Apr 2024 09:42:19 -0500 Subject: [PATCH] Add type aliases for CallablesTable --- loopy/codegen/__init__.py | 6 ++++-- loopy/kernel/function_interface.py | 13 +++++++++++-- loopy/library/function.py | 7 +++---- loopy/schedule/__init__.py | 14 +++++++++----- loopy/transform/precompute.py | 10 +++++----- loopy/transform/realize_reduction.py | 4 ++-- loopy/translation_unit.py | 4 +++- 7 files changed, 37 insertions(+), 21 deletions(-) diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py index 32f89992a..68c41336c 100644 --- a/loopy/codegen/__init__.py +++ b/loopy/codegen/__init__.py @@ -26,6 +26,9 @@ from typing import (Set, Mapping, Sequence, Any, FrozenSet, Union, Optional, Tuple, TYPE_CHECKING) from dataclasses import dataclass, replace import logging + +from loopy.codegen.result import CodeGenerationResult +from loopy.translation_unit import CallablesTable, TranslationUnit logger = logging.getLogger(__name__) import islpy as isl @@ -40,7 +43,6 @@ from loopy.types import LoopyType from loopy.typing import ExpressionT from loopy.kernel import LoopKernel from loopy.target import TargetBase -from loopy.kernel.function_interface import InKernelCallable from loopy.symbolic import CombineMapper @@ -192,7 +194,7 @@ class CodeGenerationState: var_subst_map: Map[str, ExpressionT] allow_complex: bool - callables_table: Mapping[str, InKernelCallable] + callables_table: CallablesTable is_entrypoint: bool var_name_generator: UniqueNameGenerator is_generating_device_code: bool diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py index eb373a12d..9a21d70f0 100644 --- a/loopy/kernel/function_interface.py +++ b/loopy/kernel/function_interface.py @@ -1,3 +1,5 @@ +from __future__ import annotations + __copyright__ = "Copyright (C) 2018 Andreas Kloeckner, Kaushik Kulkarni" __license__ = """ @@ -20,7 +22,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import ClassVar, Tuple +from typing import ClassVar, FrozenSet, Tuple, TYPE_CHECKING from pytools import ImmutableRecord from loopy.diagnostic import LoopyError @@ -31,6 +33,9 @@ from loopy.kernel.array import ArrayBase from loopy.kernel.data import ValueArg, ArrayArg from loopy.symbolic import DependencyMapper, WalkMapper +if TYPE_CHECKING: + from loopy.translation_unit import CallablesTable, FunctionIdT + __doc__ = """ .. currentmodule:: loopy.kernel.function_interface @@ -453,7 +458,11 @@ class InKernelCallable(ImmutableRecord): """ raise NotImplementedError() - def get_called_callables(self, callables_table, recursive=True): + def get_called_callables( + self, + callables_table: CallablesTable, + recursive: bool = True + ) -> FrozenSet[FunctionIdT]: """ Returns a :class:`frozenset` of callable ids called by *self* that are resolved via *callables_table*. diff --git a/loopy/library/function.py b/loopy/library/function.py index 9c465653f..a42359c03 100644 --- a/loopy/library/function.py +++ b/loopy/library/function.py @@ -22,6 +22,7 @@ THE SOFTWARE. from loopy.kernel.function_interface import ScalarCallable from loopy.diagnostic import LoopyError +from loopy.translation_unit import CallablesTable from loopy.types import NumpyType import numpy as np @@ -105,7 +106,7 @@ class IndexOfCallable(ScalarCallable): target), True -def get_loopy_callables(): +def get_loopy_callables() -> CallablesTable: """ Returns a mapping from function ids to corresponding :class:`loopy.kernel.function_interface.InKernelCallable` for functions @@ -116,13 +117,11 @@ def get_loopy_callables(): - callables that have a predefined meaning in :mod:`loo.py` like ``make_tuple``, ``index_of``, ``indexof_vec``. """ - known_callables = { + return { "make_tuple": MakeTupleCallable(name="make_tuple"), "indexof": IndexOfCallable(name="indexof"), "indexof_vec": IndexOfCallable(name="indexof_vec"), } - return known_callables - # vim: foldmethod=marker diff --git a/loopy/schedule/__init__.py b/loopy/schedule/__init__.py index f80aa6f37..5be848e02 100644 --- a/loopy/schedule/__init__.py +++ b/loopy/schedule/__init__.py @@ -42,7 +42,6 @@ from loopy.version import DATA_MODEL_VERSION if TYPE_CHECKING: from loopy.kernel import LoopKernel - from loopy.kernel.function_interface import InKernelCallable logger = logging.getLogger(__name__) @@ -2000,7 +1999,7 @@ class MinRecursionLimitForScheduling(MinRecursionLimit): def generate_loop_schedules( kernel: LoopKernel, - callables_table: Mapping[str, InKernelCallable], + callables_table: CallablesTable, debug_args: Optional[Dict[str, Any]] = None) -> Iterator[LoopKernel]: """ .. warning:: @@ -2022,7 +2021,7 @@ def generate_loop_schedules( def _generate_loop_schedules_inner( kernel: LoopKernel, - callables_table: Mapping[str, InKernelCallable], + callables_table: CallablesTable, debug_args: Optional[Dict[str, Any]]) -> Iterator[LoopKernel]: if debug_args is None: debug_args = {} @@ -2206,7 +2205,10 @@ schedule_cache = WriteOncePersistentDict( caches.append(schedule_cache) -def _get_one_linearized_kernel_inner(kernel, callables_table): +def _get_one_linearized_kernel_inner( + kernel: LoopKernel, + callables_table: CallablesTable + ) -> LoopKernel: # This helper function exists to ensure that the generator chain is fully # out of scope after the function returns. This allows it to be # garbage-collected in the exit handler of the @@ -2219,7 +2221,9 @@ def _get_one_linearized_kernel_inner(kernel, callables_table): return next(iter(generate_loop_schedules(kernel, callables_table))) -def get_one_linearized_kernel(kernel, callables_table): +def get_one_linearized_kernel( + kernel: LoopKernel, + callables_table: CallablesTable) -> LoopKernel: from loopy import CACHING_ENABLED # must include *callables_table* within the cache key as the preschedule diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py index a3f0a5dd5..6ae5139f5 100644 --- a/loopy/transform/precompute.py +++ b/loopy/transform/precompute.py @@ -22,7 +22,7 @@ THE SOFTWARE. from dataclasses import dataclass -from typing import FrozenSet, List, Mapping, Optional, Sequence, Type, Union +from typing import FrozenSet, List, Optional, Sequence, Type, Union from immutables import Map import islpy as isl from pytools.tag import Tag @@ -34,10 +34,9 @@ from loopy.symbolic import (get_dependencies, SubstitutionRuleMappingContext, CombineMapper) from loopy.diagnostic import LoopyError from pymbolic.mapper.substitutor import make_subst_func -from loopy.translation_unit import TranslationUnit +from loopy.translation_unit import CallablesTable, TranslationUnit from loopy.kernel.instruction import InstructionBase, MultiAssignmentBase -from loopy.kernel.function_interface import (CallableKernel, InKernelCallable, - ScalarCallable) +from loopy.kernel.function_interface import CallableKernel, ScalarCallable from loopy.kernel.tools import (kernel_has_global_barriers, find_most_recent_global_barrier) from loopy.kernel.data import AddressSpace @@ -359,7 +358,8 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper): def precompute_for_single_kernel( kernel: LoopKernel, - callables_table: Mapping[str, InKernelCallable], subst_use, + callables_table: CallablesTable, + subst_use, sweep_inames=None, within: ToStackMatchCovertible = None, *, diff --git a/loopy/transform/realize_reduction.py b/loopy/transform/realize_reduction.py index c211ab18e..7ef773408 100644 --- a/loopy/transform/realize_reduction.py +++ b/loopy/transform/realize_reduction.py @@ -40,7 +40,7 @@ from immutables import Map from loopy.kernel.data import make_assignment from loopy.symbolic import ReductionCallbackMapper -from loopy.translation_unit import TranslationUnit +from loopy.translation_unit import ConcreteCallablesTable, TranslationUnit from loopy.kernel.function_interface import CallableKernel from loopy.kernel.data import TemporaryVariable, AddressSpace from loopy.kernel.instruction import ( @@ -90,7 +90,7 @@ class _ReductionRealizationContext: domains: List[isl.BasicSet] additional_iname_tags: Dict[str, Sequence[Tag]] # list only to facilitate mutation - boxed_callables_table: List[Map] + boxed_callables_table: List[ConcreteCallablesTable] # FIXME: This is a broken-by-design concept. Local-parallel scans emit a # reduction internally. This serves to avoid force_scan acting on that diff --git a/loopy/translation_unit.py b/loopy/translation_unit.py index 39fdb2275..364191a30 100644 --- a/loopy/translation_unit.py +++ b/loopy/translation_unit.py @@ -137,6 +137,8 @@ class CallableResolver(RuleAwareIdentityMapper): # {{{ translation unit FunctionIdT = Union[str, ReductionOpFunction] +ConcreteCallablesTable = Map[FunctionIdT, InKernelCallable] +CallablesTable = Mapping[FunctionIdT, InKernelCallable] @dataclass(frozen=True) @@ -191,7 +193,7 @@ class TranslationUnit: """ - callables_table: Map[FunctionIdT, CallableKernel] + callables_table: ConcreteCallablesTable target: TargetBase entrypoints: FrozenSet[str] -- GitLab