From 27c281b22bd570544a0649b20861ef0adaa8520e Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 10 Jul 2019 20:10:36 -0500 Subject: [PATCH] Try to discover CUDA binary path, to add to Py3.8's DLL path (Closes gh-213) --- pycuda/driver.py | 64 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 5 deletions(-) diff --git a/pycuda/driver.py b/pycuda/driver.py index 3fa4c5b2..fd042a75 100644 --- a/pycuda/driver.py +++ b/pycuda/driver.py @@ -1,6 +1,63 @@ -from __future__ import absolute_import -from __future__ import print_function +from __future__ import absolute_import, print_function + +import os +import sys + import six + +import numpy as np + + +# {{{ add cuda lib dir to Python DLL path + +def _search_on_path(filenames): + """Find file on system path.""" + # http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/52224 + + from os.path import exists, abspath, join + from os import pathsep, environ + + search_path = environ["PATH"] + + paths = search_path.split(pathsep) + for path in paths: + for filename in filenames: + if exists(join(path, filename)): + return abspath(join(path, filename)) + + +def _add_cuda_libdir_to_dll_path(): + from os.path import join, dirname + + cuda_path = os.environ.get("CUDA_PATH") + + if cuda_path is not None: + os.add_dll_directory(join(cuda_path, 'bin')) + return + + nvcc_path = _search_on_path(["nvcc.exe"]) + if nvcc_path is not None: + os.add_dll_directory(dirname(nvcc_path)) + + from warnings import warn + warn("Unable to discover CUDA installation directory " + "while attempting to add it to Python's DLL path. " + "Either set the 'CUDA_PATH' environment variable " + "or ensure that 'nvcc.exe' is on the path.") + + +try: + os.add_dll_directory +except AttributeError: + # likely not on Py3.8 and Windows + # https://github.com/inducer/pycuda/issues/213 + pass +else: + _add_cuda_libdir_to_dll_path() + +# }}} + + try: from pycuda._driver import * # noqa except ImportError as e: @@ -11,9 +68,6 @@ except ImportError as e: "does not match the version of your CUDA driver.") raise -import numpy as np -import sys - if sys.version_info >= (3,): _memoryview = memoryview -- GitLab