From 4da8612cccfc89d04497923ffbeca048cb7c3001 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 27 Mar 2013 20:17:25 -0400
Subject: [PATCH] Template renderer: don't replace type aliases textually, dumb
 idea.  Also, rename type_values -> type_aliases.

---
 pyopencl/algorithm.py   |  8 ++++----
 pyopencl/elementwise.py |  6 +++---
 pyopencl/scan.py        |  8 ++++----
 pyopencl/tools.py       | 40 ++++++++++++++++++++++++----------------
 4 files changed, 35 insertions(+), 27 deletions(-)

diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py
index 350f9d2e..585dd4d7 100644
--- a/pyopencl/algorithm.py
+++ b/pyopencl/algorithm.py
@@ -75,7 +75,7 @@ def copy_if(ary, predicate, extra_args=[], queue=None, preamble=""):
     extra_args_values = tuple(val for name, val in extra_args)
 
     knl = _copy_if_template.build(ary.context,
-            type_values=(("scan_t", scan_dtype), ("item_t", ary.dtype)),
+            type_aliases=(("scan_t", scan_dtype), ("item_t", ary.dtype)),
             var_values=(("predicate", predicate),),
             more_preamble=preamble, more_arguments=extra_args_types)
     out = cl.array.empty_like(ary)
@@ -152,7 +152,7 @@ def partition(ary, predicate, extra_args=[], queue=None, preamble=""):
 
     knl = _partition_template.build(
             ary.context,
-            type_values=(("item_t", ary.dtype), ("scan_t", scan_dtype)),
+            type_aliases=(("item_t", ary.dtype), ("scan_t", scan_dtype)),
             var_values=(("predicate", predicate),),
             more_preamble=preamble, more_arguments=extra_args_types)
 
@@ -214,7 +214,7 @@ def unique(ary, is_equal_expr="a == b", extra_args=[], queue=None, preamble=""):
 
     knl = _unique_template.build(
             ary.context,
-            type_values=(("item_t", ary.dtype), ("scan_t", scan_dtype)),
+            type_aliases=(("item_t", ary.dtype), ("scan_t", scan_dtype)),
             var_values=(("macro_is_equal_expr", is_equal_expr),),
             more_preamble=preamble, more_arguments=extra_args_types)
 
