Skip to content
Snippets Groups Projects
Commit 828d7c08 authored by Alexandru Fikl's avatar Alexandru Fikl Committed by Andreas Klöckner
Browse files

add types to scan kernels

parent a1504e46
No related branches found
No related tags found
No related merge requests found
...@@ -22,19 +22,25 @@ limitations under the License. ...@@ -22,19 +22,25 @@ limitations under the License.
Derived from code within the Thrust project, https://github.com/thrust/thrust/ Derived from code within the Thrust project, https://github.com/thrust/thrust/
""" """
from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import Any, Dict, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import pyopencl as cl import pyopencl as cl
import pyopencl.array # noqa import pyopencl.array
from pyopencl.tools import (dtype_to_ctype, bitlog2, from pyopencl.tools import (
KernelTemplateBase, _process_code_for_macro, KernelTemplateBase,
get_arg_list_scalar_arg_dtypes, DtypedArgument,
bitlog2,
context_dependent_memoize, context_dependent_memoize,
dtype_to_ctype,
get_arg_list_scalar_arg_dtypes,
get_arg_offset_adjuster_code,
_process_code_for_macro,
_NumpyTypesKeyBuilder, _NumpyTypesKeyBuilder,
get_arg_offset_adjuster_code) )
import pyopencl._mymako as mako import pyopencl._mymako as mako
from pyopencl._cluda import CLUDA_PREAMBLE from pyopencl._cluda import CLUDA_PREAMBLE
...@@ -745,7 +751,7 @@ void ${name_prefix}_final_update( ...@@ -745,7 +751,7 @@ void ${name_prefix}_final_update(
# {{{ helpers # {{{ helpers
def _round_down_to_power_of_2(val): def _round_down_to_power_of_2(val: int) -> int:
result = 2**bitlog2(val) result = 2**bitlog2(val)
if result > val: if result > val:
result >>= 1 result >>= 1
...@@ -839,10 +845,11 @@ _IGNORED_WORDS = set(""" ...@@ -839,10 +845,11 @@ _IGNORED_WORDS = set("""
""".split()) """.split())
def _make_template(s): def _make_template(s: str):
import re
leftovers = set() leftovers = set()
def replace_id(match): def replace_id(match: "re.Match") -> str:
# avoid name clashes with user code by adding 'psc_' prefix to # avoid name clashes with user code by adding 'psc_' prefix to
# identifiers. # identifiers.
...@@ -850,30 +857,28 @@ def _make_template(s): ...@@ -850,30 +857,28 @@ def _make_template(s):
if word in _IGNORED_WORDS: if word in _IGNORED_WORDS:
return word return word
elif word in _PREFIX_WORDS: elif word in _PREFIX_WORDS:
return "psc_"+word return f"psc_{word}"
else: else:
leftovers.add(word) leftovers.add(word)
return word return word
import re
s = re.sub(r"\b([a-zA-Z0-9_]+)\b", replace_id, s) s = re.sub(r"\b([a-zA-Z0-9_]+)\b", replace_id, s)
if leftovers: if leftovers:
from warnings import warn from warnings import warn
warn("leftover words in identifier prefixing: " + " ".join(leftovers)) warn("leftover words in identifier prefixing: " + " ".join(leftovers))
return mako.template.Template(s, strict_undefined=True) return mako.template.Template(s, strict_undefined=True) # type: ignore
@dataclass(frozen=True) @dataclass(frozen=True)
class _GeneratedScanKernelInfo: class _GeneratedScanKernelInfo:
scan_src: str scan_src: str
kernel_name: str kernel_name: str
scalar_arg_dtypes: List["np.dtype"] scalar_arg_dtypes: List[Optional[np.dtype]]
wg_size: int wg_size: int
k_group_size: int k_group_size: int
def build(self, context, options): def build(self, context: cl.Context, options: Any) -> "_BuiltScanKernelInfo":
program = cl.Program(context, self.scan_src).build(options) program = cl.Program(context, self.scan_src).build(options)
kernel = getattr(program, self.kernel_name) kernel = getattr(program, self.kernel_name)
kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes) kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
...@@ -894,10 +899,12 @@ class _BuiltScanKernelInfo: ...@@ -894,10 +899,12 @@ class _BuiltScanKernelInfo:
class _GeneratedFinalUpdateKernelInfo: class _GeneratedFinalUpdateKernelInfo:
source: str source: str
kernel_name: str kernel_name: str
scalar_arg_dtypes: List["np.dtype"] scalar_arg_dtypes: List[Optional[np.dtype]]
update_wg_size: int update_wg_size: int
def build(self, context, options): def build(self,
context: cl.Context,
options: Any) -> "_BuiltFinalUpdateKernelInfo":
program = cl.Program(context, self.source).build(options) program = cl.Program(context, self.source).build(options)
kernel = getattr(program, self.kernel_name) kernel = getattr(program, self.kernel_name)
kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes) kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
...@@ -916,14 +923,25 @@ class ScanPerformanceWarning(UserWarning): ...@@ -916,14 +923,25 @@ class ScanPerformanceWarning(UserWarning):
pass pass
class _GenericScanKernelBase: class _GenericScanKernelBase(ABC):
# {{{ constructor, argument processing # {{{ constructor, argument processing
def __init__(self, ctx, dtype, def __init__(
arguments, input_expr, scan_expr, neutral, output_statement, self,
is_segment_start_expr=None, input_fetch_exprs=None, ctx: cl.Context,
index_dtype=np.int32, dtype: Any,
name_prefix="scan", options=None, preamble="", devices=None): arguments: Union[str, List[DtypedArgument]],
input_expr: str,
scan_expr: str,
neutral: Optional[str],
output_statement: str,
is_segment_start_expr: Optional[str] = None,
input_fetch_exprs: Optional[List[Tuple[str, str, int]]] = None,
index_dtype: Any = np.int32,
name_prefix: str = "scan",
options: Any = None,
preamble: str = "",
devices: Optional[cl.Device] = None) -> None:
""" """
:arg ctx: a :class:`pyopencl.Context` within which the code :arg ctx: a :class:`pyopencl.Context` within which the code
for this scan kernel will be generated. for this scan kernel will be generated.
...@@ -1114,8 +1132,9 @@ class _GenericScanKernelBase: ...@@ -1114,8 +1132,9 @@ class _GenericScanKernelBase:
# }}} # }}}
def finish_setup(self): @abstractmethod
raise NotImplementedError def finish_setup(self) -> None:
pass
generic_scan_kernel_cache = WriteOncePersistentDict( generic_scan_kernel_cache = WriteOncePersistentDict(
...@@ -1139,10 +1158,9 @@ class GenericScanKernel(_GenericScanKernelBase): ...@@ -1139,10 +1158,9 @@ class GenericScanKernel(_GenericScanKernelBase):
a = cl.array.arange(queue, 10000, dtype=np.int32) a = cl.array.arange(queue, 10000, dtype=np.int32)
knl(a, queue=queue) knl(a, queue=queue)
""" """
def finish_setup(self): def finish_setup(self) -> None:
# Before generating the kernel, see if it's cached. # Before generating the kernel, see if it's cached.
from pyopencl.cache import get_device_cache_id from pyopencl.cache import get_device_cache_id
devices_key = tuple(get_device_cache_id(device) devices_key = tuple(get_device_cache_id(device)
...@@ -1188,7 +1206,7 @@ class GenericScanKernel(_GenericScanKernelBase): ...@@ -1188,7 +1206,7 @@ class GenericScanKernel(_GenericScanKernelBase):
self.context, self.options) self.context, self.options)
del self.final_update_gen_info del self.final_update_gen_info
def _finish_setup_impl(self): def _finish_setup_impl(self) -> None:
# {{{ find usable workgroup/k-group size, build first-level scan # {{{ find usable workgroup/k-group size, build first-level scan
trip_count = 0 trip_count = 0
...@@ -1296,7 +1314,7 @@ class GenericScanKernel(_GenericScanKernelBase): ...@@ -1296,7 +1314,7 @@ class GenericScanKernel(_GenericScanKernelBase):
second_level_arguments = self.parsed_args + [ second_level_arguments = self.parsed_args + [
VectorArg(self.dtype, "interval_sums")] VectorArg(self.dtype, "interval_sums")]
second_level_build_kwargs = {} second_level_build_kwargs: Dict[str, Optional[str]] = {}
if self.is_segmented: if self.is_segmented:
second_level_arguments.append( second_level_arguments.append(
VectorArg(self.index_dtype, VectorArg(self.index_dtype,
...@@ -1360,12 +1378,14 @@ class GenericScanKernel(_GenericScanKernelBase): ...@@ -1360,12 +1378,14 @@ class GenericScanKernel(_GenericScanKernelBase):
# {{{ scan kernel build/properties # {{{ scan kernel build/properties
def get_local_mem_use(self, k_group_size, wg_size, use_bank_conflict_avoidance): def get_local_mem_use(
self, k_group_size: int, wg_size: int,
use_bank_conflict_avoidance: bool) -> int:
arg_dtypes = {} arg_dtypes = {}
for arg in self.parsed_args: for arg in self.parsed_args:
arg_dtypes[arg.name] = arg.dtype arg_dtypes[arg.name] = arg.dtype
fetch_expr_offsets = {} fetch_expr_offsets: Dict[str, Set] = {}
for _name, arg_name, ife_offset in self.input_fetch_exprs: for _name, arg_name, ife_offset in self.input_fetch_exprs:
fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset) fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
...@@ -1388,10 +1408,17 @@ class GenericScanKernel(_GenericScanKernelBase): ...@@ -1388,10 +1408,17 @@ class GenericScanKernel(_GenericScanKernelBase):
for arg_name, ife_offsets in list(fetch_expr_offsets.items()) for arg_name, ife_offsets in list(fetch_expr_offsets.items())
if -1 in ife_offsets or len(ife_offsets) > 1)) if -1 in ife_offsets or len(ife_offsets) > 1))
def generate_scan_kernel(self, max_wg_size, arguments, input_expr, def generate_scan_kernel(
is_segment_start_expr, input_fetch_exprs, is_first_level, self,
store_segment_start_flags, k_group_size, max_wg_size: int,
use_bank_conflict_avoidance): arguments: List[DtypedArgument],
input_expr: str,
is_segment_start_expr: Optional[str],
input_fetch_exprs: List[Tuple[str, str, int]],
is_first_level: bool,
store_segment_start_flags: bool,
k_group_size: int,
use_bank_conflict_avoidance: bool) -> _GeneratedScanKernelInfo:
scalar_arg_dtypes = get_arg_list_scalar_arg_dtypes(arguments) scalar_arg_dtypes = get_arg_list_scalar_arg_dtypes(arguments)
# Empirically found on Nv hardware: no need to be bigger than this size # Empirically found on Nv hardware: no need to be bigger than this size
...@@ -1437,7 +1464,7 @@ class GenericScanKernel(_GenericScanKernelBase): ...@@ -1437,7 +1464,7 @@ class GenericScanKernel(_GenericScanKernelBase):
# }}} # }}}
def __call__(self, *args, **kwargs): def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
# {{{ argument processing # {{{ argument processing
allocator = kwargs.get("allocator") allocator = kwargs.get("allocator")
...@@ -1451,8 +1478,8 @@ class GenericScanKernel(_GenericScanKernelBase): ...@@ -1451,8 +1478,8 @@ class GenericScanKernel(_GenericScanKernelBase):
wait_for = list(wait_for) wait_for = list(wait_for)
if len(args) != len(self.parsed_args): if len(args) != len(self.parsed_args):
raise TypeError("expected %d arguments, got %d" % raise TypeError(
(len(self.parsed_args), len(args))) f"expected {len(self.parsed_args)} arguments, got {len(args)}")
first_array = args[self.first_array_idx] first_array = args[self.first_array_idx]
allocator = allocator or first_array.allocator allocator = allocator or first_array.allocator
...@@ -1631,7 +1658,7 @@ void ${name_prefix}_debug_scan( ...@@ -1631,7 +1658,7 @@ void ${name_prefix}_debug_scan(
class GenericDebugScanKernel(_GenericScanKernelBase): class GenericDebugScanKernel(_GenericScanKernelBase):
def finish_setup(self): def finish_setup(self) -> None:
scan_tpl = _make_template(DEBUG_SCAN_TEMPLATE) scan_tpl = _make_template(DEBUG_SCAN_TEMPLATE)
scan_src = str(scan_tpl.render( scan_src = str(scan_tpl.render(
output_statement=self.output_statement, output_statement=self.output_statement,
...@@ -1645,15 +1672,14 @@ class GenericDebugScanKernel(_GenericScanKernelBase): ...@@ -1645,15 +1672,14 @@ class GenericDebugScanKernel(_GenericScanKernelBase):
**self.code_variables)) **self.code_variables))
scan_prg = cl.Program(self.context, scan_src).build(self.options) scan_prg = cl.Program(self.context, scan_src).build(self.options)
self.kernel = getattr( self.kernel = getattr(scan_prg, f"{self.name_prefix}_debug_scan")
scan_prg, self.name_prefix+"_debug_scan")
scalar_arg_dtypes = ( scalar_arg_dtypes = (
[None] [None]
+ get_arg_list_scalar_arg_dtypes(self.parsed_args) + get_arg_list_scalar_arg_dtypes(self.parsed_args)
+ [self.index_dtype]) + [self.index_dtype])
self.kernel.set_scalar_arg_dtypes(scalar_arg_dtypes) self.kernel.set_scalar_arg_dtypes(scalar_arg_dtypes)
def __call__(self, *args, **kwargs): def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
# {{{ argument processing # {{{ argument processing
allocator = kwargs.get("allocator") allocator = kwargs.get("allocator")
...@@ -1668,8 +1694,8 @@ class GenericDebugScanKernel(_GenericScanKernelBase): ...@@ -1668,8 +1694,8 @@ class GenericDebugScanKernel(_GenericScanKernelBase):
wait_for = list(wait_for) wait_for = list(wait_for)
if len(args) != len(self.parsed_args): if len(args) != len(self.parsed_args):
raise TypeError("expected %d arguments, got %d" % raise TypeError(
(len(self.parsed_args), len(args))) f"expected {len(self.parsed_args)} arguments, got {len(args)}")
first_array = args[self.first_array_idx] first_array = args[self.first_array_idx]
allocator = allocator or first_array.allocator allocator = allocator or first_array.allocator
...@@ -1763,15 +1789,23 @@ class ExclusiveScanKernel(_LegacyScanKernelBase): ...@@ -1763,15 +1789,23 @@ class ExclusiveScanKernel(_LegacyScanKernelBase):
# {{{ template # {{{ template
class ScanTemplate(KernelTemplateBase): class ScanTemplate(KernelTemplateBase):
def __init__(self, def __init__(
arguments, input_expr, scan_expr, neutral, output_statement, self,
is_segment_start_expr=None, input_fetch_exprs=None, arguments: Union[str, List[DtypedArgument]],
name_prefix="scan", preamble="", template_processor=None): input_expr: str,
scan_expr: str,
neutral: Optional[str],
output_statement: str,
is_segment_start_expr: Optional[str] = None,
input_fetch_exprs: Optional[List[Tuple[str, str, int]]] = None,
name_prefix: str = "scan",
preamble: str = "",
template_processor: Any = None) -> None:
super().__init__(template_processor=template_processor)
if input_fetch_exprs is None: if input_fetch_exprs is None:
input_fetch_exprs = [] input_fetch_exprs = []
KernelTemplateBase.__init__(self, template_processor=template_processor)
self.arguments = arguments self.arguments = arguments
self.input_expr = input_expr self.input_expr = input_expr
self.scan_expr = scan_expr self.scan_expr = scan_expr
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment