Skip to content
driver.py 1.4 KiB
Newer Older
from _driver import *




def _add_functionality():
    def device_get_attributes(dev):
        return dict((getattr(device_attribute, att), 
            dev.get_attribute(getattr(device_attribute, att))
            )
            for att in dir(device_attribute)
            if att[0].isupper())

    Device.get_attributes = device_get_attributes



Andreas Klöckner's avatar
Andreas Klöckner committed
_add_functionality()




def pagelocked_zeros(shape, dtype, order="C"):
    result = pagelocked_empty(shape, dtype, order)
    result.fill(0)
    return result




Andreas Klöckner's avatar
Andreas Klöckner committed
class SourceModule(object):
    def __init__(self, source, options=[], keep=False):
        from tempfile import mkdtemp
        tempdir = mkdtemp()
        print tempdir

        from os.path import join
        outf = open(join(tempdir, "kernel.cu"), "w")
        outf.write(source)
        outf.close()

        from subprocess import call
        result = call(["nvcc", "--cubin"] 
                + options
                + ["kernel.cu"]
            cwd=tempdir)

        data = open(join(tempdir, "kernel.cubin"), "r").read()
        self.module = module_from_buffer(data)

        if not keep:
            from os import listdir, unlink, rmdir
            for name in listdir(tempdir):
                os.unlink(join(tempdir, name))
            os.rmdir(tempdir)

        self.get_function = self.module.get_function
        self.get_global = self.module.get_global
        self.get_texref = self.module.get_texref