Skip to content
test_loopy.py 29.3 KiB
Newer Older
def test_write_parameter(ctx_factory):
    dtype = np.float32
    ctx = ctx_factory()

    knl = lp.make_kernel(ctx.devices[0], [
            "{[i,j]: 0<=i,j<n }",
            ],
            """
                a = sum((i,j), i*j)
                b = sum(i, sum(j, i*j))
                n = 15
                """,
            [
                lp.GlobalArg("a", dtype, shape=()),
                lp.GlobalArg("b", dtype, shape=()),
                lp.ValueArg("n", np.int32, approximately=1000),
                ],
            assumptions="n>=1")

def test_arg_shape_guessing(ctx_factory):
    ctx = ctx_factory()

    knl = lp.make_kernel(ctx.devices[0], [
            "{[i,j]: 0<=i,j<n }",
            ],
            """
                a = 1.5 + sum((i,j), i*j)
                b[i, j] = i*j
                c[i+j, j] = b[j,i]
                """,
            [
                lp.GlobalArg("a", shape=lp.auto_shape),
                lp.GlobalArg("b", shape=lp.auto_shape),
                lp.GlobalArg("c", shape=lp.auto_shape),
                lp.ValueArg("n"),
                ],
            assumptions="n>=1")

    print knl
    print lp.CompiledKernel(ctx, knl).get_highlighted_code()

def test_arg_guessing(ctx_factory):
    ctx = ctx_factory()

    knl = lp.make_kernel(ctx.devices[0], [
            "{[i,j]: 0<=i,j<n }",
            ],
            """
                a = 1.5 + sum((i,j), i*j)
                b[i, j] = i*j
                c[i+j, j] = b[j,i]
                """,
            assumptions="n>=1")

    print knl
    print lp.CompiledKernel(ctx, knl).get_highlighted_code()

def test_nonlinear_index(ctx_factory):
    ctx = ctx_factory()

    knl = lp.make_kernel(ctx.devices[0], [
            "{[i,j]: 0<=i,j<n }",
            ],
            """
                a[i*i] = 17
                """,
            [
                lp.GlobalArg("a", shape="n"),
                lp.ValueArg("n"),
                ],
            assumptions="n>=1")

    print knl
    print lp.CompiledKernel(ctx, knl).get_highlighted_code()



if __name__ == "__main__":
    import sys
    if len(sys.argv) > 1:
        exec(sys.argv[1])
    else:
        from py.test.cmdline import main
        main([__file__])