From 438fd1da29beb6f3ad900c14c39b00dcef609a33 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Wed, 21 Mar 2018 05:06:14 -0500
Subject: [PATCH] Fixed with_types backed to the target

---
 loopy/kernel/function_interface.py | 182 ++++-------------------------
 loopy/library/random123.py         |  42 +++++++
 loopy/target/__init__.py           |   9 ++
 loopy/target/c/__init__.py         |  91 +++++++++++++++
 loopy/target/opencl.py             | 119 ++++++++++++++++++-
 loopy/target/pyopencl.py           |  49 ++++++++
 loopy/type_inference.py            |  14 +--
 7 files changed, 335 insertions(+), 171 deletions(-)

diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py
index f2c24b293..13955f928 100644
--- a/loopy/kernel/function_interface.py
+++ b/loopy/kernel/function_interface.py
@@ -2,13 +2,11 @@ from __future__ import division, absolute_import
 
 import re
 import six
-import numpy as np
 
 from six.moves import zip
 
 from pytools import ImmutableRecord
 from loopy.diagnostic import LoopyError
-from loopy.types import NumpyType
 
 from loopy.kernel.instruction import (MultiAssignmentBase, CInstruction,
                 _DataObliviousInstruction)
@@ -85,115 +83,6 @@ class ArrayArgDescriptor(ArgDescriptor):
 # }}}
 
 
-# {{{ c with types
-
-def c_with_types(name, arg_id_to_dtype):
-
-    # Specializing the type of the math function once they agree upon the
-    # function signature.
-
-    if name in ["abs", "acos", "asin", "atan", "cos", "cosh", "sin", "sinh",
-            "tanh", "exp", "log", "log10", "sqrt", "ceil", "floor", "tan"]:
-        for id, dtype in arg_id_to_dtype.items():
-            if not -1 <= id <= 0:
-                raise LoopyError("%s can take only one argument." % name)
-
-        dtype = arg_id_to_dtype[0].numpy_dtype
-
-        if dtype.kind == 'f':
-            # generic type resolve we can go ahead and specialize
-            pass
-        elif dtype.kind in ['u', 'i']:
-            # int and unsigned are casted into float32
-            dtype = np.float32
-        else:
-            raise LoopyError("%s function cannot take arguments of the type %s"
-                    % (name, dtype))
-
-        # Done specializing. Returning the intended arg_id_to_dtype
-        dtype = NumpyType(dtype)
-        return {-1: dtype, 0: dtype}
-
-    # binary functions
-    elif name in ["max", "min"]:
-        for id, dtype in arg_id_to_dtype.items():
-            if not -1 <= id <= 1:
-                raise LoopyError("%s can take only two arguments." % name)
-
-        # finding the common type for all the dtypes involved
-        dtype = np.find_common_type(
-            [], [dtype.numpy_dtype for dtype in arg_id_to_dtype])
-
-        if dtype.kind == 'f':
-            # generic type resolve we can go ahead and specialize
-            pass
-        elif dtype.kind in ['u', 'i']:
-            # int and unsigned are implicitly casted into float32
-            dtype = np.float32
-        else:
-            raise LoopyError("%s function cannot take arguments of the type %s"
-                    % (name, dtype))
-
-        # Specialized into one of the known types
-        return {-1: NumpyType(dtype), 0: arg_id_to_dtype[0], 1: arg_id_to_dtype[1]}
-
-    else:
-        # could not specialize the function within the C namespace
-        # this would help when checking for OpenCL/CUDA function which are not
-        # present in C
-        return None
-
-# }}}
-
-
-# {{{ opencl with_types
-
-def opencl_with_types(name, arg_id_to_dtype):
-    new_arg_id_to_dtype = c_with_types(name, arg_id_to_dtype)
-    if new_arg_id_to_dtype is None:
-        # could not locate the function within C's namespace. Searching in
-        # OpenCL specific namespace
-
-        # FIXME: Need to add these functions over here
-        new_arg_id_to_dtype = None
-
-    return new_arg_id_to_dtype
-
-# }}}
-
-
-# {{{ pyopencl with_types
-
-def pyopencl_with_types(name, arg_id_to_dtype):
-    new_arg_id_to_dtype = opencl_with_types(name, arg_id_to_dtype)
-    if new_arg_id_to_dtype is None:
-        # could not locate the function within C's namespace. Searching in
-        # PyOpenCL specific namespace
-
-        # FIXME: Need to add these functions over here
-        new_arg_id_to_dtype = None
-
-    return new_arg_id_to_dtype
-
-# }}}
-
-
-# {{{ cuda with_types
-
-def cuda_with_types(name, arg_id_to_dtype):
-    new_arg_id_to_dtype = c_with_types(name, arg_id_to_dtype)
-    if new_arg_id_to_dtype is None:
-        # could not locate the function within C's namespace. Searching in
-        # CUDA specific namespace
-
-        # FIXME: Need to add these extra functions over here
-        new_arg_id_to_dtype = None
-
-    return new_arg_id_to_dtype
-
-# }}}
-
-
 # {{{ kw_to_pos
 
 def get_kw_pos_association(kernel):
