diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index 6e9c92d8fc57d595580402ab5a10b0106ec97ccf..f2c42bc254e36c0e48591b4f2c521ff9fed0c5af 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -22,19 +22,25 @@ limitations under the License.
 Derived from code within the Thrust project, https://github.com/thrust/thrust/
 """
 
+from abc import ABC, abstractmethod
 from dataclasses import dataclass
-from typing import List
+from typing import Any, Dict, List, Optional, Set, Tuple, Union
 
 import numpy as np
 
 import pyopencl as cl
-import pyopencl.array  # noqa
-from pyopencl.tools import (dtype_to_ctype, bitlog2,
-        KernelTemplateBase, _process_code_for_macro,
-        get_arg_list_scalar_arg_dtypes,
+import pyopencl.array
+from pyopencl.tools import (
+        KernelTemplateBase,
+        DtypedArgument,
+        bitlog2,
         context_dependent_memoize,
+        dtype_to_ctype,
+        get_arg_list_scalar_arg_dtypes,
+        get_arg_offset_adjuster_code,
+        _process_code_for_macro,
         _NumpyTypesKeyBuilder,
-        get_arg_offset_adjuster_code)
+        )
 
 import pyopencl._mymako as mako
 from pyopencl._cluda import CLUDA_PREAMBLE
@@ -745,7 +751,7 @@ void ${name_prefix}_final_update(
 
 # {{{ helpers
 
-def _round_down_to_power_of_2(val):
+def _round_down_to_power_of_2(val: int) -> int:
     result = 2**bitlog2(val)
     if result > val:
         result >>= 1
@@ -839,10 +845,11 @@ _IGNORED_WORDS = set("""
         """.split())
 
 
-def _make_template(s):
+def _make_template(s: str):
+    import re
     leftovers = set()
 
-    def replace_id(match):
+    def replace_id(match: "re.Match") -> str:
         # avoid name clashes with user code by adding 'psc_' prefix to
         # identifiers.
 
@@ -850,30 +857,28 @@ def _make_template(s):
         if word in _IGNORED_WORDS:
             return word
         elif word in _PREFIX_WORDS:
-            return "psc_"+word
+            return f"psc_{word}"
         else:
             leftovers.add(word)
             return word
 
-    import re
     s = re.sub(r"\b([a-zA-Z0-9_]+)\b", replace_id, s)
-
     if leftovers:
         from warnings import warn
         warn("leftover words in identifier prefixing: " + " ".join(leftovers))
 
-    return mako.template.Template(s, strict_undefined=True)
+    return mako.template.Template(s, strict_undefined=True)     # type: ignore
 
 
 @dataclass(frozen=True)
 class _GeneratedScanKernelInfo:
     scan_src: str
     kernel_name: str
-    scalar_arg_dtypes: List["np.dtype"]
+    scalar_arg_dtypes: List[Optional[np.dtype]]
     wg_size: int
     k_group_size: int
 
-    def build(self, context, options):
+    def build(self, context: cl.Context, options: Any) -> "_BuiltScanKernelInfo":
         program = cl.Program(context, self.scan_src).build(options)
         kernel = getattr(program, self.kernel_name)
         kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
@@ -894,10 +899,12 @@ class _BuiltScanKernelInfo:
 class _GeneratedFinalUpdateKernelInfo:
     source: str
     kernel_name: str
-    scalar_arg_dtypes: List["np.dtype"]
+    scalar_arg_dtypes: List[Optional[np.dtype]]
     update_wg_size: int
 
-    def build(self, context, options):
+    def build(self,
+              context: cl.Context,
+              options: Any) -> "_BuiltFinalUpdateKernelInfo":
         program = cl.Program(context, self.source).build(options)
         kernel = getattr(program, self.kernel_name)
         kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
@@ -916,14 +923,25 @@ class ScanPerformanceWarning(UserWarning):
     pass
 
 
-class _GenericScanKernelBase:
+class _GenericScanKernelBase(ABC):
     # {{{ constructor, argument processing
 
-    def __init__(self, ctx, dtype,
-            arguments, input_expr, scan_expr, neutral, output_statement,
-            is_segment_start_expr=None, input_fetch_exprs=None,
-            index_dtype=np.int32,
-            name_prefix="scan", options=None, preamble="", devices=None):
+    def __init__(
+            self,
+            ctx: cl.Context,
+            dtype: Any,
+            arguments: Union[str, List[DtypedArgument]],
+            input_expr: str,
+            scan_expr: str,
+            neutral: Optional[str],
+            output_statement: str,
+            is_segment_start_expr: Optional[str] = None,
+            input_fetch_exprs: Optional[List[Tuple[str, str, int]]] = None,
+            index_dtype: Any = np.int32,
+            name_prefix: str = "scan",
+            options: Any = None,
+            preamble: str = "",
+            devices: Optional[cl.Device] = None) -> None:
         """
         :arg ctx: a :class:`pyopencl.Context` within which the code
             for this scan kernel will be generated.
