diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py
index 5a642322021136de6f84325118311437761d98d5..343c850149df72f90c6860aa0b5db16fa1528718 100644
--- a/loopy/kernel/creation.py
+++ b/loopy/kernel/creation.py
@@ -1898,7 +1898,7 @@ class FunctionScoper(IdentityMapper):
         from loopy.symbolic import Reduction
 
         return Reduction(
-                ScopedFunction(expr.operation.name),
+                ScopedFunction(expr.function.name),
                 tuple(new_inames),
                 self.rec(expr.expr),
                 allow_simultaneous=expr.allow_simultaneous)
@@ -1918,9 +1918,10 @@ class ScopedFunctionCollector(CombineMapper):
 
     def map_reduction(self, expr):
         from loopy.kernel.reduction_callable import CallableReduction
+        from loopy.kernel.function_interface import CallableOnScalar
         from loopy.symbolic import Reduction
 
-        callable_reduction = CallableReduction(expr.operation.name)
+        callable_reduction = CallableReduction(expr.function.name)
 
         # sanity checks
 
@@ -1949,8 +1950,14 @@ class ScopedFunctionCollector(CombineMapper):
             elif isinstance(expr, Reduction) and callable_reduction.is_tuple_typed:
                 raise LoopyError("got a tuple typed argument to a scalar reduction")
 
-        return frozenset([(expr.operation.name,
-            callable_reduction)])
+        hidden_function = callable_reduction.operation.hidden_function()
+        if hidden_function is not None:
+            return frozenset([(expr.function.name,
+                callable_reduction), (hidden_function,
+                    CallableOnScalar(hidden_function))])
+        else:
+            return frozenset([(expr.function.name,
+                callable_reduction)])
 
     def map_constant(self, expr):
         return frozenset()
diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py
index bc5d178b13bdd66f37f5f0b59238d6ba248512bd..fb80c5876a37110442d56234a2c20e9f78ab40b6 100644
--- a/loopy/kernel/function_interface.py
+++ b/loopy/kernel/function_interface.py
@@ -134,23 +134,17 @@ class InKernelCallable(ImmutableRecord):
 
     """
 
-    fields = set(["name", "arg_id_to_dtype", "arg_id_to_descr"])
-    init_arg_names = ("name", "arg_id_to_dtype", "arg_id_to_descr")
+    fields = set(["arg_id_to_dtype", "arg_id_to_descr"])
+    init_arg_names = ("arg_id_to_dtype", "arg_id_to_descr")
 
-    def __init__(self, name, arg_id_to_dtype=None, arg_id_to_descr=None):
+    def __init__(self, arg_id_to_dtype=None, arg_id_to_descr=None):
 
-        # sanity checks
-
-        if not isinstance(name, str):
-            raise LoopyError("name of an InKernelCallable should be a string")
-
-        super(InKernelCallable, self).__init__(name=name,
+        super(InKernelCallable, self).__init__(
                 arg_id_to_dtype=arg_id_to_dtype,
                 arg_id_to_descr=arg_id_to_descr)
 
     def __getinitargs__(self):
-        return (self.name, self.arg_id_to_dtype, self.arg_id_to_descr,
-                self.name_in_target)
+        return (self.arg_id_to_dtype, self.arg_id_to_descr)
 
     def with_types(self, arg_id_to_dtype, target):
         """
@@ -245,10 +239,11 @@ class CallableOnScalar(InKernelCallable):
     def __init__(self, name, arg_id_to_dtype=None,
             arg_id_to_descr=None, name_in_target=None):
 
-        super(InKernelCallable, self).__init__(name=name,
+        super(InKernelCallable, self).__init__(
                 arg_id_to_dtype=arg_id_to_dtype,
                 arg_id_to_descr=arg_id_to_descr)
 
+        self.name = name
         self.name_in_target = name_in_target
 
     def __getinitargs__(self):
@@ -265,7 +260,7 @@ class CallableOnScalar(InKernelCallable):
                 if self.arg_id_to_dtype[id] != arg_id_to_dtype[id]:
                     raise LoopyError("Overwriting a specialized"
                             " function is illegal--maybe start with new instance of"
-                            " CallableScalar?")
+                            " CallableOnScalar?")
 
         # {{{ attempt to specialize using scalar functions present in target
 