@@ -1075,7 +1075,7 @@ class KeyValueSorter(object):
                     key_group_starts[my_key] = i;
                 """,
                 name="find_starts").build(self.context,
-                        type_values=(
+                        type_aliases=(
                             ("key_t", starts_dtype),
                             ("starts_t", starts_dtype),
                             ),
diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py
index 6fb33c6d..7b7f5340 100644
--- a/pyopencl/elementwise.py
+++ b/pyopencl/elementwise.py
@@ -312,10 +312,10 @@ class ElementwiseTemplate(KernelTemplateBase):
         self.name = name
         self.preamble = preamble
 
-    def build_inner(self, context, type_values, var_values,
+    def build_inner(self, context, type_aliases, var_values,
             more_preamble="", more_arguments=(), declare_types=(),
             options=()):
-        renderer = self.get_renderer(type_values, var_values, context, options)
+        renderer = self.get_renderer(type_aliases, var_values, context, options)
 
         arg_list = renderer.render_argument_list(self.arguments, more_arguments)
         type_decl_preamble = renderer.get_type_decl_preamble(
@@ -327,7 +327,7 @@ class ElementwiseTemplate(KernelTemplateBase):
             preamble=(
                 type_decl_preamble
                 + "\n" + renderer(self.preamble + "\n" + more_preamble)),
-            auto_preamble=True)
+            auto_preamble=False)
 
 # }}}
 
diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index a6d69a3a..a56f6a07 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -1538,18 +1538,18 @@ class ScanTemplate(KernelTemplateBase):
         self.name_prefix = name_prefix
         self.preamble = preamble
 
-    def build_inner(self, context, type_values, var_values,
+    def build_inner(self, context, type_aliases, var_values,
             more_preamble="", more_arguments=(), declare_types=(),
             options=(), devices=None, scan_cls=GenericScanKernel):
-        renderer = self.get_renderer(type_values, var_values, context, options)
+        renderer = self.get_renderer(type_aliases, var_values, context, options)
 
-        return scan_cls(context, renderer.type_dict["scan_t"],
+        return scan_cls(context, renderer.type_aliases["scan_t"],
             renderer.render_argument_list(self.arguments, more_arguments),
             renderer(self.input_expr), renderer(self.scan_expr), renderer(self.neutral),
             renderer(self.output_statement),
             is_segment_start_expr=renderer(self.is_segment_start_expr),
             input_fetch_exprs=self.input_fetch_exprs,
-            index_dtype=renderer.type_dict.get("index_t", np.int32),
+            index_dtype=renderer.type_aliases.get("index_t", np.int32),
             name_prefix=renderer(self.name_prefix), options=list(options),
             preamble=renderer(more_preamble+"\n"+self.preamble), devices=devices)
 
diff --git a/pyopencl/tools.py b/pyopencl/tools.py
index fda7f5af..07e1b266 100644
--- a/pyopencl/tools.py
+++ b/pyopencl/tools.py
@@ -371,7 +371,12 @@ class _CDeclList:
         if dtype in self.declared_dtypes:
             return
 
-        for name, (field_dtype, offset) in dtype.fields.iteritems():
+        from pyopencl.array import vec
+        if dtype in vec.type_to_scalar_and_count:
+            return
+
+        for name, field_data in dtype.fields.iteritems():
+            field_dtype, offset = field_data[:2]
             self.add_dtype(field_dtype)
 
         _, cdecl = match_dtype_to_c_struct(self.device, dtype_to_ctype(dtype), dtype)
@@ -391,15 +396,15 @@ class _CDeclList:
     def get_declarations(self):
         result = "\n\n".join(self.declarations)
 
-        if self.saw_double:
+        if self.saw_complex:
             result = (
-                    "#pragma OPENCL EXTENSION cl_khr_fp64: enable\n"
-                    "#define PYOPENCL_DEFINE_CDOUBLE\n"
+                    "#include <pyopencl-complex.h>\n\n"
                     + result)
 
-        if self.saw_complex:
+        if self.saw_double:
             result = (
-                    "#include <pyopencl-complex.h>\n\n"
+                    "#pragma OPENCL EXTENSION cl_khr_fp64: enable\n"
+                    "#define PYOPENCL_DEFINE_CDOUBLE\n"
                     + result)
 
         return result
@@ -607,9 +612,9 @@ class _ScalarArgPlaceholder(_ArgumentPlaceholder):
 
 
 class _TemplateRenderer(object):
-    def __init__(self, template, type_values, var_values, context=None, options=[]):
+    def __init__(self, template, type_aliases, var_values, context=None, options=[]):
         self.template = template
-        self.type_dict = dict(type_values)
+        self.type_aliases = dict(type_aliases)
         self.var_dict = dict(var_values)
 
         for name in self.var_dict:
@@ -625,10 +630,6 @@ class _TemplateRenderer(object):
 
         result = self.template.get_text_template(txt).render(self.var_dict)
 
-        # substitute in types
-        for name, dtype in self.type_dict.iteritems():
-            result = re.sub(r"\b%s\b" % name, dtype_to_ctype(dtype), result)
-
         return str(result)
 
     def get_rendered_kernel(self, txt, kernel_name):
@@ -643,7 +644,7 @@ class _TemplateRenderer(object):
     def parse_type(self, typename):
         if isinstance(typename, str):
             try:
-                return self.type_dict[typename]
+                return self.type_aliases[typename]
             except KeyError:
                 from pyopencl.compyte.dtypes import NAME_TO_DTYPE
                 return NAME_TO_DTYPE[typename]
@@ -702,8 +703,15 @@ class _TemplateRenderer(object):
         if arguments is not None:
             cdl.visit_arguments(arguments)
 
-        return cdl.get_declarations()
+        for tv in self.type_aliases.itervalues():
+            cdl.add_dtype(tv)
+
+        type_alias_decls = [
+                "typedef %s %s;" % (dtype_to_ctype(val), name)
+                for name, val in self.type_aliases.iteritems()
+                ]
 
+        return cdl.get_declarations() + "\n" + "\n".join(type_alias_decls)
 
 
 
@@ -741,8 +749,8 @@ class KernelTemplateBase(object):
         else:
             raise RuntimeError("unknown template processor '%s'" % proc_match.group(1))
 
-    def get_renderer(self, type_values, var_values, context=None, options=[]):
-        return _TemplateRenderer(self, type_values, var_values)
+    def get_renderer(self, type_aliases, var_values, context=None, options=[]):
+        return _TemplateRenderer(self, type_aliases, var_values)
 
     def build(self, context, *args, **kwargs):
         """Provide caching for an :meth:`build_inner`."""
-- 
GitLab