Newer
Older
from _driver import *
def _add_functionality():
def device_get_attributes(dev):
return dict((getattr(device_attribute, att),
dev.get_attribute(getattr(device_attribute, att))
)
for att in dir(device_attribute)
if att[0].isupper())
Device.get_attributes = device_get_attributes
def pagelocked_zeros(shape, dtype, order="C"):
result = pagelocked_empty(shape, dtype, order)
result.fill(0)
return result
class SourceModule(object):
def __init__(self, source, options=[], keep=False):
from tempfile import mkdtemp
tempdir = mkdtemp()
print tempdir
from os.path import join
outf = open(join(tempdir, "kernel.cu"), "w")
outf.write(source)
outf.close()
from subprocess import call
result = call(["nvcc", "--cubin"]
+ options
+ ["kernel.cu"]
cwd=tempdir)
data = open(join(tempdir, "kernel.cubin"), "r").read()
self.module = module_from_buffer(data)
if not keep:
from os import listdir, unlink, rmdir
for name in listdir(tempdir):
os.unlink(join(tempdir, name))
os.rmdir(tempdir)
self.get_function = self.module.get_function
self.get_global = self.module.get_global
self.get_texref = self.module.get_texref