@@ -243,7 +132,7 @@ class InKernelCallable(ImmutableRecord):
     """
 
     def __init__(self, name, subkernel=None, arg_id_to_dtype=None,
-            arg_id_to_descr=None):
+            arg_id_to_descr=None, name_in_target=None):
 
         # {{{ sanity checks
 
@@ -252,10 +141,14 @@ class InKernelCallable(ImmutableRecord):
 
         # }}}
 
+        if name_in_target is not None and subkernel is not None:
+            subkernel = subkernel.copy(name=name_in_target)
+
         super(InKernelCallable, self).__init__(name=name,
                 subkernel=subkernel,
                 arg_id_to_dtype=arg_id_to_dtype,
-                arg_id_to_descr=arg_id_to_descr)
+                arg_id_to_descr=arg_id_to_descr,
+                name_in_target=name_in_target)
 
     def with_types(self, arg_id_to_dtype, target):
         """
@@ -285,37 +178,15 @@ class InKernelCallable(ImmutableRecord):
                     raise LoopyError("Overwriting a specialized"
                             " function is illegal--maybe start with new instance of"
                             " InKernelCallable?")
-            # TODO: Check if the arguments match. If yes then just
-            # return self.copy()
 
         # {{{ attempt to specialize using scalar functions
 
         if self.name in target.get_device_ast_builder().function_identifiers():
-            from loopy.target.c import CTarget
-            from loopy.target.opencl import OpenCLTarget
-            from loopy.target.pyopencl import PyOpenCLTarget
-            from loopy.target.cuda import CudaTarget
-
-            # FIXME: Push this into the target
-            if isinstance(target, CTarget):
-                new_arg_id_to_dtype = c_with_types(self.name, arg_id_to_dtype)
-
-            elif isinstance(target, OpenCLTarget):
-                new_arg_id_to_dtype = opencl_with_types(self.name, arg_id_to_dtype)
-
-            elif isinstance(target, PyOpenCLTarget):
-                new_arg_id_to_dtype = pyopencl_with_types(self.name, arg_id_to_dtype)
-
-            elif isinstance(target, CudaTarget):
-                new_arg_id_to_dtype = cuda_with_types(self.name, arg_id_to_dtype)
-
-            else:
-                raise NotImplementedError("InKernelCallable.with_types() for"
-                        " %s target" % target)
-
-            if new_arg_id_to_dtype is not None:
-                # got our speciliazed function
-                return self.copy(arg_id_to_dtype=new_arg_id_to_dtype)
+            new_in_knl_callable = target.get_device_ast_builder().with_types(
+                    self, arg_id_to_dtype)
+            if new_in_knl_callable is None:
+                new_in_knl_callable = self.copy()
+            return new_in_knl_callable
 
         # }}}
 