@@ -1114,8 +1132,9 @@ class _GenericScanKernelBase:
 
     # }}}
 
-    def finish_setup(self):
-        raise NotImplementedError
+    @abstractmethod
+    def finish_setup(self) -> None:
+        pass
 
 
 generic_scan_kernel_cache = WriteOncePersistentDict(
@@ -1139,10 +1158,9 @@ class GenericScanKernel(_GenericScanKernelBase):
 
         a = cl.array.arange(queue, 10000, dtype=np.int32)
         knl(a, queue=queue)
-
     """
 
-    def finish_setup(self):
+    def finish_setup(self) -> None:
         # Before generating the kernel, see if it's cached.
         from pyopencl.cache import get_device_cache_id
         devices_key = tuple(get_device_cache_id(device)
@@ -1188,7 +1206,7 @@ class GenericScanKernel(_GenericScanKernelBase):
                 self.context, self.options)
         del self.final_update_gen_info
 
-    def _finish_setup_impl(self):
+    def _finish_setup_impl(self) -> None:
         # {{{ find usable workgroup/k-group size, build first-level scan
 
         trip_count = 0
@@ -1296,7 +1314,7 @@ class GenericScanKernel(_GenericScanKernelBase):
         second_level_arguments = self.parsed_args + [
                 VectorArg(self.dtype, "interval_sums")]
 
-        second_level_build_kwargs = {}
+        second_level_build_kwargs: Dict[str, Optional[str]] = {}
         if self.is_segmented:
             second_level_arguments.append(
                     VectorArg(self.index_dtype,
@@ -1360,12 +1378,14 @@ class GenericScanKernel(_GenericScanKernelBase):
 
     # {{{ scan kernel build/properties
 
-    def get_local_mem_use(self, k_group_size, wg_size, use_bank_conflict_avoidance):
+    def get_local_mem_use(
+            self, k_group_size: int, wg_size: int,
+            use_bank_conflict_avoidance: bool) -> int:
         arg_dtypes = {}
         for arg in self.parsed_args:
             arg_dtypes[arg.name] = arg.dtype
 
-        fetch_expr_offsets = {}
+        fetch_expr_offsets: Dict[str, Set] = {}
         for _name, arg_name, ife_offset in self.input_fetch_exprs:
             fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
 
@@ -1388,10 +1408,17 @@ class GenericScanKernel(_GenericScanKernelBase):
                     for arg_name, ife_offsets in list(fetch_expr_offsets.items())
                     if -1 in ife_offsets or len(ife_offsets) > 1))
 
-    def generate_scan_kernel(self, max_wg_size, arguments, input_expr,
-            is_segment_start_expr, input_fetch_exprs, is_first_level,
-            store_segment_start_flags, k_group_size,
-            use_bank_conflict_avoidance):
+    def generate_scan_kernel(
+            self,
+            max_wg_size: int,
+            arguments: List[DtypedArgument],
+            input_expr: str,
+            is_segment_start_expr: Optional[str],
+            input_fetch_exprs: List[Tuple[str, str, int]],
+            is_first_level: bool,
+            store_segment_start_flags: bool,
+            k_group_size: int,
+            use_bank_conflict_avoidance: bool) -> _GeneratedScanKernelInfo:
         scalar_arg_dtypes = get_arg_list_scalar_arg_dtypes(arguments)
 
         # Empirically found on Nv hardware: no need to be bigger than this size
@@ -1437,7 +1464,7 @@ class GenericScanKernel(_GenericScanKernelBase):
 
     # }}}
 
-    def __call__(self, *args, **kwargs):
+    def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
         # {{{ argument processing
 
         allocator = kwargs.get("allocator")
@@ -1451,8 +1478,8 @@ class GenericScanKernel(_GenericScanKernelBase):
             wait_for = list(wait_for)
 
         if len(args) != len(self.parsed_args):
-            raise TypeError("expected %d arguments, got %d" %
-                    (len(self.parsed_args), len(args)))
+            raise TypeError(
+                f"expected {len(self.parsed_args)} arguments, got {len(args)}")
 
         first_array = args[self.first_array_idx]
         allocator = allocator or first_array.allocator
@@ -1631,7 +1658,7 @@ void ${name_prefix}_debug_scan(
 
 
 class GenericDebugScanKernel(_GenericScanKernelBase):
-    def finish_setup(self):
+    def finish_setup(self) -> None:
         scan_tpl = _make_template(DEBUG_SCAN_TEMPLATE)
         scan_src = str(scan_tpl.render(
             output_statement=self.output_statement,
@@ -1645,15 +1672,14 @@ class GenericDebugScanKernel(_GenericScanKernelBase):
             **self.code_variables))
 
         scan_prg = cl.Program(self.context, scan_src).build(self.options)
-        self.kernel = getattr(
-                scan_prg, self.name_prefix+"_debug_scan")
+        self.kernel = getattr(scan_prg, f"{self.name_prefix}_debug_scan")
         scalar_arg_dtypes = (
                 [None]
                 + get_arg_list_scalar_arg_dtypes(self.parsed_args)
                 + [self.index_dtype])
         self.kernel.set_scalar_arg_dtypes(scalar_arg_dtypes)
 
-    def __call__(self, *args, **kwargs):
+    def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
         # {{{ argument processing
 
         allocator = kwargs.get("allocator")
@@ -1668,8 +1694,8 @@ class GenericDebugScanKernel(_GenericScanKernelBase):
             wait_for = list(wait_for)
 
         if len(args) != len(self.parsed_args):
-            raise TypeError("expected %d arguments, got %d" %
-                    (len(self.parsed_args), len(args)))
+            raise TypeError(
+                f"expected {len(self.parsed_args)} arguments, got {len(args)}")
 
         first_array = args[self.first_array_idx]
         allocator = allocator or first_array.allocator
@@ -1763,15 +1789,23 @@ class ExclusiveScanKernel(_LegacyScanKernelBase):
 # {{{ template
 
 class ScanTemplate(KernelTemplateBase):
-    def __init__(self,
-            arguments, input_expr, scan_expr, neutral, output_statement,
-            is_segment_start_expr=None, input_fetch_exprs=None,
-            name_prefix="scan", preamble="", template_processor=None):
+    def __init__(
+            self,
+            arguments: Union[str, List[DtypedArgument]],
+            input_expr: str,
+            scan_expr: str,
+            neutral: Optional[str],
+            output_statement: str,
+            is_segment_start_expr: Optional[str] = None,
+            input_fetch_exprs: Optional[List[Tuple[str, str, int]]] = None,
+            name_prefix: str = "scan",
+            preamble: str = "",
+            template_processor: Any = None) -> None:
+        super().__init__(template_processor=template_processor)
 
         if input_fetch_exprs is None:
             input_fetch_exprs = []
 
-        KernelTemplateBase.__init__(self, template_processor=template_processor)
         self.arguments = arguments
         self.input_expr = input_expr
         self.scan_expr = scan_expr