diff --git a/loopy/__init__.py b/loopy/__init__.py index 9c16200f37f18875f2ccfe6546d6e7cfc2ea0f2a..aac8bc67442cb2ee1964a03f773dc68168478e9d 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -22,7 +22,8 @@ class LoopyAdvisory(UserWarning): from loopy.kernel import ScalarArg, GlobalArg, ArrayArg, ConstantArg, ImageArg from loopy.kernel import (AutoFitLocalIndexTag, get_dot_dependency_graph, - LoopKernel, Instruction, default_function_mangler, single_arg_function_mangler, + LoopKernel, Instruction, + default_function_mangler, single_arg_function_mangler, opencl_function_mangler, default_preamble_generator) from loopy.creation import make_kernel from loopy.reduction import register_reduction_parser @@ -38,6 +39,7 @@ __all__ = ["ScalarArg", "GlobalArg", "ArrayArg", "ConstantArg", "ImageArg", "LoopKernel", "Instruction", "default_function_mangler", "single_arg_function_mangler", + "opencl_function_mangler", "opencl_symbol_mangler", "default_preamble_generator", "make_kernel", "register_reduction_parser", diff --git a/loopy/codegen/expression.py b/loopy/codegen/expression.py index 9e706e1ca8f6ccf9ad093416d47c4ae067a9f7e8..e522b792934bac5c1b77b42e7aa8626b25de0baf 100644 --- a/loopy/codegen/expression.py +++ b/loopy/codegen/expression.py @@ -61,8 +61,7 @@ class TypeInferenceMapper(CombineMapper): mangle_result = self.kernel.mangle_function(identifier, arg_dtypes) if mangle_result is not None: - result_dtype, c_name = mangle_result - return result_dtype + return mangle_result[0] raise RuntimeError("no type inference information on " "function '%s'" % identifier) @@ -107,6 +106,16 @@ class TypeInferenceMapper(CombineMapper): # }}} +def perform_cast(ccm, expr, expr_dtype, target_dtype): + # detect widen-to-complex, account for it. + if (ccm.allow_complex + and target_dtype.kind == "c" + and expr_dtype.kind != "c"): + from pymbolic import var + expr = var("%s_fromreal" % ccm.complex_type_name(target_dtype))(expr) + + return expr + # {{{ C code mapper class LoopyCCodeMapper(CCodeMapper): @@ -302,20 +311,34 @@ class LoopyCCodeMapper(CCodeMapper): identifier = identifier.name c_name = identifier - arg_dtypes = tuple(self.infer_type(par) for par in expr.parameters) + par_dtypes = tuple(self.infer_type(par) for par in expr.parameters) - mangle_result = self.kernel.mangle_function(identifier, arg_dtypes) + parameters = expr.parameters + + mangle_result = self.kernel.mangle_function(identifier, par_dtypes) if mangle_result is not None: - result_dtype, c_name = mangle_result + if len(mangle_result) == 2: + result_dtype, c_name = mangle_result + elif len(mangle_result) == 3: + result_dtype, c_name, arg_tgt_dtypes = mangle_result + + parameters = [ + perform_cast(self, par, par_dtype, tgt_dtype) + for par, par_dtype, tgt_dtype in zip( + parameters, par_dtypes, arg_tgt_dtypes)] + else: + raise RuntimeError("result of function mangler " + "for function '%s' not understood" + % identifier) - self.seen_functions.add((identifier, c_name, arg_dtypes)) + self.seen_functions.add((identifier, c_name, par_dtypes)) if c_name is None: raise RuntimeError("unable to find C name for function identifier '%s'" % identifier) return self.format("%s(%s)", - c_name, self.join_rec(", ", expr.parameters, PREC_NONE)) + c_name, self.join_rec(", ", parameters, PREC_NONE)) # {{{ deal with complex-valued variables diff --git a/loopy/codegen/instruction.py b/loopy/codegen/instruction.py index ecb500d901e7e101d78711e7ea22cb0dcd81c1d8..46a29fc50ce27d0d5168c4c1c33d9724a338f6f7 100644 --- a/loopy/codegen/instruction.py +++ b/loopy/codegen/instruction.py @@ -11,14 +11,9 @@ def generate_instruction_code(kernel, insn, codegen_state): expr = insn.expression - if ccm.allow_complex: - # detect widen-to-complex in assignment, account for it. - expr_dtype = ccm.infer_type(expr) - var_dtype = kernel.get_var_descriptor(insn.get_assignee_var_name()).dtype - - if var_dtype.kind == "c" and expr_dtype.kind != "c": - from pymbolic import var - expr = var("%s_fromreal" % ccm.complex_type_name(var_dtype))(expr) + from loopy.codegen.expression import perform_cast + expr = perform_cast(ccm, expr, expr_dtype=ccm.infer_type(expr), + target_dtype=kernel.get_var_descriptor(insn.get_assignee_var_name()).dtype) from cgen import Assign insn_code = Assign(ccm(insn.assignee), ccm(expr)) diff --git a/loopy/kernel.py b/loopy/kernel.py index e5e15f995b6bdc8e3e9aa4f0eceba3b636ab3b47..30509831529c6190d055a9df7bd061b2766e0c86 100644 --- a/loopy/kernel.py +++ b/loopy/kernel.py @@ -434,6 +434,28 @@ def default_function_mangler(name, arg_dtypes): return None +def opencl_function_mangler(name, arg_dtypes): + if name == "atan2" and len(arg_dtypes) == 2: + return arg_dtypes[0], name + + if len(arg_dtypes) == 1: + arg_dtype, = arg_dtypes + + if arg_dtype.kind == "c": + if arg_dtype == np.complex64: + tpname = "cfloat" + elif arg_dtype == np.complex128: + tpname = "cdouble" + else: + raise RuntimeError("unexpected complex type '%s'" % arg_dtype) + + if name in ["sqrt", "exp", "log", + "sin", "cos", "tan", + "sinh", "cosh", "tanh"]: + return arg_dtype, "%s_%s" % (tpname, name) + + return None + def single_arg_function_mangler(name, arg_dtypes): if len(arg_dtypes) == 1: dtype, = arg_dtypes @@ -528,8 +550,9 @@ class LoopKernel(Record): :ivar substitutions: a mapping from substitution names to :class:`SubstitutionRule` objects :ivar function_manglers: list of functions of signature (name, arg_dtypes) - returning a tuple (result_dtype, c_name), where c_name - is the C-level function to be called. + returning a tuple (result_dtype, c_name) + or a tuple (result_dtype, c_name, arg_dtypes), + where c_name is the C-level function to be called. :ivar symbol_manglers: list of functions of signature (name) returning a tuple (result_dtype, c_name), where c_name is the C-level symbol to be evaluated. @@ -566,7 +589,11 @@ class LoopKernel(Record): temporary_variables={}, iname_to_tag={}, substitutions={}, - function_manglers=[default_function_mangler, single_arg_function_mangler], + function_manglers=[ + default_function_mangler, + opencl_function_mangler, + single_arg_function_mangler, + ], symbol_manglers=[opencl_symbol_mangler], defines={},