@@ -444,7 +315,8 @@ class InKernelCallable(ImmutableRecord):
     def is_ready_for_code_gen(self):
 
         return (self.arg_id_to_dtype is not None and
-                self.arg_id_to_descr is not None)
+                self.arg_id_to_descr is not None and
+                self.name_in_target is not None)
 
     # {{{ code generation
 
@@ -453,16 +325,10 @@ class InKernelCallable(ImmutableRecord):
         """
         raise NotImplementedError()
 
-    def get_target_specific_name(self, target):
-
-        if self.subkernel is None:
-            return self.name
-        else:
-            return self.subkernel.name
+    def emit_call(self, expression_to_code_mapper, expression, target):
 
-        raise NotImplementedError()
+        assert self.is_ready_for_code_gen()
 
-    def emit_call(self, expression_to_code_mapper, expression, target):
         if self.subkernel:
             raise NotImplementedError()
 
@@ -484,10 +350,12 @@ class InKernelCallable(ImmutableRecord):
                     expression.parameters, par_dtypes, arg_dtypes))
 
         from pymbolic import var
-        return var(self.get_target_specific_name(target))(*processed_parameters)
+        return var(self.name_in_target)(*processed_parameters)
 
     def emit_call_insn(self, insn, target, expression_to_code_mapper):
 
+        assert self.is_ready_for_code_gen()
+
         from loopy.kernel.instruction import CallInstruction
         from pymbolic.primitives import CallWithKwargs
 
@@ -507,7 +375,7 @@ class InKernelCallable(ImmutableRecord):
             parameters.append(kw_parameters[pos_to_kw[i]])
             par_dtypes.append(self.arg_id_to_dtype[pos_to_kw[i]])
 
-        # TODO: currently no suppport for insn keywords.
+        # TODO: currently no suppport for assignee keywords.
         parameters = parameters + list(assignees)
         par_dtypes = par_dtypes + [self.arg_id_to_dtype[-i-1] for i, _ in
                 enumerate(assignees)]
@@ -523,7 +391,7 @@ class InKernelCallable(ImmutableRecord):
                     parameters, par_dtypes)]
 
         from pymbolic import var
-        return var(self.get_target_specific_name(target))(*c_parameters)
+        return var(self.name_in_target)(*c_parameters)
 
     # }}}
 
@@ -718,12 +586,10 @@ def register_pymbolic_calls_to_knl_callables(kernel,
 
             # book-keeping of the functions and names mappings for later use
             if in_knl_callable.subkernel is not None:
-                # changing the name of the subkenrel so that it emits a function
-                # with the name same as the name being used in the
-                # scoped_function.
-                new_subkernel = in_knl_callable.subkernel.copy(
-                        name=unique_name)
-                in_knl_callable = in_knl_callable.copy(subkernel=new_subkernel)
+                # for array calls the name in the target is the name of the
+                # scoped funciton
+                in_knl_callable = in_knl_callable.copy(
+                        name_in_target=unique_name)
             scoped_names_to_functions[unique_name] = in_knl_callable
             scoped_functions_to_names[in_knl_callable] = unique_name
 
diff --git a/loopy/library/random123.py b/loopy/library/random123.py
index 871dde0a6..b28d11ba6 100644
--- a/loopy/library/random123.py
+++ b/loopy/library/random123.py
@@ -223,4 +223,46 @@ def random123_function_mangler(kernel, name, arg_dtypes):
     else:
         return None
 
+
+def random123_with_types(in_knl_callable, arg_id_to_dtype, target):
+    name = in_knl_callable.name
+
+    if name not in FUNC_NAMES_TO_RNG:
+        return None
+
+    rng_variant = FUNC_NAMES_TO_RNG[name]
+    1/0
+
+    from loopy.types import NumpyType
+    base_dtype = {32: np.uint32, 64: np.uint64}[rng_variant.bits]
+    ctr_dtype = target.vector_dtype(NumpyType(base_dtype), rng_variant.width)
+    key_dtype = target.vector_dtype(NumpyType(base_dtype), rng_variant.key_width)
+
+    from loopy.kernel.data import CallMangleInfo
+    fn = rng_variant.full_name
+    if name == fn:
+        return CallMangleInfo(
+                target_name=fn+"_gen",
+                result_dtypes=(ctr_dtype, ctr_dtype),
+                arg_dtypes=(ctr_dtype, key_dtype))
+
+    elif name == fn + "_f32":
+        return CallMangleInfo(
+                target_name=name,
+                result_dtypes=(
+                    target.vector_dtype(NumpyType(np.float32), rng_variant.width),
+                    ctr_dtype),
+                arg_dtypes=(ctr_dtype, key_dtype))
+
+    elif name == fn + "_f64":
+        return CallMangleInfo(
+                target_name=name,
+                result_dtypes=(
+                    target.vector_dtype(NumpyType(np.float64), rng_variant.width),
+                    ctr_dtype),
+                arg_dtypes=(ctr_dtype, key_dtype))
+
+    else:
+        return None
+
 # vim: foldmethod=marker
diff --git a/loopy/target/__init__.py b/loopy/target/__init__.py
index fe6daf12c..336985ede 100644
--- a/loopy/target/__init__.py
+++ b/loopy/target/__init__.py
@@ -162,6 +162,15 @@ class ASTBuilderBase(object):
     def preamble_generators(self):
         return []
 
+    def with_types(self, in_knl_callable, arg_id_to_dtype):
+        """
+        Checks the in-kernel callable with the target specific functions and then
+        returns either `None` when no match is found or returns a new type
+        specialized instance of :class:`InKernelCallable`.
+
+        """
+        return None
+
     # }}}
 
     # {{{ code generation guts
diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py
index b79e6ca48..5ebcd67e1 100644
--- a/loopy/target/c/__init__.py
+++ b/loopy/target/c/__init__.py
@@ -426,6 +426,90 @@ def c_math_mangler(target, name, arg_dtypes, modify_name=True):
 
     return None
 
+
+def c_with_types(in_knl_callable, arg_id_to_dtype, modify_name=False):
+    # Function mangler for math functions defined in C standard
+    # Convert abs, min, max to fabs, fmin, fmax.
+    # If modify_name is set to True, function names are modified according to
+    # floating point types of the arguments (e.g. cos(double), cosf(float))
+    # This should be set to True for C and Cuda, False for OpenCL
+    name = in_knl_callable.name
+
+    if name in ["abs", "min", "max"]:
+        name = "f" + name
+
+    # unitary functions
+    if name in ["fabs", "acos", "asin", "atan", "cos", "cosh", "sin", "sinh",
+                "tanh", "exp", "log", "log10", "sqrt", "ceil", "floor", "tan"]:
+
+        for id in arg_id_to_dtype:
+            if not -1 <= id <= 0:
+                raise LoopyError("%s can take only one argument." % name)
+
+        if 0 not in arg_id_to_dtype or arg_id_to_dtype[0] is None:
+            # the types provided aren't mature enough to specialize the
+            # callable
+            return None
+
+        dtype = arg_id_to_dtype[0]
+        dtype = dtype.numpy_dtype
+
+        if dtype.kind in ('u', 'i'):
+            # ints and unsigned casted to float32
+            dtype = np.float32
+        elif dtype.kind == 'c':
+            raise LoopyTypeError("%s does not support type %s" % (name, dtype))
+
+        if modify_name:
+            if dtype == np.float64:
+                pass  # fabs
+            elif dtype == np.float32:
+                name = name + "f"  # fabsf
+            elif dtype == np.float128:
+                name = name + "l"  # fabsl
+            else:
+                raise LoopyTypeError("%s does not support type %s" % (name, dtype))
+
+        return in_knl_callable.copy(name_in_target=name,
+                arg_id_to_dtype={0: NumpyType(dtype), -1: NumpyType(dtype)})
+
+    # binary functions
+    if name in ["fmax", "fmin"]:
+
+        for id in arg_id_to_dtype:
+            if not -1 <= id <= 1:
+                raise LoopyError("%s can take only two arguments." % name)
+
+        if 0 not in arg_id_to_dtype or 1 not in arg_id_to_dtype or (
+                arg_id_to_dtype[0] is None or arg_id_to_dtype[1] is None):
+            # the types provided aren't mature enough to specialize the
+            # callable
+            return None
+
+        dtype = np.find_common_type(
+            [], [dtype.numpy_dtype for id, dtype in arg_id_to_dtype.items()
+                 if id >= 0])
+
+        if dtype.kind == "c":
+            raise LoopyTypeError("%s does not support complex numbers")
+
+        elif dtype.kind == "f":
+            if modify_name:
+                if dtype == np.float64:
+                    pass  # fmin
+                elif dtype == np.float32:
+                    name = name + "f"  # fminf
+                elif dtype == np.float128:
+                    name = name + "l"  # fminl
+                else:
+                    raise LoopyTypeError("%s does not support type %s"
+                                         % (name, dtype))
+        dtype = NumpyType(dtype)
+        return in_knl_callable.copy(name_in_target=name,
+                arg_id_to_dtype={-1: dtype, 0: dtype, 1: dtype})
+
+    return None
+
 # }}}
 
 
@@ -455,6 +539,13 @@ class CASTBuilder(ASTBuilderBase):
                     _preamble_generator,
                     ])
 
+    def with_types(self, in_knl_callable, arg_id_to_dtype):
+        new_callable = c_with_types(in_knl_callable, arg_id_to_dtype)
+        if new_callable is not None:
+            return new_callable
+        return super(CASTBuilder, self).with_types(in_knl_callable,
+                arg_id_to_dtype)
+
     # }}}
 
     # {{{ code generation
diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py
index 94870907b..7aec34a22 100644
--- a/loopy/target/opencl.py
+++ b/loopy/target/opencl.py
@@ -31,10 +31,12 @@ from loopy.target.c.codegen.expression import ExpressionToCExpressionMapper
 from pytools import memoize_method
 from loopy.diagnostic import LoopyError
 from loopy.types import NumpyType
-from loopy.target.c import DTypeRegistryWrapper, c_math_identifiers
+from loopy.target.c import (DTypeRegistryWrapper, c_math_identifiers,
+        c_math_mangler, c_with_types)
 from loopy.kernel.data import temp_var_scope, CallMangleInfo
 from pymbolic import var
 
+from functools import partial
 
 # {{{ dtype registry wrappers
 
@@ -156,8 +158,8 @@ def opencl_function_identifiers():
 
 # }}}
 
-# {{{ function mangler
 
+# {{{ function mangler
 
 _CL_SIMPLE_MULTI_ARG_FUNCTIONS = {
         "clamp": 3,
@@ -239,6 +241,95 @@ def opencl_function_mangler(kernel, name, arg_dtypes):
 
     return None
 
+
+def opencl_with_types(in_knl_callable, arg_id_to_dtype):
+
+    name = in_knl_callable.name
+
+    if name in ["max", "min"]:
+        for id in arg_id_to_dtype:
+            if not -1 <= id <= 1:
+                raise LoopyError("%s can take only 2 arguments." % name)
+        if 0 not in arg_id_to_dtype or 1 not in arg_id_to_dtype:
+            return None
+
+        dtype = np.find_common_type(
+                [], [dtype.numpy_dtype for id, dtype in
+                    arg_id_to_dtype.values() if id >= 0])
+
+        if dtype.kind == "i":
+            dtype = NumpyType(dtype)
+            return in_knl_callable.copy(name_in_target=name,
+                    arg_id_to_dtype={-1: dtype, 0: dtype, 1: dtype})
+
+    if name == "dot":
+        for id in arg_id_to_dtype:
+            if not -1 <= id <= 1:
+                raise LoopyError("%s can take only 2 arguments." % name)
+
+        if 0 not in arg_id_to_dtype or 1 not in arg_id_to_dtype or (
+                arg_id_to_dtype[0] is None or arg_id_to_dtype[1] is None):
+            # the types provided aren't mature enough to specialize the
+            # callable
+            return None
+
+        dtype = arg_id_to_dtype[0]
+        scalar_dtype, offset, field_name = dtype.numpy_dtype.fields["s0"]
+        return in_knl_callable.copy(name_in_target=name,
+                arg_id_to_dtype={-1: scalar_dtype, 0: dtype, 1: dtype})
+
+    if name in _CL_SIMPLE_MULTI_ARG_FUNCTIONS:
+        num_args = _CL_SIMPLE_MULTI_ARG_FUNCTIONS[name]
+        for id in arg_id_to_dtype:
+            if not -1 <= id < num_args:
+                raise LoopyError("%s can take only %d arguments." % (name,
+                        num_args))
+
+        for i in range(num_args):
+            if i not in arg_id_to_dtype or arg_id_to_dtype[i] is not None:
+                # the types provided aren't mature enough to specialize the
+                # callable
+                return None
+
+        dtype = np.find_common_type(
+                [], [dtype.numpy_dtype for id, dtype in
+                    arg_id_to_dtype.values() if id >= 0])
+
+        if dtype.kind == "c":
+            raise LoopyError("%s does not support complex numbers"
+                    % name)
+
+        updated_arg_id_to_dtype = dict((id, NumpyType(dtype)) for id in range(-1,
+            num_args))
+
+        return in_knl_callable.copy(name_in_target=name,
+                arg_id_to_dtype=updated_arg_id_to_dtype)
+
+    if name in VECTOR_LITERAL_FUNCS:
+        base_tp_name, dtype, count = VECTOR_LITERAL_FUNCS[name]
+
+        for id in arg_id_to_dtype:
+            if not -1 <= id < count:
+                raise LoopyError("%s can take only %d arguments." % (name,
+                        num_args))
+
+        for i in range(count):
+            if i not in arg_id_to_dtype or arg_id_to_dtype[i] is not None:
+                # the types provided aren't mature enough to specialize the
+                # callable
+                return None
+
+        updated_arg_id_to_dtype = dict((id, NumpyType(dtype)) for id in
+                range(count))
+        updated_arg_id_to_dtype[-1] = OpenCLTarget().vector_dtype(
+                    NumpyType(dtype), count)
+
+        return in_knl_callable.copy(name_in_target="(%s%d) " % (base_tp_name, count),
+                arg_id_to_dtype=updated_arg_id_to_dtype)
+
+    return None
+
+
 # }}}
 
 
@@ -382,6 +473,14 @@ class OpenCLTarget(CTarget):
 class OpenCLCASTBuilder(CASTBuilder):
     # {{{ library
 
+    def function_manglers(self):
+        return (
+                [
+                    opencl_function_mangler,
+                    partial(c_math_mangler, modify_name=False)
+                ] +
+                super(OpenCLCASTBuilder, self).function_manglers())
+
     def function_identifiers(self):
         return (opencl_function_identifiers() | c_math_identifiers() |
                 super(OpenCLCASTBuilder, self).function_identifiers())
@@ -401,6 +500,17 @@ class OpenCLCASTBuilder(CASTBuilder):
                     reduction_preamble_generator,
                     ])
 
+    def with_types(self, in_knl_callable, arg_id_to_dtype):
+        new_callable = opencl_with_types(in_knl_callable, arg_id_to_dtype)
+        if new_callable is not None:
+            return new_callable
+
+        new_callable = c_with_types(in_knl_callable, arg_id_to_dtype)
+        if new_callable is not None:
+            return new_callable
+        return super(OpenCLCASTBuilder, self).with_types(in_knl_callable,
+                arg_id_to_dtype)
+
     # }}}
 
     # {{{ top-level codegen
@@ -412,6 +522,11 @@ class OpenCLCASTBuilder(CASTBuilder):
 
         from loopy.target.c import FunctionDeclarationWrapper
         assert isinstance(fdecl, FunctionDeclarationWrapper)
+        if not codegen_state.is_generating_master_kernel:
+            # auxiliary kernels need not mention opencl speicific qualifiers
+            # for a functions signature
+            return fdecl
+
         fdecl = fdecl.subdecl
 
         from cgen.opencl import CLKernel, CLRequiredWorkGroupSize
diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py
index 1451cf9e7..4dace7ec2 100644
--- a/loopy/target/pyopencl.py
+++ b/loopy/target/pyopencl.py
@@ -236,6 +236,43 @@ def pyopencl_function_mangler(target, name, arg_dtypes):
     return None
 
 
+def pyopencl_with_types(in_knl_callable, arg_id_to_dtype):
+
+    name = in_knl_callable.name
+
+    for id in arg_id_to_dtype:
+        if not -1 <= id <= 0:
+            raise LoopyError("%s can take only one argument." % name)
+
+    if 0 not in arg_id_to_dtype or arg_id_to_dtype[0] is None:
+        # the types provided aren't mature enough to specialize the
+        # callable
+        return None
+
+    dtype = arg_id_to_dtype[0]
+
+    if dtype.is_complex():
+        if dtype.numpy_dtype == np.complex64:
+            tpname = "cfloat"
+        elif dtype.numpy_dtype == np.complex128:
+            tpname = "cdouble"
+        else:
+            raise RuntimeError("unexpected complex type '%s'" % dtype)
+
+        if name in ["sqrt", "exp", "log",
+                "sin", "cos", "tan",
+                "sinh", "cosh", "tanh",
+                "conj"]:
+            return in_knl_callable.copy(name_in_target="%s_%s" % (tpname, name),
+                    arg_id_to_dtype={0: dtype, -1: dtype})
+
+        if name in ["real", "imag", "abs"]:
+            return in_knl_callable.copy(name_in_target="%s_%s" % (tpname, name),
+                    arg_id_to_dtype={0: dtype, -1: dtype.numpy_dtype.type(0).real})
+
+    return None
+
+
 # {{{ preamble generator
 
 def pyopencl_preamble_generator(preamble_info):
@@ -764,6 +801,18 @@ class PyOpenCLCASTBuilder(OpenCLCASTBuilder):
             random123_preamble_generator,
             ] + super(PyOpenCLCASTBuilder, self).preamble_generators())
 
+    def with_types(self, in_knl_callable, arg_id_to_dtype):
+        from loopy.library.random123 import random123_with_types
+        new_callable = super(PyOpenCLCASTBuilder, self).with_types(in_knl_callable,
+                arg_id_to_dtype)
+        if new_callable is not None:
+            return new_callable
+
+        new_callable = pyopencl_with_types(in_knl_callable, arg_id_to_dtype)
+        if new_callable is not None:
+            return new_callable
+        return random123_with_types(in_knl_callable, arg_id_to_dtype)
+
     # }}}
 
 # }}}
diff --git a/loopy/type_inference.py b/loopy/type_inference.py
index ee4bf38be..f974e3fab 100644
--- a/loopy/type_inference.py
+++ b/loopy/type_inference.py
@@ -120,11 +120,6 @@ class TypeInferenceMapper(CombineMapper):
                 0 <= len(dtype_set) <= 1
                 for dtype_set in dtype_sets)
 
-        # Can't infer types if one of the dtypes is unknown
-        for dtype_set in dtype_sets:
-            if dtype_set == []:
-                return []
-
         from pytools import is_single_valued
 
         dtypes = [dtype
@@ -291,15 +286,12 @@ class TypeInferenceMapper(CombineMapper):
         self.specialized_functions[expr] = in_knl_callable
 
         new_arg_id_to_dtype = in_knl_callable.arg_id_to_dtype
-        result_dtypes = []
 
         # collecting result dtypes in order of the assignees
+        if -1 in new_arg_id_to_dtype and new_arg_id_to_dtype[-1] is not None:
+            return [new_arg_id_to_dtype[-1]]
 
-        for i in range(len(new_arg_id_to_dtype)):
-            if -i-1 in new_arg_id_to_dtype:
-                result_dtypes.append(new_arg_id_to_dtype[-i-1])
-            else:
-                return result_dtypes
+        return []
 
         """
         # Letting this stay over here, as it maybe needed later for maintaining
-- 
GitLab