From a71280ac1a70e75eb771c0ecf862ea5249d5da96 Mon Sep 17 00:00:00 2001 From: Mit Kotak Date: Tue, 9 Aug 2022 12:01:09 -0500 Subject: [PATCH 1/8] Add HUGEVAL support for CudaTarget MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Andreas Klöckner --- loopy/target/cuda.py | 6 ++++++ loopy/target/opencl.py | 1 + 2 files changed, 7 insertions(+) diff --git a/loopy/target/cuda.py b/loopy/target/cuda.py index 4a311f887..f87dddc4b 100644 --- a/loopy/target/cuda.py +++ b/loopy/target/cuda.py @@ -316,6 +316,12 @@ class CUDACASTBuilder(CFamilyASTBuilder): callables.update(get_cuda_callables()) return callables + def symbol_manglers(self): + from loopy.target.opencl import opencl_symbol_mangler + return ( + super().symbol_manglers() + [ + opencl_symbol_mangler + ]) # }}} # {{{ top-level codegen diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index 6957d3f96..5fa012d8b 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -437,6 +437,7 @@ def get_opencl_callables(): # {{{ symbol mangler def opencl_symbol_mangler(kernel, name): + # Also being used in loopy.target.cuda.CudaCASTBuilder.symbol_manglers # FIXME: should be more picky about exact names if name.startswith("FLT_"): return NumpyType(np.dtype(np.float32)), name -- GitLab From 4543a232a74cdfd3d94ce43c7914208ff0b3e7c1 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 4 Aug 2022 11:22:24 -0500 Subject: [PATCH 2/8] CUDACASTBuilder -> CudaCASTBuilder loopy.target.cuda conventionally uses Cuda instead of CUDA. --- loopy/target/cuda.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/loopy/target/cuda.py b/loopy/target/cuda.py index f87dddc4b..120147cea 100644 --- a/loopy/target/cuda.py +++ b/loopy/target/cuda.py @@ -228,7 +228,7 @@ class CudaTarget(CFamilyTarget): return True def get_device_ast_builder(self): - return CUDACASTBuilder(self) + return CudaCASTBuilder(self) # {{{ types @@ -304,7 +304,7 @@ def cuda_preamble_generator(preamble_info): # {{{ ast builder -class CUDACASTBuilder(CFamilyASTBuilder): +class CudaCASTBuilder(CFamilyASTBuilder): preamble_function_qualifier = "inline __device__" -- GitLab From 1eb9a3d9e7754e2230d68a46ed67a0716ba39507 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Fri, 5 Aug 2022 00:48:08 -0500 Subject: [PATCH 3/8] fixes bug in int pow to be compatible with all targets --- loopy/target/c/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py index de6a32a68..76696e198 100644 --- a/loopy/target/c/__init__.py +++ b/loopy/target/c/__init__.py @@ -236,7 +236,7 @@ def _preamble_generator(preamble_info, func_qualifier="inline"): }""") yield (f"07_{func.c_name}", f""" - inline {res_ctype} {func.c_name}({base_ctype} x, {exp_ctype} n) {{ + {func_qualifier} {res_ctype} {func.c_name}({base_ctype} x, {exp_ctype} n) {{ if (n == 0) return 1; {re.sub("^", 14*" ", signed_exponent_preamble, flags=re.M)} @@ -254,7 +254,8 @@ def _preamble_generator(preamble_info, func_qualifier="inline"): }} return x*y; - }}""") + }}""" # noqa: E501 + ) # }}} -- GitLab From cf562325834d9fe6942dbb3181b16a76ee4b6d3f Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 4 Aug 2022 19:29:37 -0500 Subject: [PATCH 4/8] Implements PyCudaTarget --- loopy/__init__.py | 3 +- loopy/target/pycuda.py | 351 ++++++++++++++++++++++++++++ loopy/target/pycuda_execution.py | 390 +++++++++++++++++++++++++++++++ setup.cfg | 3 + 4 files changed, 746 insertions(+), 1 deletion(-) create mode 100644 loopy/target/pycuda.py create mode 100644 loopy/target/pycuda_execution.py diff --git a/loopy/__init__.py b/loopy/__init__.py index ce3ba1439..c6c5b1dde 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -157,6 +157,7 @@ from loopy.target.c import (CFamilyTarget, CTarget, ExecutableCTarget, from loopy.target.cuda import CudaTarget from loopy.target.opencl import OpenCLTarget from loopy.target.pyopencl import PyOpenCLTarget +from loopy.target.pycuda import PyCudaTarget from loopy.target.ispc import ISPCTarget from loopy.tools import Optional, t_unit_to_python, memoize_on_disk @@ -303,7 +304,7 @@ __all__ = [ "CWithGNULibcTarget", "ExecutableCWithGNULibcTarget", "CudaTarget", "OpenCLTarget", "PyOpenCLTarget", "ISPCTarget", - "ASTBuilderBase", + "PyCudaTarget", "ASTBuilderBase", "Optional", "memoize_on_disk", diff --git a/loopy/target/pycuda.py b/loopy/target/pycuda.py new file mode 100644 index 000000000..ace4706c8 --- /dev/null +++ b/loopy/target/pycuda.py @@ -0,0 +1,351 @@ +"""CUDA target integrated with PyCUDA.""" + +__copyright__ = """ +Copyright (C) 2015 Andreas Kloeckner +Copyright (C) 2022 Kaushik Kulkarni +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import numpy as np +import pymbolic.primitives as p +import genpy + +from loopy.target.cuda import (CudaTarget, CudaCASTBuilder, + ExpressionToCudaCExpressionMapper) +from loopy.target.python import PythonASTBuilderBase +from typing import Sequence, Tuple +from loopy.codegen import CodeGenerationState +from loopy.codegen.result import CodeGenerationResult + +import logging +logger = logging.getLogger(__name__) + + +# {{{ preamble generator + +def pycuda_preamble_generator(preamble_info): + has_complex = False + + for dtype in preamble_info.seen_dtypes: + if dtype.involves_complex(): + has_complex = True + + if has_complex: + yield ("03_include_complex_header", """ + #include + """) + +# }}} + + +# {{{ expression mapper + +def _get_complex_tmplt_arg(dtype) -> str: + if dtype == np.complex128: + return "double" + elif dtype == np.complex64: + return "float" + else: + raise RuntimeError(f"unsupported complex type {dtype}.") + + +class ExpressionToPyCudaCExpressionMapper(ExpressionToCudaCExpressionMapper): + """ + .. note:: + + - PyCUDA (very conveniently) provides access to complex arithmetic + headers which is not the default in CUDA. To access such additional + features we introduce this mapper. + """ + def wrap_in_typecast_lazy(self, actual_type_func, needed_dtype, s): + if needed_dtype.is_complex(): + return self.wrap_in_typecast(actual_type_func(), needed_dtype, s) + else: + return super().wrap_in_typecast_lazy(actual_type_func, + needed_dtype, s) + + def wrap_in_typecast(self, actual_type, needed_dtype, s): + if not actual_type.is_complex() and needed_dtype.is_complex(): + tmplt_arg = _get_complex_tmplt_arg(needed_dtype.numpy_dtype) + return p.Variable(f"pycuda::complex<{tmplt_arg}>")(s) + else: + return super().wrap_in_typecast_lazy(actual_type, + needed_dtype, s) + + def map_constant(self, expr, type_context): + if isinstance(expr, (complex, np.complexfloating)): + try: + dtype = expr.dtype + except AttributeError: + # (COMPLEX_GUESS_LOGIC) This made it through type 'guessing' in + # type inference, and it was concluded there (search for + # COMPLEX_GUESS_LOGIC in loopy.type_inference), that no + # accuracy was lost by using single precision. + dtype = np.complex64 + else: + tmplt_arg = _get_complex_tmplt_arg(dtype) + + return p.Variable(f"pycuda::complex<{tmplt_arg}>")(self.rec(expr.real, + type_context), + self.rec(expr.imag, + type_context)) + + return super().map_constant(expr, type_context) + +# }}} + + +# {{{ target + +class PyCudaTarget(CudaTarget): + """A code generation target that takes special advantage of :mod:`pycuda` + features such as run-time knowledge of the target device (to generate + warnings) and support for complex numbers. + """ + + # FIXME make prefixes conform to naming rules + # (see Reference: Loopy’s Model of a Kernel) + + host_program_name_prefix = "_lpy_host_" + host_program_name_suffix = "" + + def __init__(self, pycuda_module_name="_lpy_cuda"): + # import pycuda.tools import to populate the TYPE_REGISTRY + import pycuda.tools # noqa: F401 + super().__init__() + self.pycuda_module_name = pycuda_module_name + + # NB: Not including 'device', as that is handled specially here. + hash_fields = CudaTarget.hash_fields + ( + "pycuda_module_name",) + comparison_fields = CudaTarget.comparison_fields + ( + "pycuda_module_name",) + + def get_host_ast_builder(self): + return PyCudaPythonASTBuilder(self) + + def get_device_ast_builder(self): + return PyCudaCASTBuilder(self) + + def get_kernel_executor_cache_key(self, **kwargs): + return (kwargs["entrypoint"],) + + def get_dtype_registry(self): + from pycuda.compyte.dtypes import TYPE_REGISTRY + return TYPE_REGISTRY + + def preprocess_translation_unit_for_passed_args(self, t_unit, epoint, + passed_args_dict): + + # {{{ ValueArgs -> GlobalArgs if passed as array shapes + + from loopy.kernel.data import ValueArg, GlobalArg + import pycuda.gpuarray as cu_np + + knl = t_unit[epoint] + new_args = [] + + for arg in knl.args: + if isinstance(arg, ValueArg): + if (arg.name in passed_args_dict + and isinstance(passed_args_dict[arg.name], cu_np.GPUArray) + and passed_args_dict[arg.name].shape == ()): + arg = GlobalArg(name=arg.name, dtype=arg.dtype, shape=(), + is_output=False, is_input=True) + + new_args.append(arg) + + knl = knl.copy(args=new_args) + + t_unit = t_unit.with_kernel(knl) + + # }}} + + return t_unit + + def get_kernel_executor(self, t_unit, **kwargs): + from loopy.target.pycuda_execution import PyCudaKernelExecutor + + epoint = kwargs.pop("entrypoint") + t_unit = self.preprocess_translation_unit_for_passed_args(t_unit, + epoint, + kwargs) + + return PyCudaKernelExecutor(t_unit, entrypoint=epoint) + +# }}} + + +# {{{ host ast builder + +class PyCudaPythonASTBuilder(PythonASTBuilderBase): + """A Python host AST builder for integration with PyCuda. + """ + + # {{{ code generation guts + + def get_function_definition( + self, codegen_state, codegen_result, + schedule_index: int, function_decl, function_body: genpy.Generable + ) -> genpy.Function: + assert schedule_index == 0 + + from loopy.schedule.tools import get_kernel_arg_info + kai = get_kernel_arg_info(codegen_state.kernel) + + args = ( + ["_lpy_cuda_functions"] + + [arg_name for arg_name in kai.passed_arg_names] + + ["wait_for=()", "allocator=None", "stream=None"]) + + from genpy import (For, Function, Suite, Return, Line, Statement as S) + return Function( + codegen_result.current_program(codegen_state).name, + args, + Suite([ + Line(), + ] + [ + Line(), + function_body, + Line(), + ] + ([ + For("_tv", "_global_temporaries", + # Free global temporaries. + # Zero-size temporaries allocate as None, tolerate that. + S("if _tv is not None: _tv.free()")) + ] if self._get_global_temporaries(codegen_state) else [] + ) + [ + Line(), + Return("_lpy_evt"), + ])) + + def get_function_declaration( + self, codegen_state: CodeGenerationState, + codegen_result: CodeGenerationResult, schedule_index: int + ) -> Tuple[Sequence[Tuple[str, str]], genpy.Generable]: + # no such thing in Python + return [], None + + def _get_global_temporaries(self, codegen_state): + from loopy.kernel.data import AddressSpace + + return sorted( + (tv for tv in codegen_state.kernel.temporary_variables.values() + if tv.address_space == AddressSpace.GLOBAL), + key=lambda tv: tv.name) + + def get_temporary_decls(self, codegen_state, schedule_index): + from genpy import Assign, Comment, Line + + from pymbolic.mapper.stringifier import PREC_NONE + ecm = self.get_expression_to_code_mapper(codegen_state) + + global_temporaries = self._get_global_temporaries(codegen_state) + if not global_temporaries: + return [] + + allocated_var_names = [] + code_lines = [] + code_lines.append(Line()) + code_lines.append(Comment("{{{ allocate global temporaries")) + code_lines.append(Line()) + + for tv in global_temporaries: + if not tv.base_storage: + nbytes_str = ecm(tv.nbytes, PREC_NONE, "i") + allocated_var_names.append(tv.name) + code_lines.append(Assign(tv.name, + f"allocator({nbytes_str})")) + + code_lines.append(Assign("_global_temporaries", "[{tvs}]".format( + tvs=", ".join(tv for tv in allocated_var_names)))) + + code_lines.append(Line()) + code_lines.append(Comment("}}}")) + code_lines.append(Line()) + + return code_lines + + def get_kernel_call(self, + codegen_state, subkernel_name, grid, block): + from genpy import Suite, Assign, Line, Comment, Statement + from pymbolic.mapper.stringifier import PREC_NONE + + from loopy.schedule.tools import get_subkernel_arg_info + skai = get_subkernel_arg_info( + codegen_state.kernel, subkernel_name) + + # make grid/block a 3-tuple + grid = grid + (1,) * (3-len(grid)) + block = block + (1,) * (3-len(block)) + global_size = grid[0] * grid[1] * grid[2] * block[0] * block[1] * block[2] + ecm = self.get_expression_to_code_mapper(codegen_state) + + grid_str = ecm(grid, prec=PREC_NONE, type_context="i") + block_str = ecm(block, prec=PREC_NONE, type_context="i") + global_size_str = ecm(global_size, prec=PREC_NONE, type_context="i") + + return Suite([ + Comment("{{{ launch %s" % subkernel_name), + Line(), + Statement("for _lpy_cu_evt in wait_for: _lpy_cu_evt.synchronize()"), + Line(), + Assign("_lpy_knl", f"_lpy_cuda_functions['{subkernel_name}']"), + Line(), + + Statement(f"if {global_size_str}: _lpy_knl.prepared_async_call(" + f"{grid_str}, {block_str}, " + "stream, " + f"{', '.join(arg_name for arg_name in skai.passed_names)}" + ")",), + Assign("_lpy_evt", "_lpy_cuda.Event().record(stream)"), + Assign("wait_for", "[_lpy_evt]"), + Line(), + Comment("}}}"), + Line(), + ]) + + # }}} + +# }}} + + +# {{{ device ast builder + +class PyCudaCASTBuilder(CudaCASTBuilder): + """A C device AST builder for integration with PyCUDA. + """ + + # {{{ library + + def preamble_generators(self): + return ([pycuda_preamble_generator] + + super().preamble_generators()) + + # }}} + + def get_expression_to_c_expression_mapper(self, codegen_state): + return ExpressionToPyCudaCExpressionMapper(codegen_state) + +# }}} + +# vim: foldmethod=marker diff --git a/loopy/target/pycuda_execution.py b/loopy/target/pycuda_execution.py new file mode 100644 index 000000000..fc8135bc8 --- /dev/null +++ b/loopy/target/pycuda_execution.py @@ -0,0 +1,390 @@ +__copyright__ = """ +Copyright (C) 2012 Andreas Kloeckner +Copyright (C) 2022 Kaushik Kulkarni +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +from typing import (Sequence, Tuple, Union, Callable, Any, Optional, + TYPE_CHECKING) +from dataclasses import dataclass + +import numpy as np +from immutables import Map + +from pytools import memoize_method +from pytools.codegen import Indentation, CodeGenerator + +from loopy.types import LoopyType +from loopy.typing import ExpressionT +from loopy.kernel import LoopKernel +from loopy.kernel.data import ArrayArg +from loopy.translation_unit import TranslationUnit +from loopy.schedule.tools import KernelArgInfo +from loopy.target.execution import ( + KernelExecutorBase, ExecutionWrapperGeneratorBase) +import logging +logger = logging.getLogger(__name__) + + +if TYPE_CHECKING: + import pycuda.driver as cuda + + +# {{{ invoker generation + +# /!\ This code runs in a namespace controlled by the user. +# Prefix all auxiliary variables with "_lpy". + + +class PyCudaExecutionWrapperGenerator(ExecutionWrapperGeneratorBase): + """ + Specialized form of the :class:`ExecutionWrapperGeneratorBase` for + pycuda execution + """ + + def __init__(self): + system_args = [ + "_lpy_cuda_functions", "stream=None", "allocator=None", "wait_for=()", + # ignored if options.no_numpy + "out_host=None" + ] + super().__init__(system_args) + + def python_dtype_str_inner(self, dtype): + from pycuda.tools import dtype_to_ctype + # Test for types built into numpy. dtype.isbuiltin does not work: + # https://github.com/numpy/numpy/issues/4317 + # Guided by https://numpy.org/doc/stable/reference/arrays.scalars.html + if issubclass(dtype.type, (np.bool_, np.number)): + name = dtype.name + if dtype.type == np.bool_: + name = "bool8" + return f"_lpy_np.dtype(_lpy_np.{name})" + else: + return ('_lpy_cuda_tools.get_or_register_dtype("%s")' + % dtype_to_ctype(dtype)) + + # {{{ handle non-numpy args + + def handle_non_numpy_arg(self, gen, arg): + gen("if isinstance(%s, _lpy_np.ndarray):" % arg.name) + with Indentation(gen): + gen("# retain originally passed array") + gen(f"_lpy_{arg.name}_np_input = {arg.name}") + gen("# synchronous, nothing to worry about") + gen("%s = _lpy_cuda_array.to_gpu_async(" + "%s, allocator=allocator, stream=stream)" + % (arg.name, arg.name)) + gen("_lpy_encountered_numpy = True") + gen("elif %s is not None:" % arg.name) + with Indentation(gen): + gen("_lpy_encountered_dev = True") + gen("_lpy_%s_np_input = None" % arg.name) + gen("else:") + with Indentation(gen): + gen("_lpy_%s_np_input = None" % arg.name) + + gen("") + + # }}} + + # {{{ handle allocation of unspecified arguments + + def handle_alloc( + self, gen: CodeGenerator, arg: ArrayArg, + strify: Callable[[Union[ExpressionT, Tuple[ExpressionT]]], str], + skip_arg_checks: bool) -> None: + """ + Handle allocation of non-specified arguments for pycuda execution + """ + from pymbolic import var + + from loopy.kernel.array import get_strides + strides = get_strides(arg) + num_axes = len(strides) + + assert arg.dtype is not None + itemsize = arg.dtype.numpy_dtype.itemsize + for i in range(num_axes): + gen("_lpy_ustrides_%d = %s" % (i, strify(strides[i]))) + + if not skip_arg_checks: + for i in range(num_axes): + gen("assert _lpy_ustrides_%d >= 0, " + "\"'%s' has negative stride in axis %d\"" + % (i, arg.name, i)) + + assert isinstance(arg.shape, tuple) + sym_ustrides = tuple( + var("_lpy_ustrides_%d" % i) + for i in range(num_axes)) + sym_shape = tuple(arg.shape[i] for i in range(num_axes)) + + size_expr = (sum(astrd*(alen-1) + for alen, astrd in zip(sym_shape, sym_ustrides)) + + 1) + + gen("_lpy_size = %s" % strify(size_expr)) + sym_strides = tuple(itemsize*s_i for s_i in sym_ustrides) + + dtype_name = self.python_dtype_str(gen, arg.dtype.numpy_dtype) + gen(f"{arg.name} = _lpy_cuda_array.GPUArray({strify(sym_shape)}, " + f"{dtype_name}, strides={strify(sym_strides)}, " + f"gpudata=allocator({strify(itemsize * var('_lpy_size'))}), " + "allocator=allocator)") + + for i in range(num_axes): + gen("del _lpy_ustrides_%d" % i) + gen("del _lpy_size") + gen("") + + # }}} + + def target_specific_preamble(self, gen): + """ + Add default pycuda imports to preamble + """ + gen.add_to_preamble("import numpy as _lpy_np") + gen.add_to_preamble("import pycuda.driver as _lpy_cuda") + gen.add_to_preamble("import pycuda.gpuarray as _lpy_cuda_array") + gen.add_to_preamble("import pycuda.tools as _lpy_cuda_tools") + from loopy.target.c.c_execution import DEF_EVEN_DIV_FUNCTION + gen.add_to_preamble(DEF_EVEN_DIV_FUNCTION) + + def initialize_system_args(self, gen): + """ + Initializes possibly empty system arguments + """ + gen("if allocator is None:") + with Indentation(gen): + gen("allocator =" + " lambda nbytes: _lpy_cuda.mem_alloc(nbytes) if nbytes else 0") + gen("") + + # {{{ generate invocation + + def generate_invocation(self, gen: CodeGenerator, kernel: LoopKernel, + kai: KernelArgInfo, host_program_name: str, args: Sequence[str]) -> None: + arg_list = (["_lpy_cuda_functions"] + + list(args) + + ["stream=stream", "wait_for=wait_for", "allocator=allocator"]) + gen(f"_lpy_evt = {host_program_name}({', '.join(arg_list)})") + + # }}} + + # {{{ generate_output_handler + + def generate_output_handler(self, gen: CodeGenerator, + kernel: LoopKernel, kai: KernelArgInfo) -> None: + options = kernel.options + + if not options.no_numpy: + gen("if out_host is None and (_lpy_encountered_numpy " + "and not _lpy_encountered_dev):") + with Indentation(gen): + gen("out_host = True") + + for arg_name in kai.passed_arg_names: + arg = kernel.arg_dict[arg_name] + if arg.is_output: + np_name = "_lpy_%s_np_input" % arg.name + gen("if out_host or %s is not None:" % np_name) + with Indentation(gen): + gen("%s = %s.get(stream=stream, ary=%s)" + % (arg.name, arg.name, np_name)) + + gen("") + + if options.return_dict: + gen("return _lpy_evt, {%s}" + % ", ".join(f'"{arg_name}": {arg_name}' + for arg_name in kai.passed_arg_names + if kernel.arg_dict[arg_name].is_output)) + else: + out_names = [arg_name for arg_name in kai.passed_arg_names + if kernel.arg_dict[arg_name].is_output] + if out_names: + gen("return _lpy_evt, (%s,)" + % ", ".join(out_names)) + else: + gen("return _lpy_evt, ()") + + # }}} + + def generate_host_code(self, gen, codegen_result): + gen.add_to_preamble(codegen_result.host_code()) + + def get_arg_pass(self, arg): + return "%s.__cuda_array_interface__['data'][0]" % arg.name + +# }}} + + +@dataclass(frozen=True) +class _KernelInfo: + t_unit: TranslationUnit + cuda_functions: Map[str, "cuda.Function"] + invoker: Callable[..., Any] + + +# {{{ kernel executor + + +def _get_arg_dtypes(knl, subkernel_name): + from loopy.schedule.tools import get_subkernel_arg_info + from loopy.kernel.data import ValueArg + + skai = get_subkernel_arg_info(knl, subkernel_name) + arg_dtypes = [] + for arg in skai.passed_names: + if arg in skai.passed_inames: + arg_dtypes.append(knl.index_dtype.numpy_dtype) + elif arg in skai.passed_temporaries: + arg_dtypes.append("P") + else: + assert arg in knl.arg_dict + if isinstance(knl.arg_dict[arg], ValueArg): + arg_dtypes.append(knl.arg_dict[arg].dtype.numpy_dtype) + else: + # Array Arg + arg_dtypes.append("P") + + return arg_dtypes + + +class PyCudaKernelExecutor(KernelExecutorBase): + """ + An object connecting a kernel to a :mod:`pycuda` + for execution. + + .. automethod:: __init__ + .. automethod:: __call__ + """ + + def get_invoker_uncached(self, t_unit, entrypoint, codegen_result): + generator = PyCudaExecutionWrapperGenerator() + return generator(t_unit, entrypoint, codegen_result) + + def get_wrapper_generator(self): + return PyCudaExecutionWrapperGenerator() + + @memoize_method + def translation_unit_info(self, + arg_to_dtype: Optional[Map[str, LoopyType]] = None + ) -> _KernelInfo: + t_unit = self.get_typed_and_scheduled_translation_unit(self.entrypoint, + arg_to_dtype) + + # FIXME: now just need to add the types to the arguments + from loopy.codegen import generate_code_v2 + from loopy.target.execution import get_highlighted_code + codegen_result = generate_code_v2(t_unit) + + dev_code = codegen_result.device_code() + epoint_knl = t_unit[self.entrypoint] + + if t_unit[self.entrypoint].options.write_code: + #FIXME: redirect to "translation unit" level option as well. + output = dev_code + if self.t_unit[self.entrypoint].options.allow_terminal_colors: + output = get_highlighted_code(output) + + if epoint_knl.options.write_code is True: + print(output) + else: + with open(epoint_knl.options.write_code, "w") as outf: + outf.write(output) + + if epoint_knl.options.edit_code: + #FIXME: redirect to "translation unit" level option as well. + from pytools import invoke_editor + dev_code = invoke_editor(dev_code, "code.cu") + + from pycuda.compiler import SourceModule + from loopy.kernel.tools import get_subkernels + + #FIXME: redirect to "translation unit" level option as well. + src_module = SourceModule(dev_code, + options=epoint_knl.options.build_options) + + cuda_functions = Map({name: (src_module + .get_function(name) + .prepare(_get_arg_dtypes(epoint_knl, name)) + ) + for name in get_subkernels(epoint_knl)}) + return _KernelInfo( + t_unit=t_unit, + cuda_functions=cuda_functions, + invoker=self.get_invoker(t_unit, self.entrypoint, codegen_result)) + + def __call__(self, *, + stream=None, allocator=None, wait_for=(), out_host=None, + **kwargs): + """ + :arg allocator: a callable that accepts a byte count and returns + an instance of :class:`pycuda.driver.DeviceAllocation`. Typically + one of :func:`pycuda.driver.mem_alloc` or + :meth:`pycuda.tools.DeviceMemoryPool.allocate`. + :arg wait_for: A sequence of :class:`pycuda.driver.Event` instances + for which to wait before launching the CUDA kernels. + :arg out_host: :class:`bool` + Decides whether output arguments (i.e. arguments + written by the kernel) are to be returned as + :mod:`numpy` arrays. *True* for yes, *False* for no. + + For the default value of *None*, if all (input) array + arguments are :mod:`numpy` arrays, defaults to + returning :mod:`numpy` arrays as well. + + :returns: ``(evt, output)`` where *evt* is a + :class:`pycuda.driver.Event` that is recorded right after the + kernel has been launched and output is a tuple of output arguments + (arguments that are written as part of the kernel). The order is + given by the order of kernel arguments. If this order is + unspecified (such as when kernel arguments are inferred + automatically), enable :attr:`loopy.Options.return_dict` to make + *output* a :class:`dict` instead, with keys of argument names and + values of the returned arrays. + """ + + if "entrypoint" in kwargs: + assert kwargs.pop("entrypoint") == self.entrypoint + from warnings import warn + warn("Obtained a redundant argument 'entrypoint'. This will" + " be an error in 2023.", DeprecationWarning, stacklevel=2) + + if __debug__: + self.check_for_required_array_arguments(kwargs.keys()) + + if self.packing_controller is not None: + kwargs = self.packing_controller(kwargs) + + translation_unit_info = self.translation_unit_info(self.arg_to_dtype(kwargs)) + + return translation_unit_info.invoker( + translation_unit_info.cuda_functions, stream, allocator, wait_for, + out_host, **kwargs) + +# }}} + +# vim: foldmethod=marker diff --git a/setup.cfg b/setup.cfg index 077a856f9..8a938c633 100644 --- a/setup.cfg +++ b/setup.cfg @@ -61,3 +61,6 @@ ignore_missing_imports = True [mypy-IPython.*] ignore_missing_imports = True + +[mypy-pycuda.*] +ignore_missing_imports = True -- GitLab From 67f974f815363c34bf0f2bd98ff34a927fd71d41 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 7 Aug 2022 12:35:07 -0500 Subject: [PATCH 5/8] adds pycuda-specific callables --- loopy/target/pycuda.py | 67 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/loopy/target/pycuda.py b/loopy/target/pycuda.py index ace4706c8..cce1f0af4 100644 --- a/loopy/target/pycuda.py +++ b/loopy/target/pycuda.py @@ -35,6 +35,9 @@ from loopy.target.python import PythonASTBuilderBase from typing import Sequence, Tuple from loopy.codegen import CodeGenerationState from loopy.codegen.result import CodeGenerationResult +from loopy.target.c import CMathCallable +from loopy.diagnostic import LoopyError +from loopy.types import NumpyType import logging logger = logging.getLogger(__name__) @@ -57,6 +60,64 @@ def pycuda_preamble_generator(preamble_info): # }}} +# {{{ PyCudaCallable + +class PyCudaCallable(CMathCallable): + def with_types(self, arg_id_to_dtype, callables_table): + if any(dtype.is_complex() for dtype in arg_id_to_dtype.values()): + if self.name in ["abs", "real", "imag"]: + if not (set(arg_id_to_dtype) <= {0, -1}): + raise LoopyError(f"'{self.name}' takes only one argument") + if arg_id_to_dtype.get(0) is None: + # not specialized enough + return (self.copy(arg_id_to_dtype=arg_id_to_dtype), + callables_table) + else: + real_dtype = np.empty(0, + arg_id_to_dtype[0].numpy_dtype).real.dtype + arg_id_to_dtype = arg_id_to_dtype.copy() + arg_id_to_dtype[-1] = NumpyType(real_dtype) + return (self.copy(arg_id_to_dtype=arg_id_to_dtype, + name_in_target=self.name), + callables_table) + elif self.name in ["sqrt", "conj", + "sin", "cos", "tan", + "sinh", "cosh", "tanh", "exp", + "log", "log10"]: + if not (set(arg_id_to_dtype) <= {0, -1}): + raise LoopyError(f"'{self.name}' takes only one argument") + if arg_id_to_dtype.get(0) is None: + # not specialized enough + return (self.copy(arg_id_to_dtype=arg_id_to_dtype), + callables_table) + else: + arg_id_to_dtype = arg_id_to_dtype.copy() + arg_id_to_dtype[-1] = arg_id_to_dtype[0] + return (self.copy(arg_id_to_dtype=arg_id_to_dtype, + name_in_target=self.name), + callables_table) + else: + raise LoopyError(f"'{self.name}' does not take complex" + " arguments.") + else: + if self.name in ["real", "imag", "conj"]: + if arg_id_to_dtype.get(0): + raise NotImplementedError("'{self.name}' for real arguments" + ", not yet supported.") + return super().with_types(arg_id_to_dtype, callables_table) + + +def get_pycuda_callables(): + cmath_ids = ["abs", "acos", "asin", "atan", "cos", "cosh", "sin", + "sinh", "pow", "atan2", "tanh", "exp", "log", "log10", + "sqrt", "ceil", "floor", "max", "min", "fmax", "fmin", + "fabs", "tan", "erf", "erfc", "isnan", "real", "imag", + "conj"] + return {id_: PyCudaCallable(id_) for id_ in cmath_ids} + +# }}} + + # {{{ expression mapper def _get_complex_tmplt_arg(dtype) -> str: @@ -341,6 +402,12 @@ class PyCudaCASTBuilder(CudaCASTBuilder): return ([pycuda_preamble_generator] + super().preamble_generators()) + @property + def known_callables(self): + callables = super().known_callables + callables.update(get_pycuda_callables()) + return callables + # }}} def get_expression_to_c_expression_mapper(self, codegen_state): -- GitLab From 6c88c8538ea7de838f99b24183800fff31a9d89e Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Fri, 5 Aug 2022 00:49:22 -0500 Subject: [PATCH 6/8] test pycuda execution support --- .gitlab-ci.yml | 2 +- test/test_pycuda_invoker.py | 261 ++++++++++++++++++++++++++++++++++++ 2 files changed, 262 insertions(+), 1 deletion(-) create mode 100644 test/test_pycuda_invoker.py diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 1f9786949..49cd96ee6 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -23,7 +23,7 @@ Pytest Nvidia Titan V: stage: test script: - export PYOPENCL_TEST=nvi:titan - - export EXTRA_INSTALL="pybind11 numpy mako" + - export EXTRA_INSTALL="pybind11 numpy mako git+https://github.com/inducer/pycuda.git" - export LOOPY_NO_CACHE=1 - source /opt/enable-intel-cl.sh - curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh diff --git a/test/test_pycuda_invoker.py b/test/test_pycuda_invoker.py new file mode 100644 index 000000000..8d7afc013 --- /dev/null +++ b/test/test_pycuda_invoker.py @@ -0,0 +1,261 @@ +__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import sys +import numpy as np +import loopy as lp +import pytest +pytest.importorskip("pycuda") +import pycuda.gpuarray as cu_np +import itertools + +import logging +logger = logging.getLogger(__name__) + +try: + import faulthandler +except ImportError: + pass +else: + faulthandler.enable() + +from typing import Tuple, Any +from pycuda.tools import init_cuda_context_fixture +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa + + +@pytest.fixture(autouse=True) +def init_cuda_context(): + yield from init_cuda_context_fixture() + + +def get_random_array(rng, shape: Tuple[int, ...], dtype: np.dtype[Any]): + if np.issubdtype(dtype, np.complexfloating): + subdtype = np.empty(0, dtype=dtype).real.dtype + return (get_random_array(rng, shape, subdtype) + + dtype.type(1j) * get_random_array(rng, shape, subdtype)) + else: + assert np.issubdtype(dtype, np.floating) + return rng.random(shape, dtype=dtype) + + +def test_pycuda_invoker(): + m = 5 + n = 6 + + knl = lp.make_kernel( + "{[i, j]: 0<=i tmp[i] = sin(x[i]) + z[i] = 2 * tmp[i] + """, + target=lp.PyCudaTarget()) + knl = lp.set_temporary_address_space(knl, "tmp", lp.AddressSpace.GLOBAL) + + evt, (out,) = knl(x=x, out_host=False) + np.testing.assert_allclose(2*np.sin(x), out.get(), rtol=1e-6) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_multi_entrypoints(dtype): + rng = np.random.default_rng(seed=314) + x = rng.random(42, dtype=dtype) + + knl1 = lp.make_kernel( + "{[i]: 0<=i tmp[i] = 21*sin(x[i]) + 864.5*cos(y[i]) + z[i] = 2 * tmp[i] + """, + [lp.GlobalArg("x,y", + offset=lp.auto, shape=lp.auto), + ...], + target=lp.PyCudaTarget()) + knl = lp.set_temporary_address_space(knl, "tmp", lp.AddressSpace.GLOBAL) + + evt, (out,) = knl(x=x, y=y) + np.testing.assert_allclose(42*np.sin(x) + 1729*np.cos(y), out) + + +@pytest.mark.parametrize("dtype,rtol", [(np.complex64, 1e-6), + (np.complex128, 1e-14), + (np.float32, 1e-6), + (np.float64, 1e-14)]) +def test_sum_of_array(dtype, rtol): + # Reported by Mit Kotak + rng = np.random.default_rng(seed=0) + knl = lp.make_kernel( + "{[i]: 0 <= i < N}", + """ + out = sum(i, x[i]) + """, + target=lp.PyCudaTarget()) + x = get_random_array(rng, (42,), np.dtype(dtype)) + evt, (out,) = knl(x=x) + np.testing.assert_allclose(np.sum(x), out, rtol=rtol) + + +@pytest.mark.parametrize("dtype,rtol", [(np.complex64, 1e-6), + (np.complex128, 1e-14), + (np.float32, 1e-6), + (np.float64, 1e-14)]) +def test_int_pow(dtype, rtol): + rng = np.random.default_rng(seed=0) + knl = lp.make_kernel( + "{[i]: 0 <= i < N}", + """ + out[i] = x[i] ** i + """, + target=lp.PyCudaTarget()) + x = get_random_array(rng, (10,), np.dtype(dtype)) + evt, (out,) = knl(x=x) + np.testing.assert_allclose(x ** np.arange(10, dtype=np.int32), out, + rtol=rtol) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__]) + +# vim: foldmethod=marker -- GitLab From 178cbe5ea5501386e293c9370ff699dfac215c18 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 7 Aug 2022 12:42:41 -0500 Subject: [PATCH 7/8] test math functions on complex dtyped args --- test/test_pycuda_invoker.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/test_pycuda_invoker.py b/test/test_pycuda_invoker.py index 8d7afc013..1c4449fe3 100644 --- a/test/test_pycuda_invoker.py +++ b/test/test_pycuda_invoker.py @@ -251,6 +251,26 @@ def test_int_pow(dtype, rtol): rtol=rtol) +@pytest.mark.parametrize("dtype", [np.complex64, np.complex128, + np.float32, np.float64]) +@pytest.mark.parametrize("func", ["abs", "sqrt", + "sin", "cos", "tan", + "sinh", "cosh", "tanh", + "exp", "log", "log10"]) +def test_math_functions(dtype, func): + rng = np.random.default_rng(seed=0) + knl = lp.make_kernel( + "{[i]: 0 <= i < N}", + f""" + y[i] = {func}(x[i]) + """, + target=lp.PyCudaTarget()) + x = get_random_array(rng, (42,), np.dtype(dtype)) + _, (out,) = knl(x=x) + np.testing.assert_allclose(getattr(np, func)(x), + out, rtol=1e-6) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab From e694703c8383e48d2da04825380c5fe572760c45 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 8 Aug 2022 16:29:58 -0500 Subject: [PATCH 8/8] adds PyCudaWithPackedArgsTarget --- loopy/__init__.py | 4 +- loopy/target/pycuda.py | 242 +++++++++++++++++++++++++++++++ loopy/target/pycuda_execution.py | 53 ++++--- test/test_pycuda_invoker.py | 85 ++++++++--- 4 files changed, 337 insertions(+), 47 deletions(-) diff --git a/loopy/__init__.py b/loopy/__init__.py index c6c5b1dde..2ea7bf1e9 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -157,7 +157,7 @@ from loopy.target.c import (CFamilyTarget, CTarget, ExecutableCTarget, from loopy.target.cuda import CudaTarget from loopy.target.opencl import OpenCLTarget from loopy.target.pyopencl import PyOpenCLTarget -from loopy.target.pycuda import PyCudaTarget +from loopy.target.pycuda import PyCudaTarget, PyCudaWithPackedArgsTarget from loopy.target.ispc import ISPCTarget from loopy.tools import Optional, t_unit_to_python, memoize_on_disk @@ -304,7 +304,7 @@ __all__ = [ "CWithGNULibcTarget", "ExecutableCWithGNULibcTarget", "CudaTarget", "OpenCLTarget", "PyOpenCLTarget", "ISPCTarget", - "PyCudaTarget", "ASTBuilderBase", + "PyCudaTarget", "PyCudaWithPackedArgsTarget", "ASTBuilderBase", "Optional", "memoize_on_disk", diff --git a/loopy/target/pycuda.py b/loopy/target/pycuda.py index cce1f0af4..86995c08b 100644 --- a/loopy/target/pycuda.py +++ b/loopy/target/pycuda.py @@ -38,6 +38,9 @@ from loopy.codegen.result import CodeGenerationResult from loopy.target.c import CMathCallable from loopy.diagnostic import LoopyError from loopy.types import NumpyType +from loopy.codegen import CodeGenerationState +from loopy.codegen.result import CodeGenerationResult +from cgen import Generable import logging logger = logging.getLogger(__name__) @@ -253,6 +256,25 @@ class PyCudaTarget(CudaTarget): return PyCudaKernelExecutor(t_unit, entrypoint=epoint) + +class PyCudaWithPackedArgsTarget(PyCudaTarget): + + def get_kernel_executor(self, t_unit, **kwargs): + from loopy.target.pycuda_execution import PyCudaWithPackedArgsKernelExecutor + + epoint = kwargs.pop("entrypoint") + t_unit = self.preprocess_translation_unit_for_passed_args(t_unit, + epoint, + kwargs) + + return PyCudaWithPackedArgsKernelExecutor(t_unit, entrypoint=epoint) + + def get_host_ast_builder(self): + return PyCudaWithPackedArgsPythonASTBuilder(self) + + def get_device_ast_builder(self): + return PyCudaWithPackedArgsCASTBuilder(self) + # }}} @@ -387,6 +409,83 @@ class PyCudaPythonASTBuilder(PythonASTBuilderBase): # }}} + +class PyCudaWithPackedArgsPythonASTBuilder(PyCudaPythonASTBuilder): + + def get_kernel_call(self, + codegen_state, subkernel_name, grid, block): + from genpy import Suite, Assign, Line, Comment, Statement + from pymbolic.mapper.stringifier import PREC_NONE + from loopy.kernel.data import ValueArg, ArrayArg + + from loopy.schedule.tools import get_subkernel_arg_info + kernel = codegen_state.kernel + skai = get_subkernel_arg_info(kernel, subkernel_name) + + # make grid/block a 3-tuple + grid = grid + (1,) * (3-len(grid)) + block = block + (1,) * (3-len(block)) + ecm = self.get_expression_to_code_mapper(codegen_state) + + grid_str = ecm(grid, prec=PREC_NONE, type_context="i") + block_str = ecm(block, prec=PREC_NONE, type_context="i") + + struct_format = [] + for arg_name in skai.passed_names: + if arg_name in codegen_state.kernel.all_inames(): + struct_format.append(kernel.index_dtype.numpy_dtype.char) + if kernel.index_dtype.numpy_dtype.itemsize < 8: + struct_format.append("x") + elif arg_name in codegen_state.kernel.temporary_variables: + struct_format.append("P") + else: + knl_arg = codegen_state.kernel.arg_dict[arg_name] + if isinstance(knl_arg, ValueArg): + struct_format.append(knl_arg.dtype.numpy_dtype.char) + if knl_arg.dtype.numpy_dtype.itemsize < 8: + struct_format.append("x") + else: + struct_format.append("P") + + def _arg_cast(arg_name: str) -> str: + if arg_name in skai.passed_inames: + return ("_lpy_np" + f".{codegen_state.kernel.index_dtype.numpy_dtype.name}" + f"({arg_name})") + elif arg_name in skai.passed_temporaries: + return f"_lpy_np.uintp(int({arg_name}))" + else: + knl_arg = codegen_state.kernel.arg_dict[arg_name] + if isinstance(knl_arg, ValueArg): + assert knl_arg.dtype is not None + return f"_lpy_np.{knl_arg.dtype.numpy_dtype.name}({arg_name})" + else: + assert isinstance(knl_arg, ArrayArg) + return f"_lpy_np.uintp(int({arg_name}))" + + return Suite([ + Comment("{{{ launch %s" % subkernel_name), + Line(), + Statement("for _lpy_cu_evt in wait_for: _lpy_cu_evt.synchronize()"), + Line(), + Assign("_lpy_knl", f"_lpy_cuda_functions['{subkernel_name}']"), + Line(), + Assign("_lpy_args_on_dev", f"allocator({len(skai.passed_names)*8})"), + Assign("_lpy_args_on_host", + f"_lpy_struct.pack('{''.join(struct_format)}'," + f"{','.join(_arg_cast(arg) for arg in skai.passed_names)})"), + Statement("_lpy_cuda.memcpy_htod(_lpy_args_on_dev, _lpy_args_on_host)"), + Line(), + Statement("_lpy_knl.prepared_async_call(" + f"{grid_str}, {block_str}, " + "stream, _lpy_args_on_dev)"), + Assign("_lpy_evt", "_lpy_cuda.Event().record(stream)"), + Assign("wait_for", "[_lpy_evt]"), + Line(), + Comment("}}}"), + Line(), + ]) + # }}} @@ -413,6 +512,149 @@ class PyCudaCASTBuilder(CudaCASTBuilder): def get_expression_to_c_expression_mapper(self, codegen_state): return ExpressionToPyCudaCExpressionMapper(codegen_state) + +class PyCudaWithPackedArgsCASTBuilder(PyCudaCASTBuilder): + def arg_struct_name(self, kernel_name: str): + return f"_lpy_{kernel_name}_packed_args" + + def get_function_declaration(self, codegen_state, codegen_result, + schedule_index): + from loopy.target.c import FunctionDeclarationWrapper + from cgen import FunctionDeclaration, Value, Pointer, Extern + from cgen.cuda import CudaGlobal, CudaDevice, CudaLaunchBounds + + kernel = codegen_state.kernel + + assert kernel.linearization is not None + name = codegen_result.current_program(codegen_state).name + arg_type = self.arg_struct_name(name) + + if self.target.fortran_abi: + name += "_" + + fdecl = FunctionDeclaration( + Value("void", name), + [Pointer(Value(arg_type, "_lpy_args"))]) + + if codegen_state.is_entrypoint: + fdecl = CudaGlobal(fdecl) + if self.target.extern_c: + fdecl = Extern("C", fdecl) + + from loopy.schedule import get_insn_ids_for_block_at + _, lsize = kernel.get_grid_sizes_for_insn_ids_as_exprs( + get_insn_ids_for_block_at(kernel.linearization, schedule_index), + codegen_state.callables_table) + + from loopy.symbolic import get_dependencies + if not get_dependencies(lsize): + # Sizes can't have parameter dependencies if they are + # to be used in static thread block size. + from pytools import product + nthreads = product(lsize) + + fdecl = CudaLaunchBounds(nthreads, fdecl) + + return [], FunctionDeclarationWrapper(fdecl) + else: + return [], CudaDevice(fdecl) + + def get_function_definition( + self, codegen_state: CodeGenerationState, + codegen_result: CodeGenerationResult, + schedule_index: int, function_decl: Generable, function_body: Generable + ) -> Generable: + from typing import cast + from loopy.target.c import generate_array_literal + from loopy.schedule import CallKernel + from loopy.schedule.tools import get_subkernel_arg_info + from loopy.kernel.data import ValueArg, AddressSpace + from cgen import (FunctionBody, + Module as Collection, + Initializer, + Line, Value, Pointer, Struct as GenerableStruct) + kernel = codegen_state.kernel + assert kernel.linearization is not None + assert isinstance(kernel.linearization[schedule_index], CallKernel) + kernel_name = (cast(CallKernel, + kernel.linearization[schedule_index]) + .kernel_name) + + skai = get_subkernel_arg_info(kernel, kernel_name) + + result = [] + + # We only need to write declarations for global variables with + # the first device program. `is_first_dev_prog` determines + # whether this is the first device program in the schedule. + is_first_dev_prog = codegen_state.is_generating_device_code + for i in range(schedule_index): + if isinstance(kernel.linearization[i], CallKernel): + is_first_dev_prog = False + break + if is_first_dev_prog: + for tv in sorted( + kernel.temporary_variables.values(), + key=lambda key_tv: key_tv.name): + + if tv.address_space == AddressSpace.GLOBAL and ( + tv.initializer is not None): + assert tv.read_only + + decl = self.wrap_global_constant( + self.get_temporary_var_declarator(codegen_state, tv)) + + if tv.initializer is not None: + decl = Initializer(decl, generate_array_literal( + codegen_state, tv, tv.initializer)) + + result.append(decl) + + # {{{ declare+unpack the struct type + + struct_fields = [] + + for arg_name in skai.passed_names: + if arg_name in skai.passed_inames: + struct_fields.append( + Value(self.target.dtype_to_typename(kernel.index_dtype), + f"{arg_name}, __padding_{arg_name}")) + elif arg_name in skai.passed_temporaries: + tv = kernel.temporary_variables[arg_name] + struct_fields.append(Pointer( + Value(self.target.dtype_to_typename(tv.dtype), arg_name))) + else: + knl_arg = kernel.arg_dict[arg_name] + if isinstance(knl_arg, ValueArg): + struct_fields.append( + Value(self.target.dtype_to_typename(knl_arg.dtype), + f"{arg_name}, __padding_{arg_name}")) + else: + struct_fields.append( + Pointer(Value(self.target.dtype_to_typename(knl_arg.dtype), + arg_name))) + + function_body.insert(0, Line()) + for arg_name in skai.passed_names[::-1]: + function_body.insert(0, Initializer( + self.arg_to_cgen_declarator( + kernel, arg_name, + arg_name in kernel.get_written_variables()), + f"_lpy_args->{arg_name}" + )) + + # }}} + + fbody = FunctionBody(function_decl, function_body) + + return Collection([*result, + Line(), + GenerableStruct(self.arg_struct_name(kernel_name), + struct_fields), + Line(), + fbody]) + + # }}} # vim: foldmethod=marker diff --git a/loopy/target/pycuda_execution.py b/loopy/target/pycuda_execution.py index fc8135bc8..1da99157b 100644 --- a/loopy/target/pycuda_execution.py +++ b/loopy/target/pycuda_execution.py @@ -168,6 +168,7 @@ class PyCudaExecutionWrapperGenerator(ExecutionWrapperGeneratorBase): gen.add_to_preamble("import pycuda.driver as _lpy_cuda") gen.add_to_preamble("import pycuda.gpuarray as _lpy_cuda_array") gen.add_to_preamble("import pycuda.tools as _lpy_cuda_tools") + gen.add_to_preamble("import struct as _lpy_struct") from loopy.target.c.c_execution import DEF_EVEN_DIV_FUNCTION gen.add_to_preamble(DEF_EVEN_DIV_FUNCTION) @@ -249,29 +250,6 @@ class _KernelInfo: # {{{ kernel executor - -def _get_arg_dtypes(knl, subkernel_name): - from loopy.schedule.tools import get_subkernel_arg_info - from loopy.kernel.data import ValueArg - - skai = get_subkernel_arg_info(knl, subkernel_name) - arg_dtypes = [] - for arg in skai.passed_names: - if arg in skai.passed_inames: - arg_dtypes.append(knl.index_dtype.numpy_dtype) - elif arg in skai.passed_temporaries: - arg_dtypes.append("P") - else: - assert arg in knl.arg_dict - if isinstance(knl.arg_dict[arg], ValueArg): - arg_dtypes.append(knl.arg_dict[arg].dtype.numpy_dtype) - else: - # Array Arg - arg_dtypes.append("P") - - return arg_dtypes - - class PyCudaKernelExecutor(KernelExecutorBase): """ An object connecting a kernel to a :mod:`pycuda` @@ -288,6 +266,27 @@ class PyCudaKernelExecutor(KernelExecutorBase): def get_wrapper_generator(self): return PyCudaExecutionWrapperGenerator() + def _get_arg_dtypes(self, knl, subkernel_name): + from loopy.schedule.tools import get_subkernel_arg_info + from loopy.kernel.data import ValueArg + + skai = get_subkernel_arg_info(knl, subkernel_name) + arg_dtypes = [] + for arg in skai.passed_names: + if arg in skai.passed_inames: + arg_dtypes.append(knl.index_dtype.numpy_dtype) + elif arg in skai.passed_temporaries: + arg_dtypes.append("P") + else: + assert arg in knl.arg_dict + if isinstance(knl.arg_dict[arg], ValueArg): + arg_dtypes.append(knl.arg_dict[arg].dtype.numpy_dtype) + else: + # Array Arg + arg_dtypes.append("P") + + return arg_dtypes + @memoize_method def translation_unit_info(self, arg_to_dtype: Optional[Map[str, LoopyType]] = None @@ -329,7 +328,7 @@ class PyCudaKernelExecutor(KernelExecutorBase): cuda_functions = Map({name: (src_module .get_function(name) - .prepare(_get_arg_dtypes(epoint_knl, name)) + .prepare(self._get_arg_dtypes(epoint_knl, name)) ) for name in get_subkernels(epoint_knl)}) return _KernelInfo( @@ -385,6 +384,12 @@ class PyCudaKernelExecutor(KernelExecutorBase): translation_unit_info.cuda_functions, stream, allocator, wait_for, out_host, **kwargs) + +class PyCudaWithPackedArgsKernelExecutor(PyCudaKernelExecutor): + + def _get_arg_dtypes(self, knl, subkernel_name): + return ["P"] + # }}} # vim: foldmethod=marker diff --git a/test/test_pycuda_invoker.py b/test/test_pycuda_invoker.py index 1c4449fe3..930336636 100644 --- a/test/test_pycuda_invoker.py +++ b/test/test_pycuda_invoker.py @@ -58,7 +58,9 @@ def get_random_array(rng, shape: Tuple[int, ...], dtype: np.dtype[Any]): return rng.random(shape, dtype=dtype) -def test_pycuda_invoker(): +@pytest.mark.parametrize("target", [lp.PyCudaTarget(), + lp.PyCudaWithPackedArgsTarget()]) +def test_pycuda_invoker(target): m = 5 n = 6 @@ -67,7 +69,7 @@ def test_pycuda_invoker(): """ y[i, j] = i+j """, - target=lp.PyCudaTarget()) + target=target) knl = lp.split_iname(knl, "i", 5, inner_tag="l.0", outer_tag="g.0") evt, (out,) = knl(n=n, m=m) @@ -78,7 +80,9 @@ def test_pycuda_invoker(): ) -def test_gbarrier(): +@pytest.mark.parametrize("target", [lp.PyCudaTarget(), + lp.PyCudaWithPackedArgsTarget()]) +def test_gbarrier(target): n = 5 knl = lp.make_kernel( "{[i, j]: 0<=i,j tmp[i] = sin(x[i]) z[i] = 2 * tmp[i] """, - target=lp.PyCudaTarget()) + target=target) knl = lp.set_temporary_address_space(knl, "tmp", lp.AddressSpace.GLOBAL) evt, (out,) = knl(x=x, out_host=False) np.testing.assert_allclose(2*np.sin(x), out.get(), rtol=1e-6) +@pytest.mark.parametrize("target", [lp.PyCudaTarget(), + lp.PyCudaWithPackedArgsTarget()]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_multi_entrypoints(dtype): +def test_multi_entrypoints(target, dtype): rng = np.random.default_rng(seed=314) x = rng.random(42, dtype=dtype) @@ -172,7 +184,7 @@ def test_multi_entrypoints(dtype): z[i] = sin(x[i]) """, name="mysin", - target=lp.PyCudaTarget()) + target=target) knl1 = lp.add_dtypes(knl1, {"x": dtype}) knl2 = lp.make_kernel( @@ -181,7 +193,7 @@ def test_multi_entrypoints(dtype): z[i] = cos(x[i]) """, name="mycos", - target=lp.PyCudaTarget()) + target=target) knl2 = lp.add_dtypes(knl2, {"x": dtype}) knl = lp.merge([knl1, knl2]) @@ -193,7 +205,9 @@ def test_multi_entrypoints(dtype): np.testing.assert_allclose(np.sin(x), out, rtol=1e-6) -def test_global_arg_with_offsets(): +@pytest.mark.parametrize("target", [lp.PyCudaTarget(), + lp.PyCudaWithPackedArgsTarget()]) +def test_global_arg_with_offsets(target): rng = np.random.default_rng(seed=314) x = rng.random(104) @@ -208,18 +222,20 @@ def test_global_arg_with_offsets(): [lp.GlobalArg("x,y", offset=lp.auto, shape=lp.auto), ...], - target=lp.PyCudaTarget()) + target=target) knl = lp.set_temporary_address_space(knl, "tmp", lp.AddressSpace.GLOBAL) evt, (out,) = knl(x=x, y=y) np.testing.assert_allclose(42*np.sin(x) + 1729*np.cos(y), out) +@pytest.mark.parametrize("target", [lp.PyCudaTarget(), + lp.PyCudaWithPackedArgsTarget()]) @pytest.mark.parametrize("dtype,rtol", [(np.complex64, 1e-6), (np.complex128, 1e-14), (np.float32, 1e-6), (np.float64, 1e-14)]) -def test_sum_of_array(dtype, rtol): +def test_sum_of_array(target, dtype, rtol): # Reported by Mit Kotak rng = np.random.default_rng(seed=0) knl = lp.make_kernel( @@ -227,50 +243,77 @@ def test_sum_of_array(dtype, rtol): """ out = sum(i, x[i]) """, - target=lp.PyCudaTarget()) + target=target) x = get_random_array(rng, (42,), np.dtype(dtype)) evt, (out,) = knl(x=x) np.testing.assert_allclose(np.sum(x), out, rtol=rtol) +@pytest.mark.parametrize("target", [lp.PyCudaTarget(), + lp.PyCudaWithPackedArgsTarget()]) @pytest.mark.parametrize("dtype,rtol", [(np.complex64, 1e-6), (np.complex128, 1e-14), (np.float32, 1e-6), (np.float64, 1e-14)]) -def test_int_pow(dtype, rtol): +def test_int_pow(target, dtype, rtol): rng = np.random.default_rng(seed=0) knl = lp.make_kernel( "{[i]: 0 <= i < N}", """ out[i] = x[i] ** i """, - target=lp.PyCudaTarget()) + target=target) x = get_random_array(rng, (10,), np.dtype(dtype)) evt, (out,) = knl(x=x) np.testing.assert_allclose(x ** np.arange(10, dtype=np.int32), out, rtol=rtol) +@pytest.mark.parametrize("target", [lp.PyCudaTarget(), + lp.PyCudaWithPackedArgsTarget()]) @pytest.mark.parametrize("dtype", [np.complex64, np.complex128, np.float32, np.float64]) @pytest.mark.parametrize("func", ["abs", "sqrt", "sin", "cos", "tan", "sinh", "cosh", "tanh", "exp", "log", "log10"]) -def test_math_functions(dtype, func): +def test_math_functions(target, dtype, func): rng = np.random.default_rng(seed=0) knl = lp.make_kernel( "{[i]: 0 <= i < N}", f""" y[i] = {func}(x[i]) """, - target=lp.PyCudaTarget()) + target=target) x = get_random_array(rng, (42,), np.dtype(dtype)) _, (out,) = knl(x=x) np.testing.assert_allclose(getattr(np, func)(x), out, rtol=1e-6) +def test_pycuda_packargs_tgt_avoids_param_space_overflow(): + from pymbolic.primitives import Sum + from loopy.symbolic import parse + + nargs = 1_000 + rng = np.random.default_rng(32) + knl = lp.make_kernel( + "{[i]: 0<=i 1: exec(sys.argv[1]) -- GitLab