diff --git a/pyopencl/tools.py b/pyopencl/tools.py index fb4a91e14f98d3cde4c6b68ceeee4d44979aa3e8..6e18512b95fda69e657224624041a6a1f9508696 100644 --- a/pyopencl/tools.py +++ b/pyopencl/tools.py @@ -72,6 +72,18 @@ Testing .. autofunction:: pytest_generate_tests_for_pyopencl +Argument Types +-------------- + +.. autoclass:: Argument +.. autoclass:: DtypedArgument + +.. autoclass:: VectorArg +.. autoclass:: ScalarArg +.. autoclass:: OtherArg + +.. autofunction:: parse_arg_list + Device Characterization ----------------------- @@ -114,8 +126,9 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - +from abc import ABC, abstractmethod from sys import intern +from typing import Any, List, Union # Do not add a pyopencl import here: This will add an import cycle. @@ -714,27 +727,38 @@ def pytest_generate_tests_for_pyopencl(metafunc): # {{{ C argument lists -class Argument: - pass +class Argument(ABC): + """ + .. automethod:: declarator + """ + + @abstractmethod + def declarator(self) -> str: + pass class DtypedArgument(Argument): - def __init__(self, dtype, name): + """ + .. attribute:: name + .. attribute:: dtype + """ + + def __init__(self, dtype: Any, name: str) -> None: self.dtype = np.dtype(dtype) self.name = name - def __repr__(self): + def __repr__(self) -> str: return "{}({!r}, {})".format( self.__class__.__name__, self.name, self.dtype) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return (type(self) == type(other) and self.dtype == other.dtype and self.name == other.name) - def __hash__(self): + def __hash__(self) -> int: return ( hash(type(self)) ^ hash(self.dtype) @@ -742,11 +766,16 @@ class DtypedArgument(Argument): class VectorArg(DtypedArgument): - def __init__(self, dtype, name, with_offset=False): + """Inherits from :class:`DtypedArgument`. + + .. automethod:: __init__ + """ + + def __init__(self, dtype: Any, name: str, with_offset: bool = False) -> None: super().__init__(dtype, name) self.with_offset = with_offset - def declarator(self): + def declarator(self) -> str: if self.with_offset: # Two underscores -> less likelihood of a name clash. return "__global {} *{}__base, long {}__offset".format( @@ -756,40 +785,42 @@ class VectorArg(DtypedArgument): return result - def __eq__(self, other): + def __eq__(self, other) -> bool: return (super().__eq__(other) and self.with_offset == other.with_offset) - def __hash__(self): + def __hash__(self) -> int: return super().__hash__() ^ hash(self.with_offset) class ScalarArg(DtypedArgument): + """Inherits from :class:`DtypedArgument`.""" + def declarator(self): return "{} {}".format(dtype_to_ctype(self.dtype), self.name) class OtherArg(Argument): - def __init__(self, declarator, name): + def __init__(self, declarator: str, name: str) -> None: self.decl = declarator self.name = name - def declarator(self): + def declarator(self) -> str: return self.decl - def __eq__(self, other): + def __eq__(self, other) -> bool: return (type(self) == type(other) and self.decl == other.decl and self.name == other.name) - def __hash__(self): + def __hash__(self) -> int: return ( hash(type(self)) ^ hash(self.decl) ^ hash(self.name)) -def parse_c_arg(c_arg, with_offset=False): +def parse_c_arg(c_arg: str, with_offset: bool = False) -> DtypedArgument: for aspace in ["__local", "__constant"]: if aspace in c_arg: raise RuntimeError("cannot deal with local or constant " @@ -807,7 +838,9 @@ def parse_c_arg(c_arg, with_offset=False): return parse_c_arg_backend(c_arg, ScalarArg, vec_arg_factory) -def parse_arg_list(arguments, with_offset=False): +def parse_arg_list( + arguments: Union[str, List[str], List[DtypedArgument]], + with_offset: bool = False) -> List[DtypedArgument]: """Parse a list of kernel arguments. *arguments* may be a comma-separate list of C declarators in a string, a list of strings representing C declarators, or :class:`Argument` objects. @@ -816,11 +849,12 @@ def parse_arg_list(arguments, with_offset=False): if isinstance(arguments, str): arguments = arguments.split(",") - def parse_single_arg(obj): + def parse_single_arg(obj: Union[str, DtypedArgument]) -> DtypedArgument: if isinstance(obj, str): from pyopencl.tools import parse_c_arg return parse_c_arg(obj, with_offset=with_offset) else: + assert isinstance(obj, DtypedArgument) return obj return [parse_single_arg(arg) for arg in arguments]