From 19345492cda6fc30ed43904e32221fb8a16c297c Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Tue, 12 Feb 2019 15:46:28 -0600 Subject: [PATCH] Refactor option treatment, don't pass include dir options to linker (Closes #13 on Gitlab) --- pyopencl/__init__.py | 120 +++++++++++++++++++++++++------------------ 1 file changed, 70 insertions(+), 50 deletions(-) diff --git a/pyopencl/__init__.py b/pyopencl/__init__.py index 7f77154f..810ff00e 100644 --- a/pyopencl/__init__.py +++ b/pyopencl/__init__.py @@ -265,6 +265,66 @@ def _find_pyopencl_include_path(): # }}} +# {{{ build option munging + +def _split_options_if_necessary(options): + if isinstance(options, six.string_types): + import shlex + if six.PY2: + # shlex.split takes bytes (py2 str) on py2 + if isinstance(options, six.text_type): + options = options.encode("utf-8") + else: + # shlex.split takes unicode (py3 str) on py3 + if isinstance(options, six.binary_type): + options = options.decode("utf-8") + + options = shlex.split(options) + + return options + + +def _find_include_path(options): + def unquote(path): + if path.startswith('"') and path.endswith('"'): + return path[1:-1] + else: + return path + + include_path = ["."] + + option_idx = 0 + while option_idx < len(options): + option = options[option_idx].strip() + if option.startswith("-I") or option.startswith("/I"): + if len(option) == 2: + if option_idx+1 < len(options): + include_path.append(unquote(options[option_idx+1])) + option_idx += 2 + else: + include_path.append(unquote(option[2:].lstrip())) + option_idx += 1 + else: + option_idx += 1 + + # }}} + + return include_path + + +def _options_to_bytestring(options): + def encode_if_necessary(s): + if isinstance(s, six.text_type): + return s.encode("utf-8") + else: + return s + + return b" ".join(encode_if_necessary(s) for s in options) + + +# }}} + + # {{{ Program (wrapper around _Program, adds caching support) _DEFAULT_BUILD_OPTIONS = [] @@ -390,25 +450,8 @@ class Program(object): # {{{ build @classmethod - def _process_build_options(cls, context, options): - if isinstance(options, six.string_types): - import shlex - if six.PY2: - # shlex.split takes bytes (py2 str) on py2 - if isinstance(options, six.text_type): - options = options.encode("utf-8") - else: - # shlex.split takes unicode (py3 str) on py3 - if isinstance(options, six.binary_type): - options = options.decode("utf-8") - - options = shlex.split(options) - - def encode_if_necessary(s): - if isinstance(s, six.text_type): - return s.encode("utf-8") - else: - return s + def _process_build_options(cls, context, options, _add_include_path=False): + options = _split_options_if_necessary(options) options = (options + _DEFAULT_BUILD_OPTIONS @@ -421,35 +464,9 @@ class Program(object): if forced_options: options = options + forced_options.split() - # {{{ find include path - - def unquote(path): - if path.startswith('"') and path.endswith('"'): - return path[1:-1] - else: - return path - - include_path = ["."] - - option_idx = 0 - while option_idx < len(options): - option = options[option_idx].strip() - if option.startswith("-I") or option.startswith("/I"): - if len(option) == 2: - if option_idx+1 < len(options): - include_path.append(unquote(options[option_idx+1])) - option_idx += 2 - else: - include_path.append(unquote(option[2:].lstrip())) - option_idx += 1 - else: - option_idx += 1 - - # }}} - - options = [encode_if_necessary(s) for s in options] - - return b" ".join(options), include_path + return ( + _options_to_bytestring(options), + _find_include_path(options)) def build(self, options=[], devices=None, cache_dir=None): options_bytes, include_path = self._process_build_options( @@ -559,8 +576,11 @@ def create_program_with_built_in_kernels(context, devices, kernel_names): context, devices, kernel_names)) -def link_program(context, programs, options=[], devices=None): - options_bytes, _ = Program._process_build_options(context, options) +def link_program(context, programs, options=None, devices=None): + if options is None: + options = [] + + options_bytes = _options_to_bytestring(_split_options_if_necessary(options)) programs = [prg._get_prg() for prg in programs] raw_prg = _Program.link(context, programs, options_bytes, devices) return Program(raw_prg) -- GitLab