diff --git a/.gitmodules b/.gitmodules
index 504e23cf344e2d5ae35f6f6abe97458b8c7a39b8..41cf31d9cfcf85de99569ea131c3f367e68a3436 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,3 +1,3 @@
-[submodule "loopy/target/opencl/compyte"]
-	path = loopy/target/opencl/compyte
+[submodule "loopy/target/c/compyte"]
+	path = loopy/target/c/compyte
 	url = https://github.com/inducer/compyte
diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index c14c1ffa24157a51c58b31a166913e94a558e062..57393bd6defd4bc3a6bfcdc7e4b90f1a907fba24 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -160,23 +160,26 @@ class CodeGenerationState(object):
         A :class:`frozenset` of predicates for which checks have been
         implemented.
 
-    .. attribute:: c_code_mapper
+    .. attribute:: expression_to_code_mapper
 
         A :class:`loopy.codegen.expression.CCodeMapper` that does not take
         per-ILP assignments into account.
     """
-    def __init__(self, implemented_domain, implemented_predicates, c_code_mapper):
+    def __init__(self, implemented_domain, implemented_predicates,
+            expression_to_code_mapper):
         self.implemented_domain = implemented_domain
         self.implemented_predicates = implemented_predicates
-        self.c_code_mapper = c_code_mapper
+        self.expression_to_code_mapper = expression_to_code_mapper
 
     def copy(self, implemented_domain=None, implemented_predicates=frozenset(),
-            c_code_mapper=None):
+            expression_to_code_mapper=None):
         return CodeGenerationState(
                 implemented_domain=implemented_domain or self.implemented_domain,
                 implemented_predicates=(
                     implemented_predicates or self.implemented_predicates),
-                c_code_mapper=c_code_mapper or self.c_code_mapper)
+                expression_to_code_mapper=(
+                    expression_to_code_mapper
+                    or self.expression_to_code_mapper))
 
     def intersect(self, other):
         new_impl, new_other = isl.align_two(self.implemented_domain, other)
@@ -205,7 +208,8 @@ class CodeGenerationState(object):
         new_impl_domain = new_impl_domain.add_constraint(cns)
         return self.copy(
                 implemented_domain=new_impl_domain,
-                c_code_mapper=self.c_code_mapper.copy_and_assign(iname, expr))
+                expression_to_code_mapper=(
+                    self.expression_to_code_mapper.copy_and_assign(iname, expr)))
 
 # }}}
 
@@ -380,15 +384,11 @@ def generate_code(kernel, device=None):
         if var.dtype.kind == "c":
             allow_complex = True
 
+    mod = []
+
     seen_dtypes = set()
     seen_functions = set()
 
-    from loopy.codegen.expression import LoopyCCodeMapper
-    ccm = (LoopyCCodeMapper(kernel, seen_dtypes, seen_functions,
-        allow_complex=allow_complex))
-
-    mod = []
-
     body = Block()
 
     # {{{ examine arg list
@@ -446,7 +446,8 @@ def generate_code(kernel, device=None):
     codegen_state = CodeGenerationState(
             implemented_domain=initial_implemented_domain,
             implemented_predicates=frozenset(),
-            c_code_mapper=ccm)
+            expression_to_code_mapper=kernel.target.get_expression_to_code_mapper(
+                kernel, seen_dtypes, seen_functions, allow_complex))
 
     from loopy.codegen.loop import set_up_hw_parallel_loops
     gen_code = set_up_hw_parallel_loops(kernel, 0, codegen_state)
diff --git a/loopy/codegen/bounds.py b/loopy/codegen/bounds.py
index c9e79f6bcc17f3ed3166dfc18cc012721410af6c..19ac4106ba58821e5d4bf5231eb530977739c7f3 100644
--- a/loopy/codegen/bounds.py
+++ b/loopy/codegen/bounds.py
@@ -28,14 +28,14 @@ from islpy import dim_type
 from pymbolic.mapper.stringifier import PREC_NONE
 
 
-def constraint_to_code(ccm, cns):
+def constraint_to_code(ecm, cns):
     if cns.is_equality():
         comp_op = "=="
     else:
         comp_op = ">="
 
     from loopy.symbolic import constraint_to_expr
-    return "%s %s 0" % (ccm(constraint_to_expr(cns), PREC_NONE, "i"), comp_op)
+    return "%s %s 0" % (ecm(constraint_to_expr(cns), PREC_NONE, "i"), comp_op)
 
 
 # {{{ bounds check generator
diff --git a/loopy/codegen/control.py b/loopy/codegen/control.py
index 3927293ae82dc521327a8c17afae5e93be1e3aa9..b28244a8c01e72611244f33901793c37804b00ad 100644
--- a/loopy/codegen/control.py
+++ b/loopy/codegen/control.py
@@ -77,7 +77,7 @@ def generate_code_for_sched_index(kernel, sched_index, codegen_state):
 
     elif isinstance(sched_item, Barrier):
         from loopy.codegen import GeneratedInstruction
-        from cgen import Statement as S
+        from cgen import Statement as S  # noqa
 
         if sched_item.comment:
             comment = " /* %s */" % sched_item.comment
@@ -362,7 +362,8 @@ def build_loop_nest(kernel, sched_index, codegen_state):
                 from loopy.codegen.bounds import constraint_to_code
 
                 conditionals = [
-                        constraint_to_code(codegen_state.c_code_mapper, cns)
+                        constraint_to_code(
+                            codegen_state.expression_to_code_mapper, cns)
                         for cns in bounds_checks] + list(pred_checks)
 
                 result = [wrap_in_if(conditionals, gen_code_block(result))]
diff --git a/loopy/codegen/instruction.py b/loopy/codegen/instruction.py
index 1bd977f3ea2e8e15e63424e89b1e347063da70e2..b98716a4e6281ff559abc670a8e1b0fe5d67976d 100644
--- a/loopy/codegen/instruction.py
+++ b/loopy/codegen/instruction.py
@@ -45,7 +45,9 @@ def wrap_in_conditionals(codegen_state, domain, check_inames, required_preds, st
     if bounds_check_set.is_empty():
         return None, None
 
-    condition_codelets = [constraint_to_code(codegen_state.c_code_mapper, cns)
+    condition_codelets = [
+            constraint_to_code(
+                codegen_state.expression_to_code_mapper, cns)
             for cns in bounds_checks]
 
     condition_codelets.extend(
@@ -86,7 +88,7 @@ def generate_instruction_code(kernel, insn, codegen_state):
 
 
 def generate_expr_instruction_code(kernel, insn, codegen_state):
-    ccm = codegen_state.c_code_mapper
+    ecm = codegen_state.expression_to_code_mapper
 
     expr = insn.expression
 
@@ -94,16 +96,16 @@ def generate_expr_instruction_code(kernel, insn, codegen_state):
     target_dtype = kernel.get_var_descriptor(assignee_var_name).dtype
 
     from cgen import Assign
-    from loopy.codegen.expression import dtype_to_type_context
-    lhs_code = ccm(insn.assignee, prec=PREC_NONE, type_context=None)
+    from loopy.expression import dtype_to_type_context
+    lhs_code = ecm(insn.assignee, prec=PREC_NONE, type_context=None)
     result = Assign(
             lhs_code,
-            ccm(expr, prec=PREC_NONE,
+            ecm(expr, prec=PREC_NONE,
                 type_context=dtype_to_type_context(kernel.target, target_dtype),
                 needed_dtype=target_dtype))
 
     if kernel.options.trace_assignments or kernel.options.trace_assignment_values:
-        from cgen import Statement as S
+        from cgen import Statement as S  # noqa
 
         gs, ls = kernel.get_grid_sizes()
 
@@ -123,7 +125,7 @@ def generate_expr_instruction_code(kernel, insn, codegen_state):
         if assignee_indices:
             printf_format += "[%s]" % ",".join(len(assignee_indices) * ["%d"])
             printf_args.extend(
-                    ccm(i, prec=PREC_NONE, type_context="i")
+                    ecm(i, prec=PREC_NONE, type_context="i")
                     for i in assignee_indices)
 
         if kernel.options.trace_assignment_values:
@@ -158,7 +160,7 @@ def generate_expr_instruction_code(kernel, insn, codegen_state):
 
 
 def generate_c_instruction_code(kernel, insn, codegen_state):
-    ccm = codegen_state.c_code_mapper
+    ecm = codegen_state.expression_to_code_mapper
 
     body = []
 
@@ -168,14 +170,14 @@ def generate_c_instruction_code(kernel, insn, codegen_state):
     from pymbolic.primitives import Variable
     for name, iname_expr in insn.iname_exprs:
         if (isinstance(iname_expr, Variable)
-                and name not in ccm.var_subst_map):
+                and name not in ecm.var_subst_map):
             # No need, the bare symbol will work
             continue
 
         body.append(
                 Initializer(
                     POD(kernel.target, kernel.index_dtype, name),
-                    codegen_state.c_code_mapper(
+                    codegen_state.expression_to_code_mapper(
                         iname_expr, prec=PREC_NONE, type_context="i")))
 
     if body:
diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py
index 40f433a90f90855aa554261b8d6fb5cbdd5fb0ec..f27458fec99ddabc97df1fa35fcf384000f8c73f 100644
--- a/loopy/codegen/loop.py
+++ b/loopy/codegen/loop.py
@@ -250,8 +250,9 @@ def set_up_hw_parallel_loops(kernel, sched_index, codegen_state,
         # slabbing conditionals.
         slabbed_kernel = intersect_kernel_with_slab(kernel, slab, iname)
         new_codegen_state = codegen_state.copy(
-                c_code_mapper=codegen_state.c_code_mapper.copy_and_assign(
-                    iname, hw_axis_expr))
+                expression_to_code_mapper=(
+                    codegen_state.expression_to_code_mapper.copy_and_assign(
+                        iname, hw_axis_expr)))
 
         inner = set_up_hw_parallel_loops(
                 slabbed_kernel, sched_index,
@@ -268,7 +269,7 @@ def set_up_hw_parallel_loops(kernel, sched_index, codegen_state,
 # {{{ sequential loop
 
 def generate_sequential_loop_dim_code(kernel, sched_index, codegen_state):
-    ccm = codegen_state.c_code_mapper
+    ecm = codegen_state.expression_to_code_mapper
     loop_iname = kernel.schedule[sched_index].iname
 
     slabs = get_slab_decomposition(
@@ -362,7 +363,7 @@ def generate_sequential_loop_dim_code(kernel, sched_index, codegen_state):
             # single-trip, generate just a variable assignment, not a loop
             result.append(gen_code_block([
                 Initializer(Const(POD(kernel.index_dtype, loop_iname)),
-                    ccm(aff_to_expr(static_lbound), PREC_NONE, "i")),
+                    ecm(aff_to_expr(static_lbound), PREC_NONE, "i")),
                 Line(),
                 inner,
                 ]))
@@ -373,9 +374,9 @@ def generate_sequential_loop_dim_code(kernel, sched_index, codegen_state):
             result.append(wrap_in(For,
                     "%s %s = %s"
                     % (kernel.target.dtype_to_typename(kernel.index_dtype),
-                        loop_iname, ccm(aff_to_expr(static_lbound), PREC_NONE, "i")),
+                        loop_iname, ecm(aff_to_expr(static_lbound), PREC_NONE, "i")),
                     "%s <= %s" % (
-                        loop_iname, ccm(aff_to_expr(static_ubound), PREC_NONE, "i")),
+                        loop_iname, ecm(aff_to_expr(static_ubound), PREC_NONE, "i")),
                     "++%s" % loop_iname,
                     inner))
 
diff --git a/loopy/expression.py b/loopy/expression.py
new file mode 100644
index 0000000000000000000000000000000000000000..2afb803b97b0e5b9f9ff1510da92ad10a50711dc
--- /dev/null
+++ b/loopy/expression.py
@@ -0,0 +1,257 @@
+from __future__ import division, absolute_import
+
+__copyright__ = "Copyright (C) 2012-15 Andreas Kloeckner"
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+
+import numpy as np
+
+from pymbolic.mapper import CombineMapper
+
+from loopy.tools import is_integer
+from loopy.diagnostic import TypeInferenceFailure, DependencyTypeInferenceFailure
+
+
+# type_context may be:
+# - 'i' for integer -
+# - 'f' for single-precision floating point
+# - 'd' for double-precision floating point
+# or None for 'no known context'.
+
+def dtype_to_type_context(target, dtype):
+    dtype = np.dtype(dtype)
+
+    if dtype.kind == 'i':
+        return 'i'
+    if dtype in [np.float64, np.complex128]:
+        return 'd'
+    if dtype in [np.float32, np.complex64]:
+        return 'f'
+    if target.is_vector_dtype(dtype):
+        return dtype_to_type_context(target, dtype.fields["x"][0])
+
+    return None
+
+
+# {{{ type inference
+
+class TypeInferenceMapper(CombineMapper):
+    def __init__(self, kernel, new_assignments=None):
+        """
+        :arg new_assignments: mapping from names to either
+            :class:`loopy.kernel.data.TemporaryVariable`
+            or
+            :class:`loopy.kernel.data.KernelArgument`
+            instances
+        """
+        self.kernel = kernel
+        if new_assignments is None:
+            new_assignments = {}
+        self.new_assignments = new_assignments
+
+    # /!\ Introduce caches with care--numpy.float32(x) and numpy.float64(x)
+    # are Python-equal (for many common constants such as integers).
+
+    @staticmethod
+    def combine(dtypes):
+        dtypes = list(dtypes)
+
+        result = dtypes.pop()
+        while dtypes:
+            other = dtypes.pop()
+
+            if result.isbuiltin and other.isbuiltin:
+                if (result, other) in [
+                        (np.int32, np.float32), (np.int32, np.float32)]:
+                    # numpy makes this a double. I disagree.
+                    result = np.dtype(np.float32)
+                else:
+                    result = (
+                            np.empty(0, dtype=result)
+                            + np.empty(0, dtype=other)
+                            ).dtype
+            elif result.isbuiltin and not other.isbuiltin:
+                # assume the non-native type takes over
+                result = other
+            elif not result.isbuiltin and other.isbuiltin:
+                # assume the non-native type takes over
+                pass
+            else:
+                if result is not other:
+                    raise TypeInferenceFailure(
+                            "nothing known about result of operation on "
+                            "'%s' and '%s'" % (result, other))
+
+        return result
+
+    def map_sum(self, expr):
+        dtypes = []
+        small_integer_dtypes = []
+        for child in expr.children:
+            dtype = self.rec(child)
+            if is_integer(child) and abs(child) < 1024:
+                small_integer_dtypes.append(dtype)
+            else:
+                dtypes.append(dtype)
+
+        from pytools import all
+        if all(dtype.kind == "i" for dtype in dtypes):
+            dtypes.extend(small_integer_dtypes)
+
+        return self.combine(dtypes)
+
+    map_product = map_sum
+
+    def map_quotient(self, expr):
+        n_dtype = self.rec(expr.numerator)
+        d_dtype = self.rec(expr.denominator)
+
+        if n_dtype.kind in "iu" and d_dtype.kind in "iu":
+            # both integers
+            return np.dtype(np.float64)
+
+        else:
+            return self.combine([n_dtype, d_dtype])
+
+    def map_constant(self, expr):
+        if is_integer(expr):
+            for tp in [np.int32, np.int64]:
+                iinfo = np.iinfo(tp)
+                if iinfo.min <= expr <= iinfo.max:
+                    return np.dtype(tp)
+
+            else:
+                raise TypeInferenceFailure("integer constant '%s' too large" % expr)
+
+        dt = np.asarray(expr).dtype
+        if hasattr(expr, "dtype"):
+            return expr.dtype
+        elif isinstance(expr, np.number):
+            # Numpy types are sized
+            return np.dtype(type(expr))
+        elif dt.kind == "f":
+            # deduce the smaller type by default
+            return np.dtype(np.float32)
+        elif dt.kind == "c":
+            if np.complex64(expr) == np.complex128(expr):
+                # (COMPLEX_GUESS_LOGIC)
+                # No precision is lost by 'guessing' single precision, use that.
+                # This at least covers simple cases like '1j'.
+                return np.dtype(np.complex64)
+
+            # Codegen for complex types depends on exactly correct types.
+            # Refuse temptation to guess.
+            raise TypeInferenceFailure("Complex constant '%s' needs to "
+                    "be sized for type inference " % expr)
+        else:
+            raise TypeInferenceFailure("Cannot deduce type of constant '%s'" % expr)
+
+    def map_subscript(self, expr):
+        return self.rec(expr.aggregate)
+
+    def map_linear_subscript(self, expr):
+        return self.rec(expr.aggregate)
+
+    def map_call(self, expr):
+        from pymbolic.primitives import Variable
+
+        identifier = expr.function
+        if isinstance(identifier, Variable):
+            identifier = identifier.name
+
+        arg_dtypes = tuple(self.rec(par) for par in expr.parameters)
+
+        mangle_result = self.kernel.mangle_function(identifier, arg_dtypes)
+        if mangle_result is not None:
+            return mangle_result[0]
+
+        raise RuntimeError("no type inference information on "
+                "function '%s'" % identifier)
+
+    def map_variable(self, expr):
+        if expr.name in self.kernel.all_inames():
+            return self.kernel.index_dtype
+
+        result = self.kernel.mangle_symbol(expr.name)
+        if result is not None:
+            result_dtype, _ = result
+            return result_dtype
+
+        obj = self.new_assignments.get(expr.name)
+
+        if obj is None:
+            obj = self.kernel.arg_dict.get(expr.name)
+
+        if obj is None:
+            obj = self.kernel.temporary_variables.get(expr.name)
+
+        if obj is None:
+            raise TypeInferenceFailure("name not known in type inference: %s"
+                    % expr.name)
+
+        from loopy.kernel.data import TemporaryVariable, KernelArgument
+        import loopy as lp
+        if isinstance(obj, TemporaryVariable):
+            result = obj.dtype
+            if result is lp.auto:
+                raise DependencyTypeInferenceFailure(
+                        "temporary variable '%s'" % expr.name,
+                        expr.name)
+            else:
+                return result
+
+        elif isinstance(obj, KernelArgument):
+            result = obj.dtype
+            if result is None:
+                raise DependencyTypeInferenceFailure(
+                        "argument '%s'" % expr.name,
+                        expr.name)
+            else:
+                return result
+
+        else:
+            raise RuntimeError("unexpected type inference "
+                    "object type for '%s'" % expr.name)
+
+    map_tagged_variable = map_variable
+
+    def map_lookup(self, expr):
+        agg_result = self.rec(expr.aggregate)
+        dtype, offset = agg_result.fields[expr.name]
+        return dtype
+
+    def map_comparison(self, expr):
+        # "bool" is unusable because OpenCL's bool has indeterminate memory
+        # format.
+        return np.dtype(np.int32)
+
+    map_logical_not = map_comparison
+    map_logical_and = map_comparison
+    map_logical_or = map_comparison
+
+    def map_reduction(self, expr):
+        return expr.operation.result_dtype(
+                self.kernel.target, self.rec(expr.expr), expr.inames)
+
+# }}}
+
+# vim: fdm=marker
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index f9cd3f3eb2140354decf61b475ba1742d5ea6832..9de31233b831215bba72ba1e226b0df555b42286 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -162,7 +162,7 @@ def infer_unknown_types(kernel, expect_completion=False):
 
     # }}}
 
-    from loopy.codegen.expression import TypeInferenceMapper
+    from loopy.expression import TypeInferenceMapper
     type_inf_mapper = TypeInferenceMapper(kernel,
             _DictUnionView([
                 new_temp_vars,
@@ -401,7 +401,7 @@ def realize_reduction(kernel, insn_id_filter=None):
     var_name_gen = kernel.get_var_name_generator()
     new_temporary_variables = kernel.temporary_variables.copy()
 
-    from loopy.codegen.expression import TypeInferenceMapper
+    from loopy.expression import TypeInferenceMapper
     type_inf_mapper = TypeInferenceMapper(kernel)
 
     def map_reduction(expr, rec):
diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb142773929cb157523fad4f9e08b57235e73b28
--- /dev/null
+++ b/loopy/target/c/__init__.py
@@ -0,0 +1,58 @@
+"""OpenCL target independent of PyOpenCL."""
+
+from __future__ import division, absolute_import
+
+__copyright__ = "Copyright (C) 2015 Andreas Kloeckner"
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+import numpy as np  # noqa
+from loopy.target import TargetBase
+
+from pytools import memoize_method
+
+
+class CTarget(TargetBase):
+    @memoize_method
+    def get_dtype_registry(self):
+        from loopy.target.c.compyte import (
+                DTypeRegistry, fill_with_registry_with_c_types)
+        result = DTypeRegistry()
+        fill_with_registry_with_c_types(result)
+        return result
+
+    def is_vector_dtype(self, dtype):
+        return False
+
+    def get_vector_dtype(self, base, count):
+        raise KeyError()
+
+    def get_or_register_dtype(self, names, dtype=None):
+        return self.get_dtype_registry().get_or_register_dtype(names, dtype)
+
+    def dtype_to_typename(self, dtype):
+        return self.get_dtype_registry().dtype_to_ctype(dtype)
+
+    def get_expression_to_code_mapper(self, kernel,
+            seen_dtypes, seen_functions, allow_complex):
+        from loopy.target.c.codegen.expression import LoopyCCodeMapper
+        return (LoopyCCodeMapper(kernel, seen_dtypes, seen_functions,
+            allow_complex=allow_complex))
diff --git a/loopy/target/c/codegen/__init__.py b/loopy/target/c/codegen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/loopy/codegen/expression.py b/loopy/target/c/codegen/expression.py
similarity index 78%
rename from loopy/codegen/expression.py
rename to loopy/target/c/codegen/expression.py
index f834c8fa50287c14aec1417ebef093dffd5062d9..ccfa88040e5def8e1a3035d236a5bd2a0f7b7d07 100644
--- a/loopy/codegen/expression.py
+++ b/loopy/target/c/codegen/expression.py
@@ -31,238 +31,12 @@ import numpy as np
 from pymbolic.mapper import RecursiveMapper
 from pymbolic.mapper.stringifier import (PREC_NONE, PREC_CALL, PREC_PRODUCT,
         PREC_POWER)
-from pymbolic.mapper import CombineMapper
 import islpy as isl
 from pytools import Record
 
-from loopy.tools import is_integer
-from loopy.diagnostic import TypeInferenceFailure, DependencyTypeInferenceFailure
-
-
-# {{{ type inference
-
-class TypeInferenceMapper(CombineMapper):
-    def __init__(self, kernel, new_assignments=None):
-        """
-        :arg new_assignments: mapping from names to either
-            :class:`loopy.kernel.data.TemporaryVariable`
-            or
-            :class:`loopy.kernel.data.KernelArgument`
-            instances
-        """
-        self.kernel = kernel
-        if new_assignments is None:
-            new_assignments = {}
-        self.new_assignments = new_assignments
-
-    # /!\ Introduce caches with care--numpy.float32(x) and numpy.float64(x)
-    # are Python-equal (for many common constants such as integers).
-
-    @staticmethod
-    def combine(dtypes):
-        dtypes = list(dtypes)
-
-        result = dtypes.pop()
-        while dtypes:
-            other = dtypes.pop()
-
-            if result.isbuiltin and other.isbuiltin:
-                if (result, other) in [
-                        (np.int32, np.float32), (np.int32, np.float32)]:
-                    # numpy makes this a double. I disagree.
-                    result = np.dtype(np.float32)
-                else:
-                    result = (
-                            np.empty(0, dtype=result)
-                            + np.empty(0, dtype=other)
-                            ).dtype
-            elif result.isbuiltin and not other.isbuiltin:
-                # assume the non-native type takes over
-                result = other
-            elif not result.isbuiltin and other.isbuiltin:
-                # assume the non-native type takes over
-                pass
-            else:
-                if result is not other:
-                    raise TypeInferenceFailure(
-                            "nothing known about result of operation on "
-                            "'%s' and '%s'" % (result, other))
-
-        return result
-
-    def map_sum(self, expr):
-        dtypes = []
-        small_integer_dtypes = []
-        for child in expr.children:
-            dtype = self.rec(child)
-            if is_integer(child) and abs(child) < 1024:
-                small_integer_dtypes.append(dtype)
-            else:
-                dtypes.append(dtype)
-
-        from pytools import all
-        if all(dtype.kind == "i" for dtype in dtypes):
-            dtypes.extend(small_integer_dtypes)
-
-        return self.combine(dtypes)
-
-    map_product = map_sum
-
-    def map_quotient(self, expr):
-        n_dtype = self.rec(expr.numerator)
-        d_dtype = self.rec(expr.denominator)
-
-        if n_dtype.kind in "iu" and d_dtype.kind in "iu":
-            # both integers
-            return np.dtype(np.float64)
+from loopy.expression import dtype_to_type_context, TypeInferenceMapper
 
-        else:
-            return self.combine([n_dtype, d_dtype])
-
-    def map_constant(self, expr):
-        if is_integer(expr):
-            for tp in [np.int32, np.int64]:
-                iinfo = np.iinfo(tp)
-                if iinfo.min <= expr <= iinfo.max:
-                    return np.dtype(tp)
-
-            else:
-                raise TypeInferenceFailure("integer constant '%s' too large" % expr)
-
-        dt = np.asarray(expr).dtype
-        if hasattr(expr, "dtype"):
-            return expr.dtype
-        elif isinstance(expr, np.number):
-            # Numpy types are sized
-            return np.dtype(type(expr))
-        elif dt.kind == "f":
-            # deduce the smaller type by default
-            return np.dtype(np.float32)
-        elif dt.kind == "c":
-            if np.complex64(expr) == np.complex128(expr):
-                # (COMPLEX_GUESS_LOGIC)
-                # No precision is lost by 'guessing' single precision, use that.
-                # This at least covers simple cases like '1j'.
-                return np.dtype(np.complex64)
-
-            # Codegen for complex types depends on exactly correct types.
-            # Refuse temptation to guess.
-            raise TypeInferenceFailure("Complex constant '%s' needs to "
-                    "be sized for type inference " % expr)
-        else:
-            raise TypeInferenceFailure("Cannot deduce type of constant '%s'" % expr)
-
-    def map_subscript(self, expr):
-        return self.rec(expr.aggregate)
-
-    def map_linear_subscript(self, expr):
-        return self.rec(expr.aggregate)
-
-    def map_call(self, expr):
-        from pymbolic.primitives import Variable
-
-        identifier = expr.function
-        if isinstance(identifier, Variable):
-            identifier = identifier.name
-
-        arg_dtypes = tuple(self.rec(par) for par in expr.parameters)
-
-        mangle_result = self.kernel.mangle_function(identifier, arg_dtypes)
-        if mangle_result is not None:
-            return mangle_result[0]
-
-        raise RuntimeError("no type inference information on "
-                "function '%s'" % identifier)
-
-    def map_variable(self, expr):
-        if expr.name in self.kernel.all_inames():
-            return self.kernel.index_dtype
-
-        result = self.kernel.mangle_symbol(expr.name)
-        if result is not None:
-            result_dtype, _ = result
-            return result_dtype
-
-        obj = self.new_assignments.get(expr.name)
-
-        if obj is None:
-            obj = self.kernel.arg_dict.get(expr.name)
-
-        if obj is None:
-            obj = self.kernel.temporary_variables.get(expr.name)
-
-        if obj is None:
-            raise TypeInferenceFailure("name not known in type inference: %s"
-                    % expr.name)
-
-        from loopy.kernel.data import TemporaryVariable, KernelArgument
-        import loopy as lp
-        if isinstance(obj, TemporaryVariable):
-            result = obj.dtype
-            if result is lp.auto:
-                raise DependencyTypeInferenceFailure(
-                        "temporary variable '%s'" % expr.name,
-                        expr.name)
-            else:
-                return result
-
-        elif isinstance(obj, KernelArgument):
-            result = obj.dtype
-            if result is None:
-                raise DependencyTypeInferenceFailure(
-                        "argument '%s'" % expr.name,
-                        expr.name)
-            else:
-                return result
-
-        else:
-            raise RuntimeError("unexpected type inference "
-                    "object type for '%s'" % expr.name)
-
-    map_tagged_variable = map_variable
-
-    def map_lookup(self, expr):
-        agg_result = self.rec(expr.aggregate)
-        dtype, offset = agg_result.fields[expr.name]
-        return dtype
-
-    def map_comparison(self, expr):
-        # "bool" is unusable because OpenCL's bool has indeterminate memory
-        # format.
-        return np.dtype(np.int32)
-
-    map_logical_not = map_comparison
-    map_logical_and = map_comparison
-    map_logical_or = map_comparison
-
-    def map_reduction(self, expr):
-        return expr.operation.result_dtype(
-                self.kernel.target, self.rec(expr.expr), expr.inames)
-
-# }}}
-
-
-# {{{ C code mapper
-
-# type_context may be:
-# - 'i' for integer -
-# - 'f' for single-precision floating point
-# - 'd' for double-precision floating point
-# or None for 'no known context'.
-
-def dtype_to_type_context(target, dtype):
-    dtype = np.dtype(dtype)
-
-    if dtype.kind == 'i':
-        return 'i'
-    if dtype in [np.float64, np.complex128]:
-        return 'd'
-    if dtype in [np.float32, np.complex64]:
-        return 'f'
-    if target.is_vector_dtype(dtype):
-        return dtype_to_type_context(target, dtype.fields["x"][0])
-
-    return None
+from loopy.tools import is_integer
 
 
 def get_opencl_vec_member(idx):
@@ -293,6 +67,8 @@ class SeenFunction(Record):
                 + tuple((f, getattr(self, f)) for f in type(self).fields))
 
 
+# {{{ C code mapper
+
 class LoopyCCodeMapper(RecursiveMapper):
     def __init__(self, kernel, seen_dtypes, seen_functions, var_subst_map={},
             allow_complex=False):
diff --git a/loopy/target/c/compyte b/loopy/target/c/compyte
new file mode 160000
index 0000000000000000000000000000000000000000..fb6ba114d9d906403d47b0aaf69e2fe4cef382f2
--- /dev/null
+++ b/loopy/target/c/compyte
@@ -0,0 +1 @@
+Subproject commit fb6ba114d9d906403d47b0aaf69e2fe4cef382f2
diff --git a/loopy/target/opencl/__init__.py b/loopy/target/opencl/__init__.py
index 04efdedabb6e281b99c8ffae3a5be53e084524ed..b297eb9ffbaa72538a9e49b2420322ae69c9ff87 100644
--- a/loopy/target/opencl/__init__.py
+++ b/loopy/target/opencl/__init__.py
@@ -26,44 +26,21 @@ THE SOFTWARE.
 
 import numpy as np
 
-from loopy.target import TargetBase
-
-
-# {{{ type registry
-
-def _register_types():
-    from loopy.target.opencl.compyte.dtypes import (
-            _fill_dtype_registry, get_or_register_dtype)
-    import struct
-
-    _fill_dtype_registry(respect_windows=False, include_bool=False)
-
-    # complex number support left out
-
-    is_64_bit = struct.calcsize('@P') * 8 == 64
-    if not is_64_bit:
-        get_or_register_dtype(
-                ["unsigned long", "unsigned long int"], np.uint64)
-        get_or_register_dtype(
-                ["signed long", "signed long int", "long int"], np.int64)
-
-_register_types()
-
-# }}}
+from loopy.target.c import CTarget
+from pytools import memoize_method
 
 
 # {{{ vector types
 
-class vec:
+class vec:  # noqa
     pass
 
 
 def _create_vector_types():
     field_names = ["x", "y", "z", "w"]
 
-    from loopy.target.opencl.compyte.dtypes import get_or_register_dtype
-
     vec.types = {}
+    vec.names_and_dtypes = []
     vec.type_to_scalar_and_count = {}
 
     counts = [2, 3, 4, 8, 16]
@@ -109,40 +86,20 @@ def _create_vector_types():
                     dtype = np.dtype([(n, base_type) for (n, title)
                                       in zip(names, titles)])
 
-            get_or_register_dtype(name, dtype)
-
             setattr(vec, name, dtype)
 
-            def create_array(dtype, count, padded_count, *args, **kwargs):
-                if len(args) < count:
-                    from warnings import warn
-                    warn("default values for make_xxx are deprecated;"
-                            " instead specify all parameters or use"
-                            " array.vec.zeros_xxx", DeprecationWarning)
-                padded_args = tuple(list(args)+[0]*(padded_count-len(args)))
-                array = eval("array(padded_args, dtype=dtype)",
-                        dict(array=np.array, padded_args=padded_args,
-                        dtype=dtype))
-                for key, val in kwargs.items():
-                    array[key] = val
-                return array
-
-            setattr(vec, "make_"+name, staticmethod(eval(
-                    "lambda *args, **kwargs: create_array(dtype, %i, %i, "
-                    "*args, **kwargs)" % (count, padded_count),
-                    dict(create_array=create_array, dtype=dtype))))
-            setattr(vec, "filled_"+name, staticmethod(eval(
-                    "lambda val: vec.make_%s(*[val]*%i)" % (name, count))))
-            setattr(vec, "zeros_"+name,
-                    staticmethod(eval("lambda: vec.filled_%s(0)" % (name))))
-            setattr(vec, "ones_"+name,
-                    staticmethod(eval("lambda: vec.filled_%s(1)" % (name))))
+            vec.names_and_dtypes.append((name, dtype))
 
             vec.types[np.dtype(base_type), count] = dtype
             vec.type_to_scalar_and_count[dtype] = np.dtype(base_type), count
 
 _create_vector_types()
 
+
+def _register_vector_types(dtype_registry):
+    for name, dtype in vec.names_and_dtypes:
+        dtype_registry.get_or_register_dtype(name, dtype)
+
 # }}}
 
 
@@ -234,7 +191,7 @@ def opencl_preamble_generator(target, seen_dtypes, seen_functions):
 
 # {{{ target
 
-class OpenCLTarget(TargetBase):
+class OpenCLTarget(CTarget):
     def function_manglers(self):
         return (
                 super(OpenCLTarget, self).function_manglers() + [
@@ -255,13 +212,24 @@ class OpenCLTarget(TargetBase):
                     reduction_preamble_generator
                     ])
 
-    def get_or_register_dtype(self, names, dtype=None):
-        from loopy.target.opencl.compyte.dtypes import get_or_register_dtype
-        return get_or_register_dtype(names, dtype)
+    @memoize_method
+    def get_dtype_registry(self):
+        from loopy.target.c.compyte import (
+                DTypeRegistry, fill_with_registry_with_c_types)
+        result = DTypeRegistry()
+        fill_with_registry_with_c_types(result)
+
+        # complex number support left out
+
+        # CL defines 'long' as 64-bit
+        result.get_or_register_dtype(
+                ["unsigned long", "unsigned long int"], np.uint64)
+        result.get_or_register_dtype(
+                ["signed long", "signed long int", "long int"], np.int64)
+
+        _register_vector_types(result)
 
-    def dtype_to_typename(self, dtype):
-        from loopy.target.opencl.compyte.dtypes import dtype_to_ctype
-        return dtype_to_ctype(dtype)
+        return result
 
     def is_vector_dtype(self, dtype):
         return list(vec.types.values())
diff --git a/loopy/target/opencl/compyte b/loopy/target/opencl/compyte
deleted file mode 160000
index 5d54e1b2b7f28d3e779029ac0b4aa5f957829f23..0000000000000000000000000000000000000000
--- a/loopy/target/opencl/compyte
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit 5d54e1b2b7f28d3e779029ac0b4aa5f957829f23
diff --git a/loopy/target/pyopencl/__init__.py b/loopy/target/pyopencl/__init__.py
index a1b323d7cf0deaef308952d87b1523b0a020ad68..a0a119dce42e1095ff327f38f6b836f80091b159 100644
--- a/loopy/target/pyopencl/__init__.py
+++ b/loopy/target/pyopencl/__init__.py
@@ -259,13 +259,9 @@ class PyOpenCLTarget(OpenCLTarget):
     def pre_codegen_check(self, kernel):
         check_sizes(kernel, self.device)
 
-    def get_or_register_dtype(self, names, dtype=None):
-        from pyopencl.compyte.dtypes import get_or_register_dtype
-        return get_or_register_dtype(names, dtype)
-
-    def dtype_to_typename(self, dtype):
-        from pyopencl.compyte.dtypes import dtype_to_ctype
-        return dtype_to_ctype(dtype)
+    def get_dtype_registry(self):
+        from pyopencl.compyte.dtypes import TYPE_REGISTRY
+        return TYPE_REGISTRY
 
     def is_vector_dtype(self, dtype):
         from pyopencl.array import vec