Skip to content
Snippets Groups Projects
test.py 1.69 KiB
Newer Older
  • Learn to ignore specific revisions
  • import numpy.linalg as la
    
    import pyopencl as cl
    
    import pyopencl.array  # noqa
    import pyopencl.clrandom  # noqa
    
    import loopy as lp
    import sys
    
    from pyopencl.tools import (  # noqa
            pytest_generate_tests_for_pyopencl
            as pytest_generate_tests)
    
    
    _WENO_PRG = []
    
    
    def get_weno_program():
        if _WENO_PRG:
            return _WENO_PRG[0]
    
        fn = "WENO.F90"
    
        with open(fn, "r") as infile:
            infile_content = infile.read()
    
        prg = lp.parse_transformed_fortran(infile_content, filename=fn)
        _WENO_PRG.append(prg)
        return prg
    
    
    
    def with_root_kernel(prg, root_name):
    
        # FIXME This is a little less beautiful than it could be
    
        new_prg = prg.copy(name=root_name)
    
        for name in prg:
            clbl = new_prg[name]
            if isinstance(clbl, lp.LoopKernel) and clbl.is_called_from_host:
                new_prg = new_prg.with_kernel(clbl.copy(is_called_from_host=False))
    
    
        new_prg = new_prg.with_kernel(prg[root_name].copy(is_called_from_host=True))
    
    def test_matvec(ctx_factory):
    
        ctx = ctx_factory()
        queue = cl.CommandQueue(ctx)
    
    
        prg = get_weno_program()
    
        prg = with_root_kernel(prg, "mult_mat_vec")
    
        a = np.random.rand(10, 10).astype(np.float32).copy(order="F")
        b = np.random.rand(10).astype(np.float32)
    
        c_dev = cl.array.empty(queue, 10, dtype=np.float32)
    
        prg = lp.set_options(prg, write_cl=True)
    
        prg(queue, a=a, b=b, c=c_dev, alpha=1.0)
    
    
        c = c_dev.get()
    
        assert la.norm(a@b - c, 2)/la.norm(c) < 1e-5
    
    
    # This lets you run 'python test.py test_case(cl._csc)' without pytest.
    
    if __name__ == "__main__":
        if len(sys.argv) > 1:
            exec(sys.argv[1])
        else:
            from pytest import main
            main([__file__])