@@ -406,12 +401,13 @@ class CallableKernel(InKernelCallable):
     def __init__(self, name, subkernel, arg_id_to_dtype=None,
             arg_id_to_descr=None, name_in_target=None):
 
-        super(InKernelCallable, self).__init__(name=name,
+        super(InKernelCallable, self).__init__(
                 arg_id_to_dtype=arg_id_to_dtype,
                 arg_id_to_descr=arg_id_to_descr)
         if name_in_target is not None:
             subkernel = subkernel.copy(name=name_in_target)
 
+        self.name = name
         self.name_in_target = name_in_target
         self.subkernel = subkernel
 
@@ -628,7 +624,7 @@ def register_pymbolic_calls_to_knl_callables(kernel,
                 unique_name = next_indexed_name(unique_name)
 
             # book-keeping of the functions and names mappings for later use
-            if in_knl_callable.subkernel is not None:
+            if isinstance(in_knl_callable, CallableKernel):
                 # for array calls the name in the target is the name of the
                 # scoped funciton
                 in_knl_callable = in_knl_callable.copy(
diff --git a/loopy/kernel/reduction_callable.py b/loopy/kernel/reduction_callable.py
index 1682f71608dc0e066cf6930cc2e6b07649bbc8d5..1ad2acd8db6891490f5b7e7fff68d9c3ffd9268b 100644
--- a/loopy/kernel/reduction_callable.py
+++ b/loopy/kernel/reduction_callable.py
@@ -28,7 +28,7 @@ class CallableReduction(InKernelCallable):
 
         self.operation = operation
 
-        super(InKernelCallable, self).__init__(name="",
+        super(InKernelCallable, self).__init__(
                 arg_id_to_dtype=arg_id_to_dtype,
                 arg_id_to_descr=arg_id_to_descr)
 
@@ -47,39 +47,32 @@ class CallableReduction(InKernelCallable):
 
             for id, dtype in arg_id_to_dtype.items():
                 # only checking for the ones which have been provided
-                if self.arg_id_to_dtype[id] != arg_id_to_dtype[id]:
+                if id in self.arg_id_to_dtype and (
+                        self.arg_id_to_dtype[id] != arg_id_to_dtype[id]):
                     raise LoopyError("Overwriting a specialized"
                             " function is illegal--maybe start with new instance of"
-                            " CallableScalar?")
-
-        if self.name in target.get_device_ast_builder().function_identifiers():
-            new_in_knl_callable = target.get_device_ast_builder().with_types(
-                    self, arg_id_to_dtype)
-            if new_in_knl_callable is None:
-                new_in_knl_callable = self.copy()
-            return new_in_knl_callable
-
-        # did not find a scalar function and function prototype does not
-        # even have  subkernel registered => no match found
-        raise LoopyError("Function %s not present within"
-                " the %s namespace" % (self.name, target))
+                            " CallableReduction?")
+        updated_arg_id_to_dtype = self.operation.with_types(arg_id_to_dtype,
+                target)
+        return self.copy(arg_id_to_dtype=updated_arg_id_to_dtype)
 
     def with_descrs(self, arg_id_to_descr):
+        # not sure what would be the reson of having this over here
 
         # This is a scalar call
         # need to assert that the name is in funtion indentifiers
         arg_id_to_descr[-1] = ValueArgDescriptor()
         return self.copy(arg_id_to_descr=arg_id_to_descr)
 
-    def with_iname_tag_usage(self, unusable, concurrent_shape):
-
-        raise NotImplementedError()
+    def inline(self, kernel):
+        # Replaces the job of realize_reduction
+        raise NotImplementedError
 
     def is_ready_for_code_gen(self):
 
         return (self.arg_id_to_dtype is not None and
                 self.arg_id_to_descr is not None and
-                self.name_in_target is not None)
+                self.operation is not None)
 
 
 # vim: foldmethod=marker
diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py
index 5daa1528a7d67c0dc35644dc40c6d179dc01527e..f4444c8864a44840f624c63ab282769e0f321a9a 100644
--- a/loopy/library/reduction.py
+++ b/loopy/library/reduction.py
@@ -36,7 +36,7 @@ class ReductionOperation(object):
     equality-comparable.
     """
 
-    def result_dtypes(self, target, *arg_dtypes):
+    def with_types(self, arg_id_to_dtype, target):
         """
         :arg arg_dtypes: may be None if not known
         :returns: None if not known, otherwise the returned type
@@ -51,6 +51,9 @@ class ReductionOperation(object):
     def neutral_element(self, *dtypes):
         raise NotImplementedError
 
+    def hidden_function(self):
+        return None
+
     def __hash__(self):
         # Force subclasses to override
         raise NotImplementedError
@@ -95,15 +98,22 @@ class ScalarReductionOperation(ReductionOperation):
     def arg_count(self):
         return 1
 
-    def result_dtypes(self, kernel, arg_dtype):
+    def with_types(self, arg_id_to_dtype, target):
+        if 0 not in arg_id_to_dtype or arg_id_to_dtype[0] is None:
+            # do not have enough info to figure out the type.
+            return arg_id_to_dtype.copy()
+
+        arg_dtype = arg_id_to_dtype[0]
+
+        updated_arg_id_to_dtype = arg_id_to_dtype.copy()
         if self.forced_result_type is not None:
-            return (self.parse_result_type(
-                    kernel.target, self.forced_result_type),)
+            updated_arg_id_to_dtype[-1] = (self.parse_result_type(
+                target, self.forced_result_type),)
+            return updated_arg_id_to_dtype
 
-        if arg_dtype is None:
-            return None
+        updated_arg_id_to_dtype[-1] = arg_dtype
 
-        return (arg_dtype,)
+        return updated_arg_id_to_dtype
 
     def __hash__(self):
         return hash((type(self), self.forced_result_type))
@@ -180,7 +190,11 @@ class MaxReductionOperation(ScalarReductionOperation):
         return get_ge_neutral(dtype)
 
     def __call__(self, dtype, operand1, operand2):
-        return var("max")(operand1, operand2)
+        from loopy.symbolic import ScopedFunction
+        return ScopedFunction("max")(operand1, operand2)
+
+    def hidden_function(self):
+        return "max"
 
 
 class MinReductionOperation(ScalarReductionOperation):
@@ -188,7 +202,11 @@ class MinReductionOperation(ScalarReductionOperation):
         return get_le_neutral(dtype)
 
     def __call__(self, dtype, operand1, operand2):
-        return var("min")(operand1, operand2)
+        from loopy.symbolic import ScopedFunction
+        return ScopedFunction("min")(operand1, operand2)
+
+    def hidden_function(self):
+        return "min"
 
 
 # {{{ base class for symbolic reduction ops
@@ -233,9 +251,22 @@ class _SegmentedScalarReductionOperation(ReductionOperation):
         return var("make_tuple")(scalar_neutral_element,
                 segment_flag_dtype.numpy_dtype.type(0))
 
-    def result_dtypes(self, kernel, scalar_dtype, segment_flag_dtype):
-        return (self.inner_reduction.result_dtypes(kernel, scalar_dtype)
-                + (segment_flag_dtype,))
+    def with_types(self,  arg_id_to_dtype, target):
+        for id in range(self.arg_count):
+            if id not in arg_id_to_dtype or arg_id_to_dtype[id] is None:
+                # types of arguemnts not known => result type cannot be
+                # determined.
+                return arg_id_to_dtype.copy()
+
+        scalar_dtype = arg_id_to_dtype[0]
+        segment_flag_dtype = arg_id_to_dtype[1]
+
+        updated_arg_id_to_dtype = arg_id_to_dtype.copy()
+        updated_arg_id_to_dtype[-1] = self.inner_reduction.with_types(
+                {0: scalar_dtype}, target)[-1]
+        updated_arg_id_to_dtype[-2] = segment_flag_dtype
+
+        return updated_arg_id_to_dtype
 
     def __str__(self):
         return "segmented(%s)" % self.which
@@ -299,8 +330,22 @@ class _ArgExtremumReductionOperation(ReductionOperation):
                 scalar_dtype.numpy_dtype.type.__name__,
                 index_dtype.numpy_dtype.type.__name__)
 
-    def result_dtypes(self, kernel, scalar_dtype, index_dtype):
-        return (scalar_dtype, index_dtype)
+    def with_types(self, arg_id_to_dtype, target):
+        for id in range(self.arg_count):
+            if id not in arg_id_to_dtype or arg_id_to_dtype[id] is None:
+                # types of arguemnts not known => result type cannot be
+                # determined.
+                return self.copy(arg_id_to_dtype=arg_id_to_dtype)
+
+        scalar_dtype = arg_id_to_dtype[0]
+        index_dtype = arg_id_to_dtype[1]
+
+        updated_arg_id_to_dtype = arg_id_to_dtype.copy()
+
+        updated_arg_id_to_dtype[-1] = scalar_dtype
+        updated_arg_id_to_dtype[-2] = index_dtype
+
+        return updated_arg_id_to_dtype
 
     def neutral_element(self, scalar_dtype, index_dtype):
         scalar_neutral_func = (
@@ -331,12 +376,18 @@ class ArgMaxReductionOperation(_ArgExtremumReductionOperation):
     update_comparison = ">="
     neutral_sign = -1
 
+    def hidden_function(self):
+        return "max"
+
 
 class ArgMinReductionOperation(_ArgExtremumReductionOperation):
     which = "min"
     update_comparison = "<="
     neutral_sign = +1
 
+    def hidden_function(self):
+        return "min"
+
 
 def get_argext_preamble(kernel, func_id, arg_dtypes):
     op = func_id.reduction_op
@@ -377,8 +428,8 @@ def get_argext_preamble(kernel, func_id, arg_dtypes):
 _REDUCTION_OPS = {
         "sum": SumReductionOperation,
         "product": ProductReductionOperation,
-        "max": MaxReductionOperation,
-        "min": MinReductionOperation,
+        "maximum": MaxReductionOperation,
+        "minimum": MinReductionOperation,
         "argmax": ArgMaxReductionOperation,
         "argmin": ArgMinReductionOperation,
         "segmented(sum)": SegmentedSumReductionOperation,
@@ -429,6 +480,12 @@ def reduction_function_identifiers():
     return set(op for op in _REDUCTION_OPS)
 
 
+def reduction_function_mangler(kernel, func_id, arg_dtypes):
+    raise NotImplementedError("Reduction Function Mangler!")
+
+
+'''
+# KK -- we will replace this with the new interface
 def reduction_function_mangler(kernel, func_id, arg_dtypes):
     if isinstance(func_id, ArgExtOp):
         from loopy.target.opencl import CTarget
@@ -475,6 +532,7 @@ def reduction_function_mangler(kernel, func_id, arg_dtypes):
                 )
 
     return None
+'''
 
 
 def reduction_preamble_generator(preamble_info):
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 34fe6e830c3f3a3cf42d6b77a12fda54c727276f..51389f4f56669e1cebaf7e6b04e5f3e8d9cde0ae 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -39,7 +39,6 @@ from loopy.kernel.data import make_assignment
 # for the benefit of loopy.statistics, for now
 from loopy.type_inference import infer_unknown_types
 from loopy.symbolic import ScopedFunction, CombineMapper
-from pymbolic.mapper import Collector
 
 from loopy.kernel.instruction import (MultiAssignmentBase, CInstruction,
         CallInstruction,  _DataObliviousInstruction)
@@ -893,7 +892,6 @@ def _insert_subdomain_into_domain_tree(kernel, domains, subdomain):
 # }}}
 
 
-
 def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
                       automagic_scans_ok=False, force_scan=False,
                       force_outer_iname_for_scan=None):
@@ -1041,13 +1039,16 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
         init_id = insn_id_gen(
                 "%s_%s_init" % (insn.id, "_".join(expr.inames)))
 
+        reduction_operation = kernel.scoped_functions[
+                expr.function.name].operation
+
         init_insn = make_assignment(
                 id=init_id,
                 assignees=acc_vars,
                 within_inames=outer_insn_inames - frozenset(expr.inames),
                 within_inames_is_final=insn.within_inames_is_final,
                 depends_on=init_insn_depends_on,
-                expression=expr.operation.neutral_element(*arg_dtypes),
+                expression=reduction_operation.neutral_element(*arg_dtypes),
                 predicates=insn.predicates,)
 
         generated_insns.append(init_insn)
@@ -1082,10 +1083,12 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
         else:
             reduction_expr = expr.expr
 
+        reduction_operation = kernel.scoped_functions[
+                expr.function.name].operation
         reduction_insn = make_assignment(
                 id=update_id,
                 assignees=acc_vars,
-                expression=expr.operation(
+                expression=reduction_operation(
                     arg_dtypes,
                     _strip_if_scalar(acc_vars, acc_vars),
                     reduction_expr),
@@ -1094,8 +1097,6 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
                 within_inames_is_final=insn.within_inames_is_final,
                 predicates=insn.predicates,)
 
-        reduction_insn = scope_function_in_insn(reduction_insn, kenrel)
-
         generated_insns.append(reduction_insn)
 
         new_insn_add_depends_on.add(reduction_insn.id)
@@ -1944,6 +1945,8 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
 
     kernel = lp.tag_inames(kernel, new_iname_tags)
 
+    # making changes to the scoped function that are arising
+
     # TODO: remove unused inames...
 
     kernel = (
@@ -2381,10 +2384,6 @@ def preprocess_kernel(kernel, device=None):
     from loopy.kernel.creation import apply_single_writer_depencency_heuristic
     kernel = apply_single_writer_depencency_heuristic(kernel)
 
-    # inferring the shape and dim_tags of the arguments involved in a function
-    # call.
-    kernel = infer_arg_descr(kernel)
-
     # Ordering restrictions:
     #
     # - realize_reduction must happen after type inference because it needs
@@ -2396,6 +2395,10 @@ def preprocess_kernel(kernel, device=None):
 
     kernel = realize_reduction(kernel, unknown_types_ok=False)
 
+    # inferring the shape and dim_tags of the arguments involved in a function
+    # call.
+    kernel = infer_arg_descr(kernel)
+
     # Ordering restriction:
     # add_axes_to_temporaries_for_ilp because reduction accumulators
     # need to be duplicated by this.
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index e8e39a24f35f0744046a420750f4768c02e5c0ca..32670c1cc26bef1dacc18c0a46f70c92a00d9405 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -96,7 +96,7 @@ class IdentityMapperMixin(object):
             new_inames.append(new_sym_iname.name)
 
         return Reduction(
-                expr.operation, tuple(new_inames),
+                expr.function, tuple(new_inames),
                 self.rec(expr.expr, *args),
                 allow_simultaneous=expr.allow_simultaneous)
 
@@ -226,7 +226,7 @@ class StringifyMapper(StringifyMapperBase):
 
         return "%sreduce(%s, [%s], %s)" % (
                 "simul_" if expr.allow_simultaneous else "",
-                expr.operation, ", ".join(expr.inames),
+                expr.function, ", ".join(expr.inames),
                 self.rec(expr.expr, PREC_NONE))
 
     def map_tagged_variable(self, expr, prec):
@@ -266,7 +266,7 @@ class UnidirectionalUnifier(UnidirectionalUnifierBase):
         if not isinstance(other, type(expr)):
             return self.treat_mismatch(expr, other, unis)
         if (expr.inames != other.inames
-                or type(expr.operation) != type(other.operation)  # noqa
+                or type(expr.function) != type(other.function)  # noqa
                 ):
             return []
 
@@ -537,7 +537,7 @@ class Reduction(p.Expression):
     """Represents a reduction operation on :attr:`exprs`
     across :attr:`inames`.
 
-    ..attribute:: operation
+    ..attribute:: function
 
         an instance of :class:`pymbolic.primitives.Variable` which indicates
         the reduction callable that the reduction would point to in the dict
@@ -562,10 +562,10 @@ class Reduction(p.Expression):
         in precisely one reduction, to avoid mis-nesting errors.
     """
 
-    init_arg_names = ("operation", "inames", "expr", "allow_simultaneous")
+    init_arg_names = ("function", "inames", "expr", "allow_simultaneous")
 
-    def __init__(self, operation, inames, expr, allow_simultaneous=False):
-        assert isinstance(operation, p.Variable)
+    def __init__(self, function, inames, expr, allow_simultaneous=False):
+        assert isinstance(function, p.Variable)
 
         if isinstance(inames, str):
             inames = tuple(iname.strip() for iname in inames.split(","))
@@ -610,20 +610,20 @@ class Reduction(p.Expression):
                 raise LoopyError("got a tuple typed argument to a scalar reduction")
         """
 
-        self.operation = operation
+        self.function = function
         self.inames = inames
         self.expr = expr
         self.allow_simultaneous = allow_simultaneous
 
     def __getinitargs__(self):
-        return (self.operation, self.inames, self.expr, self.allow_simultaneous)
+        return (self.funciton, self.inames, self.expr, self.allow_simultaneous)
 
     def get_hash(self):
-        return hash((self.__class__, self.operation, self.inames, self.expr))
+        return hash((self.__class__, self.function, self.inames, self.expr))
 
     def is_equal(self, other):
         return (other.__class__ == self.__class__
-                and other.operation == self.operation
+                and other.function == self.function
                 and other.inames == self.inames
                 and other.expr == self.expr)
 
@@ -1146,10 +1146,10 @@ class FunctionToPrimitiveMapper(IdentityMapper):
     turns those into the actual pymbolic primitives used for that.
     """
 
-    def _parse_reduction(self, operation, inames, red_exprs,
+    def _parse_reduction(self, function, inames, red_exprs,
             allow_simultaneous=False):
-        assert isinstance(operation, str)
-        operation = p.Variable(operation)
+        assert isinstance(function, str)
+        function = p.Variable(function)
         if isinstance(inames, p.Variable):
             inames = (inames,)
 
@@ -1168,7 +1168,7 @@ class FunctionToPrimitiveMapper(IdentityMapper):
         if len(red_exprs) == 1:
             red_exprs = red_exprs[0]
 
-        return Reduction(operation, tuple(processed_inames), red_exprs,
+        return Reduction(function, tuple(processed_inames), red_exprs,
                 allow_simultaneous=allow_simultaneous)
 
     def map_call(self, expr):
@@ -1194,10 +1194,10 @@ class FunctionToPrimitiveMapper(IdentityMapper):
 
         elif name in set(["reduce, simul_reduce"]):
             if len(expr.parameters) >= 3:
-                operation, inames = expr.parameters[:2]
+                function, inames = expr.parameters[:2]
                 red_exprs = expr.parameters[2:]
 
-                return self._parse_reduction(str(operation), inames,
+                return self._parse_reduction(str(function), inames,
                         tuple(self.rec(red_expr) for red_expr in red_exprs),
                         allow_simultaneous=(name == "simul_reduce"))
             else:
diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py
index 7aec34a22a4a4a6282ab6065abd412f16612bcdc..7ffd91309b9b82a86c17357abca9fe381a145e33 100644
--- a/loopy/target/opencl.py
+++ b/loopy/target/opencl.py
@@ -255,7 +255,7 @@ def opencl_with_types(in_knl_callable, arg_id_to_dtype):
 
         dtype = np.find_common_type(
                 [], [dtype.numpy_dtype for id, dtype in
-                    arg_id_to_dtype.values() if id >= 0])
+                    arg_id_to_dtype.items() if id >= 0])
 
         if dtype.kind == "i":
             dtype = NumpyType(dtype)
diff --git a/loopy/type_inference.py b/loopy/type_inference.py
index 11113538eeadb4fc103bdc15cacc54899bcb864d..8df9773a98dcafe809f81de88f068e4948ace410 100644
--- a/loopy/type_inference.py
+++ b/loopy/type_inference.py
@@ -396,7 +396,10 @@ class TypeInferenceMapper(CombineMapper):
         from loopy.symbolic import Reduction
         from pymbolic.primitives import Call
 
-        if not return_tuple and expr.is_tuple_typed:
+        reduction_callable = self.scoped_functions[
+                expr.function.name]
+
+        if not return_tuple and reduction_callable.is_tuple_typed:
             raise LoopyError("reductions with more or fewer than one "
                              "return value may only be used in direct "
                              "assignments")
@@ -416,12 +419,23 @@ class TypeInferenceMapper(CombineMapper):
             else:
                 rec_results = self.rec(expr.expr)
 
-        if return_tuple:
-            return [expr.operation.result_dtypes(self.kernel, *rec_result)
-                    for rec_result in rec_results]
-        else:
-            return [expr.operation.result_dtypes(self.kernel, rec_result)[0]
-                    for rec_result in rec_results]
+        arg_id_to_dtype = dict(enumerate(rec_results))
+
+        in_knl_callable = (
+                self.scoped_functions[expr.function.name].with_types(
+                    arg_id_to_dtype, self.kernel.target))
+
+        # storing the type specialized function so that it can be used for
+        # later use
+        self.specialized_functions[expr] = in_knl_callable
+
+        new_arg_id_to_dtype = in_knl_callable.arg_id_to_dtype
+
+        # collecting result dtypes in order of the assignees
+        if -1 in new_arg_id_to_dtype and new_arg_id_to_dtype[-1] is not None:
+            return [new_arg_id_to_dtype[-1]]
+
+        return []
 
     def map_sub_array_ref(self, expr):
         return self.rec(expr.get_begin_subscript())
@@ -691,8 +705,9 @@ def infer_arg_and_reduction_dtypes_for_reduction_expression(
         kernel, expr, unknown_types_ok):
     type_inf_mapper = TypeInferenceMapper(kernel)
     import loopy as lp
+    callable_reduction = kernel.scoped_functions[expr.function.name]
 
-    if expr.is_tuple_typed:
+    if callable_reduction.is_tuple_typed:
         arg_dtypes_result = type_inf_mapper(
                 expr, return_tuple=True, return_dtype_set=True)
 
@@ -700,7 +715,7 @@ def infer_arg_and_reduction_dtypes_for_reduction_expression(
             arg_dtypes = arg_dtypes_result[0]
         else:
             if unknown_types_ok:
-                arg_dtypes = [lp.auto] * expr.operation.arg_count
+                arg_dtypes = [lp.auto] * callable_reduction.operation.arg_count
             else:
                 raise LoopyError("failed to determine types of accumulators for "
                         "reduction '%s'" % expr)
@@ -714,13 +729,22 @@ def infer_arg_and_reduction_dtypes_for_reduction_expression(
                 raise LoopyError("failed to determine type of accumulator for "
                         "reduction '%s'" % expr)
 
-    reduction_dtypes = expr.operation.result_dtypes(kernel, *arg_dtypes)
-    reduction_dtypes = tuple(
-            dt.with_target(kernel.target)
-            if dt is not lp.auto else dt
-            for dt in reduction_dtypes)
+    # TODODODODODODODODODO
+
+    new_arg_id_to_dtype = callable_reduction.with_types(
+            dict(enumerate(arg_dtypes)), kernel.target).arg_id_to_dtype
+
+    num_result = len([id for id in new_arg_id_to_dtype if id < 0])
+    reduction_dtypes = []
+
+    for id in range(num_result):
+        dt = new_arg_id_to_dtype[-id-1]
+        if dt is not lp.auto:
+            reduction_dtypes.append(dt.with_target(kernel.target))
+        else:
+            reduction_dtypes.append(dt)
 
-    return tuple(arg_dtypes), reduction_dtypes
+    return tuple(arg_dtypes), tuple(reduction_dtypes)
 
 # }}}