diff --git a/pyopencl/scan.py b/pyopencl/scan.py index 6e9c92d8fc57d595580402ab5a10b0106ec97ccf..f2c42bc254e36c0e48591b4f2c521ff9fed0c5af 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -22,19 +22,25 @@ limitations under the License. Derived from code within the Thrust project, https://github.com/thrust/thrust/ """ +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List +from typing import Any, Dict, List, Optional, Set, Tuple, Union import numpy as np import pyopencl as cl -import pyopencl.array # noqa -from pyopencl.tools import (dtype_to_ctype, bitlog2, - KernelTemplateBase, _process_code_for_macro, - get_arg_list_scalar_arg_dtypes, +import pyopencl.array +from pyopencl.tools import ( + KernelTemplateBase, + DtypedArgument, + bitlog2, context_dependent_memoize, + dtype_to_ctype, + get_arg_list_scalar_arg_dtypes, + get_arg_offset_adjuster_code, + _process_code_for_macro, _NumpyTypesKeyBuilder, - get_arg_offset_adjuster_code) + ) import pyopencl._mymako as mako from pyopencl._cluda import CLUDA_PREAMBLE @@ -745,7 +751,7 @@ void ${name_prefix}_final_update( # {{{ helpers -def _round_down_to_power_of_2(val): +def _round_down_to_power_of_2(val: int) -> int: result = 2**bitlog2(val) if result > val: result >>= 1 @@ -839,10 +845,11 @@ _IGNORED_WORDS = set(""" """.split()) -def _make_template(s): +def _make_template(s: str): + import re leftovers = set() - def replace_id(match): + def replace_id(match: "re.Match") -> str: # avoid name clashes with user code by adding 'psc_' prefix to # identifiers. @@ -850,30 +857,28 @@ def _make_template(s): if word in _IGNORED_WORDS: return word elif word in _PREFIX_WORDS: - return "psc_"+word + return f"psc_{word}" else: leftovers.add(word) return word - import re s = re.sub(r"\b([a-zA-Z0-9_]+)\b", replace_id, s) - if leftovers: from warnings import warn warn("leftover words in identifier prefixing: " + " ".join(leftovers)) - return mako.template.Template(s, strict_undefined=True) + return mako.template.Template(s, strict_undefined=True) # type: ignore @dataclass(frozen=True) class _GeneratedScanKernelInfo: scan_src: str kernel_name: str - scalar_arg_dtypes: List["np.dtype"] + scalar_arg_dtypes: List[Optional[np.dtype]] wg_size: int k_group_size: int - def build(self, context, options): + def build(self, context: cl.Context, options: Any) -> "_BuiltScanKernelInfo": program = cl.Program(context, self.scan_src).build(options) kernel = getattr(program, self.kernel_name) kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes) @@ -894,10 +899,12 @@ class _BuiltScanKernelInfo: class _GeneratedFinalUpdateKernelInfo: source: str kernel_name: str - scalar_arg_dtypes: List["np.dtype"] + scalar_arg_dtypes: List[Optional[np.dtype]] update_wg_size: int - def build(self, context, options): + def build(self, + context: cl.Context, + options: Any) -> "_BuiltFinalUpdateKernelInfo": program = cl.Program(context, self.source).build(options) kernel = getattr(program, self.kernel_name) kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes) @@ -916,14 +923,25 @@ class ScanPerformanceWarning(UserWarning): pass -class _GenericScanKernelBase: +class _GenericScanKernelBase(ABC): # {{{ constructor, argument processing - def __init__(self, ctx, dtype, - arguments, input_expr, scan_expr, neutral, output_statement, - is_segment_start_expr=None, input_fetch_exprs=None, - index_dtype=np.int32, - name_prefix="scan", options=None, preamble="", devices=None): + def __init__( + self, + ctx: cl.Context, + dtype: Any, + arguments: Union[str, List[DtypedArgument]], + input_expr: str, + scan_expr: str, + neutral: Optional[str], + output_statement: str, + is_segment_start_expr: Optional[str] = None, + input_fetch_exprs: Optional[List[Tuple[str, str, int]]] = None, + index_dtype: Any = np.int32, + name_prefix: str = "scan", + options: Any = None, + preamble: str = "", + devices: Optional[cl.Device] = None) -> None: """ :arg ctx: a :class:`pyopencl.Context` within which the code for this scan kernel will be generated. @@ -1114,8 +1132,9 @@ class _GenericScanKernelBase: # }}} - def finish_setup(self): - raise NotImplementedError + @abstractmethod + def finish_setup(self) -> None: + pass generic_scan_kernel_cache = WriteOncePersistentDict( @@ -1139,10 +1158,9 @@ class GenericScanKernel(_GenericScanKernelBase): a = cl.array.arange(queue, 10000, dtype=np.int32) knl(a, queue=queue) - """ - def finish_setup(self): + def finish_setup(self) -> None: # Before generating the kernel, see if it's cached. from pyopencl.cache import get_device_cache_id devices_key = tuple(get_device_cache_id(device) @@ -1188,7 +1206,7 @@ class GenericScanKernel(_GenericScanKernelBase): self.context, self.options) del self.final_update_gen_info - def _finish_setup_impl(self): + def _finish_setup_impl(self) -> None: # {{{ find usable workgroup/k-group size, build first-level scan trip_count = 0 @@ -1296,7 +1314,7 @@ class GenericScanKernel(_GenericScanKernelBase): second_level_arguments = self.parsed_args + [ VectorArg(self.dtype, "interval_sums")] - second_level_build_kwargs = {} + second_level_build_kwargs: Dict[str, Optional[str]] = {} if self.is_segmented: second_level_arguments.append( VectorArg(self.index_dtype, @@ -1360,12 +1378,14 @@ class GenericScanKernel(_GenericScanKernelBase): # {{{ scan kernel build/properties - def get_local_mem_use(self, k_group_size, wg_size, use_bank_conflict_avoidance): + def get_local_mem_use( + self, k_group_size: int, wg_size: int, + use_bank_conflict_avoidance: bool) -> int: arg_dtypes = {} for arg in self.parsed_args: arg_dtypes[arg.name] = arg.dtype - fetch_expr_offsets = {} + fetch_expr_offsets: Dict[str, Set] = {} for _name, arg_name, ife_offset in self.input_fetch_exprs: fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset) @@ -1388,10 +1408,17 @@ class GenericScanKernel(_GenericScanKernelBase): for arg_name, ife_offsets in list(fetch_expr_offsets.items()) if -1 in ife_offsets or len(ife_offsets) > 1)) - def generate_scan_kernel(self, max_wg_size, arguments, input_expr, - is_segment_start_expr, input_fetch_exprs, is_first_level, - store_segment_start_flags, k_group_size, - use_bank_conflict_avoidance): + def generate_scan_kernel( + self, + max_wg_size: int, + arguments: List[DtypedArgument], + input_expr: str, + is_segment_start_expr: Optional[str], + input_fetch_exprs: List[Tuple[str, str, int]], + is_first_level: bool, + store_segment_start_flags: bool, + k_group_size: int, + use_bank_conflict_avoidance: bool) -> _GeneratedScanKernelInfo: scalar_arg_dtypes = get_arg_list_scalar_arg_dtypes(arguments) # Empirically found on Nv hardware: no need to be bigger than this size @@ -1437,7 +1464,7 @@ class GenericScanKernel(_GenericScanKernelBase): # }}} - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> cl.Event: # {{{ argument processing allocator = kwargs.get("allocator") @@ -1451,8 +1478,8 @@ class GenericScanKernel(_GenericScanKernelBase): wait_for = list(wait_for) if len(args) != len(self.parsed_args): - raise TypeError("expected %d arguments, got %d" % - (len(self.parsed_args), len(args))) + raise TypeError( + f"expected {len(self.parsed_args)} arguments, got {len(args)}") first_array = args[self.first_array_idx] allocator = allocator or first_array.allocator @@ -1631,7 +1658,7 @@ void ${name_prefix}_debug_scan( class GenericDebugScanKernel(_GenericScanKernelBase): - def finish_setup(self): + def finish_setup(self) -> None: scan_tpl = _make_template(DEBUG_SCAN_TEMPLATE) scan_src = str(scan_tpl.render( output_statement=self.output_statement, @@ -1645,15 +1672,14 @@ class GenericDebugScanKernel(_GenericScanKernelBase): **self.code_variables)) scan_prg = cl.Program(self.context, scan_src).build(self.options) - self.kernel = getattr( - scan_prg, self.name_prefix+"_debug_scan") + self.kernel = getattr(scan_prg, f"{self.name_prefix}_debug_scan") scalar_arg_dtypes = ( [None] + get_arg_list_scalar_arg_dtypes(self.parsed_args) + [self.index_dtype]) self.kernel.set_scalar_arg_dtypes(scalar_arg_dtypes) - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> cl.Event: # {{{ argument processing allocator = kwargs.get("allocator") @@ -1668,8 +1694,8 @@ class GenericDebugScanKernel(_GenericScanKernelBase): wait_for = list(wait_for) if len(args) != len(self.parsed_args): - raise TypeError("expected %d arguments, got %d" % - (len(self.parsed_args), len(args))) + raise TypeError( + f"expected {len(self.parsed_args)} arguments, got {len(args)}") first_array = args[self.first_array_idx] allocator = allocator or first_array.allocator @@ -1763,15 +1789,23 @@ class ExclusiveScanKernel(_LegacyScanKernelBase): # {{{ template class ScanTemplate(KernelTemplateBase): - def __init__(self, - arguments, input_expr, scan_expr, neutral, output_statement, - is_segment_start_expr=None, input_fetch_exprs=None, - name_prefix="scan", preamble="", template_processor=None): + def __init__( + self, + arguments: Union[str, List[DtypedArgument]], + input_expr: str, + scan_expr: str, + neutral: Optional[str], + output_statement: str, + is_segment_start_expr: Optional[str] = None, + input_fetch_exprs: Optional[List[Tuple[str, str, int]]] = None, + name_prefix: str = "scan", + preamble: str = "", + template_processor: Any = None) -> None: + super().__init__(template_processor=template_processor) if input_fetch_exprs is None: input_fetch_exprs = [] - KernelTemplateBase.__init__(self, template_processor=template_processor) self.arguments = arguments self.input_expr = input_expr self.scan_expr = scan_expr