From 19345492cda6fc30ed43904e32221fb8a16c297c Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 12 Feb 2019 15:46:28 -0600
Subject: [PATCH] Refactor option treatment, don't pass include dir options to
 linker (Closes #13 on Gitlab)

---
 pyopencl/__init__.py | 120 +++++++++++++++++++++++++------------------
 1 file changed, 70 insertions(+), 50 deletions(-)

diff --git a/pyopencl/__init__.py b/pyopencl/__init__.py
index 7f77154f..810ff00e 100644
--- a/pyopencl/__init__.py
+++ b/pyopencl/__init__.py
@@ -265,6 +265,66 @@ def _find_pyopencl_include_path():
 # }}}
 
 
+# {{{ build option munging
+
+def _split_options_if_necessary(options):
+    if isinstance(options, six.string_types):
+        import shlex
+        if six.PY2:
+            # shlex.split takes bytes (py2 str) on py2
+            if isinstance(options, six.text_type):
+                options = options.encode("utf-8")
+        else:
+            # shlex.split takes unicode (py3 str) on py3
+            if isinstance(options, six.binary_type):
+                options = options.decode("utf-8")
+
+        options = shlex.split(options)
+
+    return options
+
+
+def _find_include_path(options):
+    def unquote(path):
+        if path.startswith('"') and path.endswith('"'):
+            return path[1:-1]
+        else:
+            return path
+
+    include_path = ["."]
+
+    option_idx = 0
+    while option_idx < len(options):
+        option = options[option_idx].strip()
+        if option.startswith("-I") or option.startswith("/I"):
+            if len(option) == 2:
+                if option_idx+1 < len(options):
+                    include_path.append(unquote(options[option_idx+1]))
+                option_idx += 2
+            else:
+                include_path.append(unquote(option[2:].lstrip()))
+                option_idx += 1
+        else:
+            option_idx += 1
+
+    # }}}
+
+    return include_path
+
+
+def _options_to_bytestring(options):
+    def encode_if_necessary(s):
+        if isinstance(s, six.text_type):
+            return s.encode("utf-8")
+        else:
+            return s
+
+    return b" ".join(encode_if_necessary(s) for s in options)
+
+
+# }}}
+
+
 # {{{ Program (wrapper around _Program, adds caching support)
 
 _DEFAULT_BUILD_OPTIONS = []
@@ -390,25 +450,8 @@ class Program(object):
     # {{{ build
 
     @classmethod
-    def _process_build_options(cls, context, options):
-        if isinstance(options, six.string_types):
-            import shlex
-            if six.PY2:
-                # shlex.split takes bytes (py2 str) on py2
-                if isinstance(options, six.text_type):
-                    options = options.encode("utf-8")
-            else:
-                # shlex.split takes unicode (py3 str) on py3
-                if isinstance(options, six.binary_type):
-                    options = options.decode("utf-8")
-
-            options = shlex.split(options)
-
-        def encode_if_necessary(s):
-            if isinstance(s, six.text_type):
-                return s.encode("utf-8")
-            else:
-                return s
+    def _process_build_options(cls, context, options, _add_include_path=False):
+        options = _split_options_if_necessary(options)
 
         options = (options
                 + _DEFAULT_BUILD_OPTIONS
@@ -421,35 +464,9 @@ class Program(object):
         if forced_options:
             options = options + forced_options.split()
 
-        # {{{ find include path
-
-        def unquote(path):
-            if path.startswith('"') and path.endswith('"'):
-                return path[1:-1]
-            else:
-                return path
-
-        include_path = ["."]
-
-        option_idx = 0
-        while option_idx < len(options):
-            option = options[option_idx].strip()
-            if option.startswith("-I") or option.startswith("/I"):
-                if len(option) == 2:
-                    if option_idx+1 < len(options):
-                        include_path.append(unquote(options[option_idx+1]))
-                    option_idx += 2
-                else:
-                    include_path.append(unquote(option[2:].lstrip()))
-                    option_idx += 1
-            else:
-                option_idx += 1
-
-        # }}}
-
-        options = [encode_if_necessary(s) for s in options]
-
-        return b" ".join(options), include_path
+        return (
+                _options_to_bytestring(options),
+                _find_include_path(options))
 
     def build(self, options=[], devices=None, cache_dir=None):
         options_bytes, include_path = self._process_build_options(
@@ -559,8 +576,11 @@ def create_program_with_built_in_kernels(context, devices, kernel_names):
         context, devices, kernel_names))
 
 
-def link_program(context, programs, options=[], devices=None):
-    options_bytes, _ = Program._process_build_options(context, options)
+def link_program(context, programs, options=None, devices=None):
+    if options is None:
+        options = []
+
+    options_bytes = _options_to_bytestring(_split_options_if_necessary(options))
     programs = [prg._get_prg() for prg in programs]
     raw_prg = _Program.link(context, programs, options_bytes, devices)
     return Program(raw_prg)
-- 
GitLab