Skip to content
Commits on Source (4)
......@@ -26,8 +26,8 @@ call site. For example, a call to ``sin(x)`` in :mod:`loopy` is type-generic to
begin with, but it later specialized to either ``sinf``, ``sin`` or ``sinl``
depending on the type of its argument ``x``. A callable's behavior during type
or shape specialization is encoded via
:meth:`~loopy.kernel.function_interface.InKernelCallable.with_types` and
:meth:`~loopy.kernel.function_interface.InKernelCallable.with_descrs`.
:meth:`~loopy.InKernelCallable.with_types` and
:meth:`~loopy.InKernelCallable.with_descrs`.
Registering callables
......
.. currentmodule:: loopy
TranslationUnit
===============
.. autoclass:: TranslationUnit
Reference
---------
Translation Units
=================
.. automodule:: loopy.translation_unit
......@@ -45,7 +45,7 @@ from loopy.kernel.data import (
SubstitutionRule,
CallMangleInfo)
from loopy.kernel.function_interface import (
CallableKernel, ScalarCallable)
InKernelCallable, CallableKernel, ScalarCallable)
from loopy.translation_unit import (
TranslationUnit, make_program)
......@@ -186,7 +186,7 @@ __all__ = [
"CallInstruction", "CInstruction", "NoOpInstruction",
"BarrierInstruction",
"ScalarCallable", "CallableKernel",
"InKernelCallable", "ScalarCallable", "CallableKernel",
"TranslationUnit", "make_program",
......
......@@ -26,6 +26,9 @@ from typing import (Set, Mapping, Sequence, Any, FrozenSet, Union,
Optional, Tuple, TYPE_CHECKING)
from dataclasses import dataclass, replace
import logging
from loopy.codegen.result import CodeGenerationResult
from loopy.translation_unit import CallablesTable, TranslationUnit
logger = logging.getLogger(__name__)
import islpy as isl
......@@ -40,7 +43,6 @@ from loopy.types import LoopyType
from loopy.typing import ExpressionT
from loopy.kernel import LoopKernel
from loopy.target import TargetBase
from loopy.kernel.function_interface import InKernelCallable
from loopy.symbolic import CombineMapper
......@@ -192,7 +194,7 @@ class CodeGenerationState:
var_subst_map: Map[str, ExpressionT]
allow_complex: bool
callables_table: Mapping[str, InKernelCallable]
callables_table: CallablesTable
is_entrypoint: bool
var_name_generator: UniqueNameGenerator
is_generating_device_code: bool
......@@ -310,7 +312,10 @@ class CodeGenerationState:
# }}}
code_gen_cache = WriteOncePersistentDict(
code_gen_cache: WriteOncePersistentDict[
TranslationUnit,
CodeGenerationResult
] = WriteOncePersistentDict(
"loopy-code-gen-cache-v3-"+DATA_MODEL_VERSION,
key_builder=LoopyKeyBuilder())
......@@ -561,13 +566,7 @@ class TranslationUnitCodeGenerationResult:
self.host_programs.values()))
def generate_code_v2(program):
"""
Returns an instance of :class:`CodeGenerationResult`.
:param program: An instance of :class:`loopy.TranslationUnit`.
"""
def generate_code_v2(t_unit: TranslationUnit) -> CodeGenerationResult:
from loopy.kernel import LoopKernel
from loopy.translation_unit import make_program
......@@ -576,46 +575,46 @@ def generate_code_v2(program):
from loopy import CACHING_ENABLED
if CACHING_ENABLED:
input_program = program
input_t_unit = t_unit
try:
result = code_gen_cache[input_program]
logger.debug(f"TranslationUnit with entrypoints {program.entrypoints}:"
result = code_gen_cache[input_t_unit]
logger.debug(f"TranslationUnit with entrypoints {t_unit.entrypoints}:"
" code generation cache hit")
return result
except KeyError:
logger.debug(f"TranslationUnit with entrypoints {program.entrypoints}:"
logger.debug(f"TranslationUnit with entrypoints {t_unit.entrypoints}:"
" code generation cache miss")
# }}}
if isinstance(program, LoopKernel):
program = make_program(program)
if isinstance(t_unit, LoopKernel):
t_unit = make_program(t_unit)
from loopy.kernel import KernelState
if program.state < KernelState.PREPROCESSED:
if t_unit.state < KernelState.PREPROCESSED:
# Note that we cannot have preprocessing separately for everyone.
# Since, now the preprocessing of each one depends on the other.
# So we check if any one of the callable kernels are not preprocesses
# then, we have to do the preprocessing of every other kernel.
from loopy.preprocess import preprocess_program
program = preprocess_program(program)
t_unit = preprocess_program(t_unit)
from loopy.type_inference import infer_unknown_types
program = infer_unknown_types(program, expect_completion=True)
t_unit = infer_unknown_types(t_unit, expect_completion=True)
if program.state < KernelState.LINEARIZED:
if t_unit.state < KernelState.LINEARIZED:
from loopy.schedule import linearize
program = linearize(program)
t_unit = linearize(t_unit)
# Why diverge? Generated code for a non-entrypoint kernel and an entrypoint
# kernel isn't same for a general loopy target. For example in OpenCL, a
# kernel callable from host and the one supposed to be callable from device
# have different function signatures. To generate correct code, each
# callable should be exclusively an entrypoint or a non-entrypoint kernel.
program = diverge_callee_entrypoints(program)
t_unit = diverge_callee_entrypoints(t_unit)
from loopy.check import pre_codegen_checks
pre_codegen_checks(program)
pre_codegen_checks(t_unit)
host_programs = {}
device_programs = []
......@@ -624,13 +623,13 @@ def generate_code_v2(program):
# {{{ collect host/device programs
for func_id in sorted(key for key, val in program.callables_table.items()
for func_id in sorted(key for key, val in t_unit.callables_table.items()
if isinstance(val, CallableKernel)):
cgr = generate_code_for_a_single_kernel(program[func_id],
program.callables_table,
program.target,
func_id in program.entrypoints)
if func_id in program.entrypoints:
cgr = generate_code_for_a_single_kernel(t_unit[func_id],
t_unit.callables_table,
t_unit.target,
func_id in t_unit.entrypoints)
if func_id in t_unit.entrypoints:
host_programs[func_id] = cgr.host_program
else:
assert len(cgr.device_programs) == 1
......@@ -643,14 +642,14 @@ def generate_code_v2(program):
# {{{ collect preambles
for clbl in program.callables_table.values():
device_preambles.extend(list(clbl.generate_preambles(program.target)))
for clbl in t_unit.callables_table.values():
device_preambles.extend(list(clbl.generate_preambles(t_unit.target)))
# }}}
# adding the callee fdecls to the device_programs
device_programs = ([device_programs[0].copy(
ast=program.target.get_device_ast_builder().ast_module.Collection(
ast=t_unit.target.get_device_ast_builder().ast_module.Collection(
callee_fdecls+[device_programs[0].ast]))] +
device_programs[1:])
cgr = TranslationUnitCodeGenerationResult(
......@@ -659,7 +658,7 @@ def generate_code_v2(program):
device_preambles=device_preambles)
if CACHING_ENABLED:
code_gen_cache.store_if_not_present(input_program, cgr)
code_gen_cache.store_if_not_present(input_t_unit, cgr)
return cgr
......
from __future__ import annotations
__copyright__ = "Copyright (C) 2018 Andreas Kloeckner, Kaushik Kulkarni"
__license__ = """
......@@ -20,7 +22,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from typing import ClassVar, Tuple
from typing import ClassVar, FrozenSet, Tuple, TYPE_CHECKING
from pytools import ImmutableRecord
from loopy.diagnostic import LoopyError
......@@ -31,6 +33,9 @@ from loopy.kernel.array import ArrayBase
from loopy.kernel.data import ValueArg, ArrayArg
from loopy.symbolic import DependencyMapper, WalkMapper
if TYPE_CHECKING:
from loopy.translation_unit import CallablesTable, FunctionIdT
__doc__ = """
.. currentmodule:: loopy.kernel.function_interface
......@@ -38,6 +43,8 @@ __doc__ = """
.. autoclass:: ArrayArgDescriptor
.. currentmodule:: loopy
.. autoclass:: InKernelCallable
.. autoclass:: CallableKernel
......@@ -64,7 +71,7 @@ class ArrayArgDescriptor(ImmutableRecord):
"""
Records information about an array argument to an in-kernel callable. To be
passed to and returned from
:meth:`InKernelCallable.with_descrs`, used for
:meth:`~loopy.InKernelCallable.with_descrs`, used for
matching shape and address space of caller and callee kernels.
.. attribute:: shape
......@@ -367,9 +374,10 @@ class InKernelCallable(ImmutableRecord):
def with_descrs(self, arg_id_to_descr, clbl_inf_ctx):
"""
:arg arg_id_to_descr: a mapping from argument identifiers (integers for
positional arguments) to instances of :class:`ArrayArgDescriptor`
or :class:`ValueArgDescriptor`. Unspecified/unknown descriptors are
not represented in *arg_id_to_type*.
positional arguments) to instances of
:class:`~loopy.kernel.function_interface.ArrayArgDescriptor`
or :class:`~loopy.kernel.function_interface.ValueArgDescriptor`.
Unspecified/unknown descriptors are not represented in *arg_id_to_type*.
Return values are denoted by negative integers, with the first
returned value identified as *-1*.
......@@ -453,7 +461,11 @@ class InKernelCallable(ImmutableRecord):
"""
raise NotImplementedError()
def get_called_callables(self, callables_table, recursive=True):
def get_called_callables(
self,
callables_table: CallablesTable,
recursive: bool = True
) -> FrozenSet[FunctionIdT]:
"""
Returns a :class:`frozenset` of callable ids called by *self* that are
resolved via *callables_table*.
......
......@@ -22,6 +22,7 @@ THE SOFTWARE.
from loopy.kernel.function_interface import ScalarCallable
from loopy.diagnostic import LoopyError
from loopy.translation_unit import CallablesTable
from loopy.types import NumpyType
import numpy as np
......@@ -105,7 +106,7 @@ class IndexOfCallable(ScalarCallable):
target), True
def get_loopy_callables():
def get_loopy_callables() -> CallablesTable:
"""
Returns a mapping from function ids to corresponding
:class:`loopy.kernel.function_interface.InKernelCallable` for functions
......@@ -116,13 +117,11 @@ def get_loopy_callables():
- callables that have a predefined meaning in :mod:`loo.py` like
``make_tuple``, ``index_of``, ``indexof_vec``.
"""
known_callables = {
return {
"make_tuple": MakeTupleCallable(name="make_tuple"),
"indexof": IndexOfCallable(name="indexof"),
"indexof_vec": IndexOfCallable(name="indexof_vec"),
}
return known_callables
# vim: foldmethod=marker
......@@ -42,7 +42,8 @@ from loopy.version import DATA_MODEL_VERSION
if TYPE_CHECKING:
from loopy.kernel import LoopKernel
from loopy.kernel.function_interface import InKernelCallable
from loopy.translation_unit import CallablesTable, TranslationUnit
logger = logging.getLogger(__name__)
......@@ -2000,7 +2001,7 @@ class MinRecursionLimitForScheduling(MinRecursionLimit):
def generate_loop_schedules(
kernel: LoopKernel,
callables_table: Mapping[str, InKernelCallable],
callables_table: CallablesTable,
debug_args: Optional[Dict[str, Any]] = None) -> Iterator[LoopKernel]:
"""
.. warning::
......@@ -2022,7 +2023,7 @@ def generate_loop_schedules(
def _generate_loop_schedules_inner(
kernel: LoopKernel,
callables_table: Mapping[str, InKernelCallable],
callables_table: CallablesTable,
debug_args: Optional[Dict[str, Any]]) -> Iterator[LoopKernel]:
if debug_args is None:
debug_args = {}
......@@ -2198,7 +2199,10 @@ def _generate_loop_schedules_inner(
# }}}
schedule_cache = WriteOncePersistentDict(
schedule_cache: WriteOncePersistentDict[
Tuple[LoopKernel, CallablesTable],
LoopKernel
] = WriteOncePersistentDict(
"loopy-schedule-cache-v4-"+DATA_MODEL_VERSION,
key_builder=LoopyKeyBuilder())
......@@ -2206,7 +2210,10 @@ schedule_cache = WriteOncePersistentDict(
caches.append(schedule_cache)
def _get_one_linearized_kernel_inner(kernel, callables_table):
def _get_one_linearized_kernel_inner(
kernel: LoopKernel,
callables_table: CallablesTable
) -> LoopKernel:
# This helper function exists to ensure that the generator chain is fully
# out of scope after the function returns. This allows it to be
# garbage-collected in the exit handler of the
......@@ -2219,7 +2226,9 @@ def _get_one_linearized_kernel_inner(kernel, callables_table):
return next(iter(generate_loop_schedules(kernel, callables_table)))
def get_one_linearized_kernel(kernel, callables_table):
def get_one_linearized_kernel(
kernel: LoopKernel,
callables_table: CallablesTable) -> LoopKernel:
from loopy import CACHING_ENABLED
# must include *callables_table* within the cache key as the preschedule
......@@ -2257,7 +2266,7 @@ def get_one_scheduled_kernel(kernel, callables_table):
return get_one_linearized_kernel(kernel, callables_table)
def linearize(t_unit):
def linearize(t_unit: TranslationUnit) -> TranslationUnit:
from loopy.kernel.function_interface import (CallableKernel,
ScalarCallable)
from loopy.check import pre_schedule_checks
......
......@@ -21,7 +21,7 @@ THE SOFTWARE.
"""
from typing import (Callable, Tuple, Union, Set, FrozenSet, List, Dict,
from typing import (Callable, Mapping, Tuple, Union, Set, FrozenSet, List, Dict,
Optional, Sequence, Any)
from dataclasses import dataclass
......@@ -721,7 +721,10 @@ class ExecutionWrapperGeneratorBase(ABC):
# }}}
typed_and_scheduled_cache = WriteOncePersistentDict(
typed_and_scheduled_cache: WriteOncePersistentDict[
Tuple[str, TranslationUnit, Optional[Mapping[str, LoopyType]]],
TranslationUnit
] = WriteOncePersistentDict(
"loopy-typed-and-scheduled-cache-v1-"+DATA_MODEL_VERSION,
key_builder=LoopyKeyBuilder())
......@@ -729,7 +732,10 @@ typed_and_scheduled_cache = WriteOncePersistentDict(
caches.append(typed_and_scheduled_cache)
invoker_cache = WriteOncePersistentDict(
invoker_cache: WriteOncePersistentDict[
Tuple[str, TranslationUnit, str],
str
] = WriteOncePersistentDict(
"loopy-invoker-cache-v10-"+DATA_MODEL_VERSION,
key_builder=LoopyKeyBuilder())
......@@ -848,12 +854,12 @@ class ExecutorBase:
logger.debug("%s: typed-and-scheduled cache miss" %
self.t_unit.entrypoints)
kernel = self.get_typed_and_scheduled_translation_unit_uncached(arg_to_dtype)
t_unit = self.get_typed_and_scheduled_translation_unit_uncached(arg_to_dtype)
if CACHING_ENABLED:
typed_and_scheduled_cache.store_if_not_present(cache_key, kernel)
typed_and_scheduled_cache.store_if_not_present(cache_key, t_unit)
return kernel
return t_unit
def arg_to_dtype(self, kwargs) -> Optional[Map[str, LoopyType]]:
if not self.has_runtime_typed_args:
......
......@@ -22,7 +22,7 @@ THE SOFTWARE.
from dataclasses import dataclass
from typing import FrozenSet, List, Mapping, Optional, Sequence, Type, Union
from typing import FrozenSet, List, Optional, Sequence, Type, Union
from immutables import Map
import islpy as isl
from pytools.tag import Tag
......@@ -34,10 +34,9 @@ from loopy.symbolic import (get_dependencies,
SubstitutionRuleMappingContext, CombineMapper)
from loopy.diagnostic import LoopyError
from pymbolic.mapper.substitutor import make_subst_func
from loopy.translation_unit import TranslationUnit
from loopy.translation_unit import CallablesTable, TranslationUnit
from loopy.kernel.instruction import InstructionBase, MultiAssignmentBase
from loopy.kernel.function_interface import (CallableKernel, InKernelCallable,
ScalarCallable)
from loopy.kernel.function_interface import CallableKernel, ScalarCallable
from loopy.kernel.tools import (kernel_has_global_barriers,
find_most_recent_global_barrier)
from loopy.kernel.data import AddressSpace
......@@ -359,7 +358,8 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper):
def precompute_for_single_kernel(
kernel: LoopKernel,
callables_table: Mapping[str, InKernelCallable], subst_use,
callables_table: CallablesTable,
subst_use,
sweep_inames=None,
within: ToStackMatchCovertible = None,
*,
......
......@@ -40,7 +40,7 @@ from immutables import Map
from loopy.kernel.data import make_assignment
from loopy.symbolic import ReductionCallbackMapper
from loopy.translation_unit import TranslationUnit
from loopy.translation_unit import ConcreteCallablesTable, TranslationUnit
from loopy.kernel.function_interface import CallableKernel
from loopy.kernel.data import TemporaryVariable, AddressSpace
from loopy.kernel.instruction import (
......@@ -90,7 +90,7 @@ class _ReductionRealizationContext:
domains: List[isl.BasicSet]
additional_iname_tags: Dict[str, Sequence[Tag]]
# list only to facilitate mutation
boxed_callables_table: List[Map]
boxed_callables_table: List[ConcreteCallablesTable]
# FIXME: This is a broken-by-design concept. Local-parallel scans emit a
# reduction internally. This serves to avoid force_scan acting on that
......
......@@ -48,6 +48,21 @@ if TYPE_CHECKING:
__doc__ = """
.. class:: FunctionIdT
A type for a function identifier.
A :class:`~loopy.library.reduction.ReductionOpFunction` or a :class:`str`.
.. class:: CallablesTable
A type alias for callables tables, mapping from :class:`FunctionIdT`
to :class:`~loopy.InKernelCallable`
.. currentmodule:: loopy
.. autoclass:: TranslationUnit
.. currentmodule:: loopy.translation_unit
.. autoclass:: CallablesInferenceContext
......@@ -137,6 +152,8 @@ class CallableResolver(RuleAwareIdentityMapper):
# {{{ translation unit
FunctionIdT = Union[str, ReductionOpFunction]
ConcreteCallablesTable = Map[FunctionIdT, InKernelCallable]
CallablesTable = Mapping[FunctionIdT, InKernelCallable]
@dataclass(frozen=True)
......@@ -191,7 +208,7 @@ class TranslationUnit:
"""
callables_table: Map[FunctionIdT, CallableKernel]
callables_table: ConcreteCallablesTable
target: TargetBase
entrypoints: FrozenSet[str]
......@@ -790,7 +807,7 @@ def add_callable_to_table(callables_table, clbl_id, clbl):
# {{{ resolve_callables
def resolve_callables(program):
def resolve_callables(t_unit: TranslationUnit) -> TranslationUnit:
"""
Returns a :class:`TranslationUnit` with known :class:`pymbolic.primitives.Call`
expression nodes converted to :class:`loopy.symbolic.ResolvedFunction`.
......@@ -799,21 +816,21 @@ def resolve_callables(program):
from loopy.check import validate_kernel_call_sites
from loopy.kernel import KernelState
if program.state >= KernelState.CALLS_RESOLVED:
if t_unit.state >= KernelState.CALLS_RESOLVED:
# program's callables have been resolved
return program
return t_unit
# get registered callables
known_callables = dict(program.callables_table)
known_callables = dict(t_unit.callables_table)
# get target specific callables
known_callables.update(program.target.get_device_ast_builder().known_callables)
known_callables.update(t_unit.target.get_device_ast_builder().known_callables)
# get loopy specific callables
known_callables.update(get_loopy_callables())
callables_table = {}
# callables: name of the calls seen in the program
callables = {name for name, clbl in program.callables_table.items()
callables = {name for name, clbl in t_unit.callables_table.items()
if isinstance(clbl, CallableKernel)}
while callables:
......@@ -841,11 +858,11 @@ def resolve_callables(program):
else:
raise NotImplementedError(f"{type(clbl)}")
program = program.copy(callables_table=Map(callables_table))
t_unit = t_unit.copy(callables_table=Map(callables_table))
validate_kernel_call_sites(program)
validate_kernel_call_sites(t_unit)
return program
return t_unit
# }}}
......
......@@ -84,7 +84,7 @@ setup(name="loopy",
python_requires="~=3.8",
install_requires=[
"pytools>=2023.1.1",
"pytools>=2024.1.2",
"pymbolic>=2022.1",
"genpy>=2016.1.2",
......