Skip to content
test_loopy.py 93.5 KiB
Newer Older
         [lp.ConstantArg('T', shape=(200,), dtype=np.float32),
         '...'])

    knl = lp.fix_parameters(knl, n=200)

Andreas Klöckner's avatar
Andreas Klöckner committed
    lp.auto_test_vs_ref(knl, ctx, knl)
@pytest.mark.parametrize("src_order", ["C"])
@pytest.mark.parametrize("tmp_order", ["C", "F"])
def test_temp_initializer(ctx_factory, src_order, tmp_order):
    a = np.random.randn(3, 3).copy(order=src_order)

    ctx = ctx_factory()
    queue = cl.CommandQueue(ctx)

    knl = lp.make_kernel(
            "{[i,j]: 0<=i,j<n}",
            "out[i,j] = tmp[i,j]",
            [
                lp.TemporaryVariable("tmp",
                    initializer=a,
                    shape=lp.auto,
                    scope=lp.temp_var_scope.PRIVATE,
                    order=tmp_order),
                "..."
                ])

    knl = lp.set_options(knl, write_cl=True, highlight_cl=True)
    knl = lp.fix_parameters(knl, n=a.shape[0])

    evt, (a2,) = knl(queue, out_host=True)

    assert np.array_equal(a, a2)


def test_const_temp_with_initializer_not_saved():
    knl = lp.make_kernel(
        "{[i]: 0<=i<10}",
        """
        ... gbarrier
        out[i] = tmp[i]
        """,
        [
            lp.TemporaryVariable("tmp",
                initializer=np.arange(10),
                shape=lp.auto,
                scope=lp.temp_var_scope.PRIVATE,
                read_only=True),
            "..."
            ],
        seq_dependencies=True)

    knl = lp.preprocess_kernel(knl)
    knl = lp.get_one_scheduled_kernel(knl)
    knl = lp.save_and_reload_temporaries(knl)

    # This ensures no save slot was added.
    assert len(knl.temporary_variables) == 1


def test_header_extract():
Nick Curtis's avatar
Nick Curtis committed
    knl = lp.make_kernel('{[k]: 0<=k<n}}',
         """
         for k
             T[k] = k**2
         end
         """,
         [lp.GlobalArg('T', shape=(200,), dtype=np.float32),
Nick Curtis's avatar
Nick Curtis committed
         '...'])

    knl = lp.fix_parameters(knl, n=200)

    #test C
    cknl = knl.copy(target=lp.CTarget())
    assert str(lp.generate_header(cknl)[0]) == (
            'void loopy_kernel(float *__restrict__ T);')
    cuknl = knl.copy(target=lp.CudaTarget())
    assert str(lp.generate_header(cuknl)[0]) == (
            'extern "C" __global__ void __launch_bounds__(1) '
            'loopy_kernel(float *__restrict__ T);')
Nick Curtis's avatar
Nick Curtis committed

    #test OpenCL
    oclknl = knl.copy(target=lp.PyOpenCLTarget())
    assert str(lp.generate_header(oclknl)[0]) == (
            '__kernel void __attribute__ ((reqd_work_group_size(1, 1, 1))) '
            'loopy_kernel(__global float *__restrict__ T);')

def test_scalars_with_base_storage(ctx_factory):
    """ Regression test for !50 """
    ctx = ctx_factory()
    queue = cl.CommandQueue(ctx)

    knl = lp.make_kernel(
            "{ [i]: 0<=i<1}",
            "a = 1",
            [lp.TemporaryVariable("a", dtype=np.float64,
                                  shape=(), base_storage="base")])

    knl(queue, out_host=True)


Andreas Klöckner's avatar
Andreas Klöckner committed
def test_if_else(ctx_factory):
    ctx = ctx_factory()
    queue = cl.CommandQueue(ctx)

    knl = lp.make_kernel(
            "{ [i]: 0<=i<50}",
            """
            if i % 3 == 0
                a[i] = 15  {nosync_query=writes:a}
Andreas Klöckner's avatar
Andreas Klöckner committed
            elif i % 3 == 1
                a[i] = 11  {nosync_query=writes:a}
Andreas Klöckner's avatar
Andreas Klöckner committed
            else
                a[i] = 3  {nosync_query=writes:a}
Andreas Klöckner's avatar
Andreas Klöckner committed
            end
            """
            )

    evt, (out,) = knl(queue, out_host=True)

    out_ref = np.empty(50)
    out_ref[::3] = 15
    out_ref[1::3] = 11
    out_ref[2::3] = 3

    assert np.array_equal(out_ref, out)

    knl = lp.make_kernel(
            "{ [i]: 0<=i<50}",
            """
            for i
                if i % 2 == 0
                    if i % 3 == 0
                        a[i] = 15  {nosync_query=writes:a}
                    elif i % 3 == 1
                        a[i] = 11  {nosync_query=writes:a}
                        a[i] = 3  {nosync_query=writes:a}
                    a[i] = 4  {nosync_query=writes:a}
                end
            end
            """
            )

    evt, (out,) = knl(queue, out_host=True)

    out_ref = np.zeros(50)
    out_ref[1::2] = 4
    out_ref[0::6] = 15
    out_ref[4::6] = 11
    out_ref[2::6] = 3

    knl = lp.make_kernel(
            "{ [i,j]: 0<=i,j<50}",
            """
            for i
                if i < 25
                    for j
                        if j % 2 == 0
                            a[i, j] = 1  {nosync_query=writes:a}
                            a[i, j] = 0  {nosync_query=writes:a}
                            a[i, j] = 0  {nosync_query=writes:a}
                            a[i, j] = 1  {nosync_query=writes:a}
                        end
                    end
                end
            end
            """
            )

    evt, (out,) = knl(queue, out_host=True)

    out_ref = np.zeros((50, 50))
    out_ref[:25, 0::2] = 1
    out_ref[25:, 1::2] = 1

    assert np.array_equal(out_ref, out)

Andreas Klöckner's avatar
Andreas Klöckner committed

def test_tight_loop_bounds(ctx_factory):
    ctx = ctx_factory()
    queue = cl.CommandQueue(ctx)

    knl = lp.make_kernel(
        ["{ [i] : 0 <= i <= 5 }",
         "[i] -> { [j] : 2 * i - 2 < j <= 2 * i and 0 <= j <= 9 }"],
        """
        for i
          for j
            out[j] = j
          end
        end
        """,
        silenced_warnings="write_race(insn)")

    knl = lp.split_iname(knl, "i", 5, inner_tag="l.0", outer_tag="g.0")

    evt, (out,) = knl(queue, out_host=True)

    assert (out == np.arange(10)).all()


def test_tight_loop_bounds_codegen():
    knl = lp.make_kernel(
        ["{ [i] : 0 <= i <= 5 }",
         "[i] -> { [j] : 2 * i - 2 <= j <= 2 * i and 0 <= j <= 9 }"],
        """
        for i
          for j
            out[j] = j
          end
        end
        """,
        silenced_warnings="write_race(insn)",
        target=lp.OpenCLTarget())

    knl = lp.split_iname(knl, "i", 5, inner_tag="l.0", outer_tag="g.0")

    cgr = lp.generate_code_v2(knl)
    #print(cgr.device_code())
Matt Wala's avatar
Matt Wala committed
        "for (int j = " \
        "(gid(0) == 0 && lid(0) == 0 ? 0 : -2 + 2 * lid(0) + 10 * gid(0)); " \
        "j <= (-1 + gid(0) == 0 && lid(0) == 0 ? 9 : 2 * lid(0)); ++j)"
    assert for_loop in cgr.device_code()
def test_unscheduled_insn_detection():
    knl = lp.make_kernel(
        "{ [i]: 0 <= i < 10 }",
        """
        out[i] = i {id=insn1}
        """,
        "...")

    knl = lp.get_one_scheduled_kernel(lp.preprocess_kernel(knl))
    insn1, = lp.find_instructions(knl, "id:insn1")
    knl.instructions.append(insn1.copy(id="insn2"))

    from loopy.diagnostic import UnscheduledInstructionError
    with pytest.raises(UnscheduledInstructionError):
        lp.generate_code(knl)


def test_integer_reduction(ctx_factory):
Nick Curtis's avatar
Nick Curtis committed
    ctx = ctx_factory()
    queue = cl.CommandQueue(ctx)

    from loopy.kernel.data import temp_var_scope as scopes
Nick Curtis's avatar
Nick Curtis committed
    from loopy.types import to_loopy_type
Nick Curtis's avatar
Nick Curtis committed
    n = 200
Nick Curtis's avatar
Nick Curtis committed
    for vtype in [np.int32, np.int64]:
Nick Curtis's avatar
Nick Curtis committed
        var_int = np.random.randint(1000, size=n).astype(vtype)
Nick Curtis's avatar
Nick Curtis committed
        var_lp = lp.TemporaryVariable('var', initializer=var_int,
                                   read_only=True,
Nick Curtis's avatar
Nick Curtis committed
                                   scope=scopes.PRIVATE,
                                   dtype=to_loopy_type(vtype),
Nick Curtis's avatar
Nick Curtis committed
                                   shape=lp.auto)

        from collections import namedtuple
        ReductionTest = namedtuple('ReductionTest', 'kind, check, args')

        reductions = [
            ReductionTest('max', lambda x: x == np.max(var_int), args='var[k]'),
            ReductionTest('min', lambda x: x == np.min(var_int), args='var[k]'),
            ReductionTest('sum', lambda x: x == np.sum(var_int), args='var[k]'),
            ReductionTest('product', lambda x: x == np.prod(var_int), args='var[k]'),
            ReductionTest('argmax',
                lambda x: (
                    x[0] == np.max(var_int) and var_int[out[1]] == np.max(var_int)),
                args='var[k], k'),
            ReductionTest('argmin',
                lambda x: (
                    x[0] == np.min(var_int) and var_int[out[1]] == np.min(var_int)),
                args='var[k], k')
        ]

        for reduction, function, args in reductions:
Nick Curtis's avatar
Nick Curtis committed
            kstr = ("out" if 'arg' not in reduction
                        else "out[0], out[1]")
            kstr += ' = {0}(k, {1})'.format(reduction, args)
Nick Curtis's avatar
Nick Curtis committed
            knl = lp.make_kernel('{[k]: 0<=k<n}',
Nick Curtis's avatar
Nick Curtis committed
                                kstr,
                                [var_lp, '...'])

            knl = lp.fix_parameters(knl, n=200)

Nick Curtis's avatar
Nick Curtis committed

            assert function(out)
Nick Curtis's avatar
Nick Curtis committed

def test_complicated_argmin_reduction(ctx_factory):
    cl_ctx = ctx_factory()
    knl = lp.make_kernel(
            "{[ictr,itgt,idim]: "
            "0<=itgt<ntargets "
            "and 0<=ictr<ncenters "
            "and 0<=idim<ambient_dim}",

            """
            for itgt
                for ictr
                    <> dist_sq = sum(idim,
                            (tgt[idim,itgt] - center[idim,ictr])**2)
                    <> in_disk = dist_sq < (radius[ictr]*1.05)**2
                    <> matches = (
                            (in_disk
                                and qbx_forced_limit == 0)
                            or (in_disk
                                    and qbx_forced_limit != 0
                                    and qbx_forced_limit * center_side[ictr] > 0)
                            )

                    <> post_dist_sq = if(matches, dist_sq, HUGE)
                end
                <> min_dist_sq, <> min_ictr = argmin(ictr, ictr, post_dist_sq)

                tgt_to_qbx_center[itgt] = if(min_dist_sq < HUGE, min_ictr, -1)
            end
            """)

    knl = lp.fix_parameters(knl, ambient_dim=2)
    knl = lp.add_and_infer_dtypes(knl, {
            "tgt,center,radius,HUGE": np.float32,
            "center_side,qbx_forced_limit": np.int32,
            })

    lp.auto_test_vs_ref(knl, cl_ctx, knl, parameters={
            "HUGE": 1e20, "ncenters": 200, "ntargets": 300,
            "qbx_forced_limit": 1})


def test_nosync_option_parsing():
    knl = lp.make_kernel(
        "{[i]: 0 <= i < 10}",
        """
        <>t = 1 {id=insn1,nosync=insn1}
        t = 2   {id=insn2,nosync=insn1:insn2}
        t = 3   {id=insn3,nosync=insn1@local:insn2@global:insn3@any}
        t = 4   {id=insn4,nosync_query=id:insn*@local}
        t = 5   {id=insn5,nosync_query=id:insn1}
        """,
        options=lp.Options(allow_terminal_colors=False))
    kernel_str = str(knl)
    print(kernel_str)
    assert "id=insn1, no_sync_with=insn1@any" in kernel_str
    assert "id=insn2, no_sync_with=insn1@any:insn2@any" in kernel_str
    assert "id=insn3, no_sync_with=insn1@local:insn2@global:insn3@any" in kernel_str
    assert "id=insn4, no_sync_with=insn1@local:insn2@local:insn3@local:insn5@local" in kernel_str  # noqa
    assert "id=insn5, no_sync_with=insn1@any" in kernel_str
def barrier_between(knl, id1, id2, ignore_barriers_in_levels=()):
    from loopy.schedule import (RunInstruction, Barrier, EnterLoop, LeaveLoop,
            CallKernel, ReturnFromKernel)
    watch_for_barrier = False
    seen_barrier = False

    for sched_item in knl.schedule:
        if isinstance(sched_item, RunInstruction):
            if sched_item.insn_id == id1:
                watch_for_barrier = True
            elif sched_item.insn_id == id2:
                return watch_for_barrier and seen_barrier
        elif isinstance(sched_item, Barrier):
            if watch_for_barrier and loop_level not in ignore_barriers_in_levels:
        elif isinstance(sched_item, EnterLoop):
            loop_level += 1
        elif isinstance(sched_item, LeaveLoop):
            loop_level -= 1
        elif isinstance(sched_item, (CallKernel, ReturnFromKernel)):
            pass
        else:
            raise RuntimeError("schedule item type '%s' not understood"
                    % type(sched_item).__name__)

    raise RuntimeError("id2 was not seen")


def test_barrier_insertion_near_top_of_loop():
    knl = lp.make_kernel(
        "{[i,j]: 0 <= i,j < 10 }",
        """
        for i
         <>a[i] = i  {id=ainit}
         for j
          <>t = a[(i + 1) % 10]  {id=tcomp}
          <>b[i,j] = a[i] + t   {id=bcomp1}
          b[i,j] = b[i,j] + 1  {id=bcomp2}
         end
        end
        """,
        seq_dependencies=True)
    knl = lp.tag_inames(knl, dict(i="l.0"))
    knl = lp.set_temporary_scope(knl, "a", "local")
    knl = lp.set_temporary_scope(knl, "b", "local")
    knl = lp.get_one_scheduled_kernel(lp.preprocess_kernel(knl))

    print(knl)

    assert barrier_between(knl, "ainit", "tcomp")
    assert barrier_between(knl, "tcomp", "bcomp1")
    assert barrier_between(knl, "bcomp1", "bcomp2")
def test_barrier_insertion_near_bottom_of_loop():
    knl = lp.make_kernel(
        ["{[i]: 0 <= i < 10 }",
         "[jmax] -> {[j]: 0 <= j < jmax}"],
        """
        for i
         <>a[i] = i  {id=ainit}
         for j
          <>b[i,j] = a[i] + t   {id=bcomp1}
          b[i,j] = b[i,j] + 1  {id=bcomp2}
         end
         a[i] = i + 1 {id=aupdate}
        end
        """,
        seq_dependencies=True)
    knl = lp.tag_inames(knl, dict(i="l.0"))
    knl = lp.set_temporary_scope(knl, "a", "local")
    knl = lp.set_temporary_scope(knl, "b", "local")
    knl = lp.get_one_scheduled_kernel(lp.preprocess_kernel(knl))

    print(knl)

    assert barrier_between(knl, "bcomp1", "bcomp2")
    assert barrier_between(knl, "ainit", "aupdate", ignore_barriers_in_levels=[1])
Nick Curtis's avatar
Nick Curtis committed
def test_barrier_in_overridden_get_grid_size_expanded_kernel():
    from loopy.kernel.data import temp_var_scope as scopes

    # make simple barrier'd kernel
    knl = lp.make_kernel('{[i]: 0 <= i < 10}',
                   """
              for i
                    a[i] = i {id=a}
                    ... lbarrier {id=barrier}
                    b[i + 1] = a[i] {nosync=a}
              end
                   """,
                   [lp.TemporaryVariable("a", np.float32, shape=(10,), order='C',
                                         scope=scopes.LOCAL),
                    lp.GlobalArg("b", np.float32, shape=(11,), order='C')],
               seq_dependencies=True)

    # split into kernel w/ vesize larger than iname domain
    vecsize = 16
    knl = lp.split_iname(knl, 'i', vecsize, inner_tag='l.0')

Nick Curtis's avatar
Nick Curtis committed

    # artifically expand via overridden_get_grid_sizes_for_insn_ids
Nick Curtis's avatar
Nick Curtis committed
    knl = knl.copy(overridden_get_grid_sizes_for_insn_ids=GridOverride(
        knl.copy(), vecsize))
Nick Curtis's avatar
Nick Curtis committed
    # make sure we can generate the code
    lp.generate_code_v2(knl)


def test_multi_argument_reduction_type_inference():
    from loopy.type_inference import TypeInferenceMapper
    from loopy.library.reduction import SegmentedSumReductionOperation
    from loopy.types import to_loopy_type
    op = SegmentedSumReductionOperation()

    knl = lp.make_kernel("{[i,j]: 0<=i<10 and 0<=j<i}", "")

    int32 = to_loopy_type(np.int32)

    expr = lp.symbolic.Reduction(
            operation=op,
            inames=("i",),
            expr=lp.symbolic.Reduction(
                operation=op,
                inames="j",
                expr=(1, 2),
                allow_simultaneous=True),
            allow_simultaneous=True)

    t_inf_mapper = TypeInferenceMapper(knl)

Matt Wala's avatar
Matt Wala committed
    assert (
            t_inf_mapper(expr, return_tuple=True, return_dtype_set=True)
            == [(int32, int32)])
def test_multi_argument_reduction_parsing():
    from loopy.symbolic import parse, Reduction

    assert isinstance(
            parse("reduce(argmax, i, reduce(argmax, j, i, j))").expr,
            Reduction)


def test_global_barrier_order_finding():
    knl = lp.make_kernel(
            "{[i,itrip]: 0<=i<n and 0<=itrip<ntrips}",
            """
            for i
                for itrip
                    ... gbarrier {id=top}
                    <> z[i] = z[i+1] + z[i]  {id=wr_z,dep=top}
                    <> v[i] = 11  {id=wr_v,dep=top}
                    ... gbarrier {dep=wr_z:wr_v,id=yoink}
                    z[i] = z[i] - z[i+1] + v[i] {id=iupd, dep=yoink}
                end
                ... nop {id=nop}
                ... gbarrier {dep=iupd,id=postloop}
                z[i] = z[i] - z[i+1] + v[i]  {id=zzzv,dep=postloop}
            end
            """)

    assert lp.get_global_barrier_order(knl) == ("top", "yoink", "postloop")

    for insn, barrier in (
            ("nop", None),
            ("top", None),
            ("wr_z", "top"),
            ("wr_v", "top"),
            ("yoink", "top"),
            ("postloop", "yoink"),
            ("zzzv", "postloop")):
        assert lp.find_most_recent_global_barrier(knl, insn) == barrier


def test_global_barrier_error_if_unordered():
    # FIXME: Should be illegal to declare this
    knl = lp.make_kernel("{[i]: 0 <= i < 10}",
            """
            ... gbarrier
            ... gbarrier
            """)

    from loopy.diagnostic import LoopyError
    with pytest.raises(LoopyError):
        lp.get_global_barrier_order(knl)
def test_struct_assignment(ctx_factory):
    ctx = ctx_factory()
    queue = cl.CommandQueue(ctx)

    bbhit = np.dtype([
        ("tmin", np.float32),
        ("tmax", np.float32),
        ("bi", np.int32),
        ("hit", np.int32)])

    bbhit, bbhit_c_decl = cl.tools.match_dtype_to_c_struct(
            ctx.devices[0], "bbhit", bbhit)
    bbhit = cl.tools.get_or_register_dtype('bbhit', bbhit)

    preamble = bbhit_c_decl

    knl = lp.make_kernel(
        "{ [i]: 0<=i<N }",
        """
        for i
            result[i].hit = i % 2  {nosync_query=writes:result}
            result[i].tmin = i  {nosync_query=writes:result}
            result[i].tmax = i+10  {nosync_query=writes:result}
            result[i].bi = i  {nosync_query=writes:result}
        end
        """,
        [
            lp.GlobalArg("result", shape=("N",), dtype=bbhit),
            "..."],
        preambles=[("000", preamble)])

    knl = lp.set_options(knl, write_cl=True)
    knl(queue, N=200)


Matt Wala's avatar
Matt Wala committed
def test_inames_conditional_generation(ctx_factory):
    ctx = ctx_factory()
    knl = lp.make_kernel(
            "{[i,j,k]: 0 < k < i and 0 < j < 10 and 0 < i < 10}",
            """
            for k
                ... gbarrier
                <>tmp1 = 0
            end
            for j
                ... gbarrier
                <>tmp2 = i
            end
            """,
            "...",
            seq_dependencies=True)

    knl = lp.tag_inames(knl, dict(i="g.0"))

    with cl.CommandQueue(ctx) as queue:
        knl(queue)


def test_kernel_var_name_generator():
    knl = lp.make_kernel(
            "{[i]: 0 <= i <= 10}",
            """
            <>a = 0
            <>b_s0 = 0
            """)

    vng = knl.get_var_name_generator()

    assert vng("a_s0") != "a_s0"
    assert vng("b") != "b"


def test_fixed_parameters(ctx_factory):
    ctx = ctx_factory()
    queue = cl.CommandQueue(ctx)

    knl = lp.make_kernel(
            "[n] -> {[i]: 0 <= i < n}",
            """
            <>tmp[i] = i  {id=init}
            tmp[0] = 0  {dep=init}
def test_parameter_inference():
    knl = lp.make_kernel("{[i]: 0 <= i < n and i mod 2 = 0}", "")
    assert knl.all_params() == set(["n"])


def test_execution_backend_can_cache_dtypes(ctx_factory):
    # When the kernel is invoked, the execution backend uses it as a cache key
    # for the type inference and scheduling cache. This tests to make sure that
    # dtypes in the kernel can be cached, even though they may not have a
    # target.

    ctx = ctx_factory()
    queue = cl.CommandQueue(ctx)

    knl = lp.make_kernel("{[i]: 0 <= i < 10}", "<>tmp[i] = i")
    knl = lp.add_dtypes(knl, dict(tmp=int))

    knl(queue)


def test_wildcard_dep_matching():
    knl = lp.make_kernel(
            "{[i]: 0 <= i < 10}",
            """
            <>a = 0 {id=insn1}
            <>b = 0 {id=insn2,dep=insn?}
            <>c = 0 {id=insn3,dep=insn*}
            <>d = 0 {id=insn4,dep=insn[12]}
            <>e = 0 {id=insn5,dep=insn[!1]}
            """,
            "...")

Matt Wala's avatar
Matt Wala committed
    all_insns = set("insn%d" % i for i in range(1, 6))

    assert knl.id_to_insn["insn1"].depends_on == set()
    assert knl.id_to_insn["insn2"].depends_on == all_insns - set(["insn2"])
    assert knl.id_to_insn["insn3"].depends_on == all_insns - set(["insn3"])
    assert knl.id_to_insn["insn4"].depends_on == set(["insn1", "insn2"])
    assert knl.id_to_insn["insn5"].depends_on == all_insns - set(["insn1", "insn5"])


Nick Curtis's avatar
Nick Curtis committed
def test_preamble_with_separate_temporaries(ctx_factory):
    from loopy.kernel.data import temp_var_scope as scopes
    # create a function mangler

    # and finally create a test
    n = 10
    # for each entry come up with a random number of data points
Nick Curtis's avatar
Nick Curtis committed
    num_data = np.asarray(np.random.randint(2, 10, size=n), dtype=np.int32)
Nick Curtis's avatar
Nick Curtis committed
    # turn into offsets
    offsets = np.asarray(np.hstack(([0], np.cumsum(num_data))), dtype=np.int32)
    # create lookup data
    lookup = np.empty(0)
    for i in num_data:
        lookup = np.hstack((lookup, np.arange(i)))
    lookup = np.asarray(lookup, dtype=np.int32)
    # and create data array
    data = np.random.rand(np.product(num_data))

    # make kernel
    kernel = lp.make_kernel('{[i]: 0 <= i < n}',
    """
    for i
        <>ind = indirect(offsets[i], offsets[i + 1], 1)
        out[i] = data[ind]
    end
    """,
    [lp.GlobalArg('out', shape=('n',)),
     lp.TemporaryVariable(
        'offsets', shape=(offsets.size,), initializer=offsets, scope=scopes.GLOBAL,
        read_only=True),
     lp.GlobalArg('data', shape=(data.size,), dtype=np.float64)],
    )
    # fixt params, and add manglers / preamble
    from testlib import SeparateTemporariesPreambleTestHelper
    preamble_with_sep_helper = SeparateTemporariesPreambleTestHelper(
            func_name='indirect',
            func_arg_dtypes=(np.int32, np.int32, np.int32),
            func_result_dtypes=(np.int32,),
            arr=lookup
            )

Nick Curtis's avatar
Nick Curtis committed
    kernel = lp.fix_parameters(kernel, **{'n': n})
    kernel = lp.register_preamble_generators(
            kernel, [preamble_with_sep_helper.preamble_gen])
    kernel = lp.register_function_manglers(
            kernel, [preamble_with_sep_helper.mangler])
Nick Curtis's avatar
Nick Curtis committed

    print(lp.generate_code(kernel)[0])
    # and call (functionality unimportant, more that it compiles)
Nick Curtis's avatar
Nick Curtis committed
    ctx = ctx_factory()
Nick Curtis's avatar
Nick Curtis committed
    queue = cl.CommandQueue(ctx)
    # check that it actually performs the lookup correctly
    assert np.allclose(kernel(
        queue, data=data.flatten('C'))[1][0], data[offsets[:-1] + 1])


def test_arg_inference_for_predicates():
    knl = lp.make_kernel("{[i]: 0 <= i < 10}",
            """
            if incr[i]
              a = a + 1
            end
            """)

    assert "incr" in knl.arg_dict
    assert knl.arg_dict["incr"].shape == (10,)


def test_relaxed_stride_checks(ctx_factory):
    # Check that loopy is compatible with numpy's relaxed stride rules.
    ctx = ctx_factory()

    knl = lp.make_kernel("{[i,j]: 0 <= i <= n and 0 <= j <= m}",
             """
             a[i] = sum(j, A[i,j] * b[j])
             """)

    with cl.CommandQueue(ctx) as queue:
Matt Wala's avatar
Matt Wala committed
        mat = np.zeros((1, 10), order="F")
Matt Wala's avatar
Matt Wala committed
        b = np.zeros(10)
Matt Wala's avatar
Matt Wala committed
        evt, (a,) = knl(queue, A=mat, b=b)
def test_add_prefetch_works_in_lhs_index():
    knl = lp.make_kernel(
            "{ [n,k,l,k1,l1,k2,l2]: "
            "start<=n<end and 0<=k,k1,k2<3 and 0<=l,l1,l2<2 }",
            """
            for n
                <> a1_tmp[k,l] = a1[a1_map[n, k],l]
                a1_tmp[k1,l1] = a1_tmp[k1,l1] + 1
                a1_out[a1_map[n,k2], l2] = a1_tmp[k2,l2]
            end
            """,
            [
                lp.GlobalArg("a1,a1_out", None, "ndofs,2"),
                lp.GlobalArg("a1_map", None, "nelements,3"),
Andreas Klöckner's avatar
Andreas Klöckner committed
                "..."
            ])

    knl = lp.add_prefetch(knl, "a1_map", "k")

    from loopy.symbolic import get_dependencies
    for insn in knl.instructions:
        assert "a1_map" not in get_dependencies(insn.assignees)


Nick Curtis's avatar
Nick Curtis committed
def test_explicit_simd_shuffles(ctx_factory):
    ctx = ctx_factory()

Nick Curtis's avatar
Nick Curtis committed
    def create_and_test(insn, answer=None, atomic=False, additional_check=None,
                        store=False):
Nick Curtis's avatar
Nick Curtis committed
        knl = lp.make_kernel(['{[i]: 0 <= i < 12}', '{[j]: 0 <= j < 1}'],
                             insn,
Nick Curtis's avatar
Nick Curtis committed
                             [lp.GlobalArg('a', shape=(1, 14,), dtype=np.int32,
Nick Curtis's avatar
Nick Curtis committed
                                           for_atomic=atomic),
                              lp.GlobalArg('b', shape=(1, 14,), dtype=np.int32,
                                           for_atomic=atomic)])
Nick Curtis's avatar
Nick Curtis committed

        knl = lp.split_iname(knl, 'i', 4, inner_tag='vec')
        knl = lp.tag_inames(knl, [('j', 'g.0')])
        knl = lp.split_array_axis(knl, ['a', 'b'], 1, 4)
        knl = lp.tag_array_axes(knl, ['a', 'b'], 'N1,N0,vec')

        print(lp.generate_code_v2(knl).device_code())
        queue = cl.CommandQueue(ctx)
        if answer is None:
Nick Curtis's avatar
Nick Curtis committed
            answer = np.zeros(16, dtype=np.int32)
            if store:
                answer[2:-2] = np.arange(0, 12, dtype=np.int32)
            else:
                answer[:-4] = np.arange(2, 14, dtype=np.int32)

        a = np.zeros((1, 4, 4), dtype=np.int32)
        b = np.arange(16, dtype=np.int32).reshape((1, 4, 4))
        result = knl(queue, a=a, b=b)[1][0]

        assert np.array_equal(result.flatten('C'), answer)
Nick Curtis's avatar
Nick Curtis committed
        if additional_check is not None:
            assert additional_check(knl)
    # test w/ compile time temporary constant
    create_and_test("<>c = 2\n" +
Nick Curtis's avatar
Nick Curtis committed
                    "a[j, i] = b[j, i + c]",
                    additional_check=lambda knl: 'vload' in lp.generate_code_v2(
                        knl).device_code())
Nick Curtis's avatar
Nick Curtis committed
    create_and_test("a[j, i] = b[j, i + 2]")
    create_and_test("a[j, i] = b[j, i + 2] + a[j, i]")
    create_and_test("a[j, i] = a[j, i] + b[j, i + 2]")
Nick Curtis's avatar
Nick Curtis committed
    # test vector stores
    create_and_test("<>c = 2\n" +
                    "a[j, i + c] = b[j, i]",
                    additional_check=lambda knl: 'vstore' in lp.generate_code_v2(
                        knl).device_code(),
                    store=True)
    create_and_test("a[j, i + 2] = b[j, i]", store=True)
    create_and_test("a[j, i + 2] = b[j, i] + a[j, i + 2]", store=True)
    create_and_test("a[j, i + 2] = a[j, i + 2] + b[j, i]", store=True)
    # test small vector shuffle
    shuffled = np.arange(16, dtype=np.int32)[(np.arange(16) + 2) % 4 +
                                              4 * (np.arange(16) // 4)]
    shuffled[12:] = 0
    create_and_test("a[j, i] = b[j, (i + 2) % 4 + 4 * (i // 4)]", shuffled)
    create_and_test("a[j, (i + 2) % 4 + 4 * (i // 4)] = b[j, i]", shuffled)
Nick Curtis's avatar
Nick Curtis committed
    # test atomics
    from loopy import LoopyError
Nick Curtis's avatar
Nick Curtis committed
    from loopy.codegen import Unvectorizable
    with pytest.raises((LoopyError, Unvectorizable)):
        temp = np.arange(12, dtype=np.int32)
        answer = np.zeros(4, dtype=np.int32)
        for i in range(4):
            answer[i] = np.sum(temp[(i + 2) % 4::4])
        create_and_test("a[j, (i + 2) % 4] = a[j, (i + 2) % 4] + b[j, i] {atomic}",
                        answer, True)
def test_explicit_simd_temporary_promotion(ctx_factory):
    from loopy.kernel.data import temp_var_scope as scopes

    ctx = ctx_factory()
    queue = cl.CommandQueue(ctx)

    # fun with vector temporaries

    # first, test parsing
    knl = lp.make_kernel(
        '{[i,j]: 0 <= i,j < 12}',
        """
        <> t = 1
        <int32> t1 = 1
        <int32:s> t2 = 1
        <:s> t3 = 1
Nick Curtis's avatar
Nick Curtis committed
        <:v> tv = 1
        <int32> tv1 = 1
        <int32:v> tv2 = 1
        <:v> tv3 = 1
Nick Curtis's avatar
Nick Curtis committed
    def make_kernel(insn, ans=None, preamble=None, extra_inames=None):
        skeleton = """
        %(preamble)s
        for j
            for i
                %(insn)s
                if test
                    a[i, j] = 1
Nick Curtis's avatar
Nick Curtis committed
        end
        """
        inames = ['i, j']
        if extra_inames is not None:
            inames += list(extra_inames)
        knl = lp.make_kernel(
            '{[%(inames)s]: 0 <= %(inames)s < 12}' % {'inames': ', '.join(inames)},
            skeleton % dict(insn=insn, preamble='' if not preamble else preamble),
Nick Curtis's avatar
Nick Curtis committed
            [lp.GlobalArg('a', shape=(12, 12)),
             lp.TemporaryVariable('mask', shape=(12,), initializer=np.array(
                                  np.arange(12) >= 6, dtype=np.int), read_only=True,
                                  scope=scopes.GLOBAL)])

        knl = lp.split_iname(knl, 'j', 4, inner_tag='vec')
        knl = lp.split_array_axis(knl, 'a', 1, 4)
        knl = lp.tag_array_axes(knl, 'a', 'N1,N0,vec')
        knl = lp.preprocess_kernel(knl)
Nick Curtis's avatar
Nick Curtis committed
        if ans is not None:
            assert np.array_equal(knl(queue, a=np.zeros((12, 3, 4), dtype=np.int32))[
                1][0], ans)

        return knl
Nick Curtis's avatar
Nick Curtis committed
    # case 1) -- incorrect promotion of temporaries to vector dtypes
Nick Curtis's avatar
Nick Curtis committed
    make_kernel('<> test = mask[i]', ans)

    # next test the writer heuristic

Nick Curtis's avatar
Nick Curtis committed
    # case 2) assignment from a vector iname
Nick Curtis's avatar
Nick Curtis committed
    knl = make_kernel('<> test = mask[j]')
    assert knl.temporary_variables['test'].shape == (4,)
Nick Curtis's avatar
Nick Curtis committed
    # case 3) recursive dependency
    knl = make_kernel("""
        <> test = mask[j]
        <> test2 = test
        """)
    assert knl.temporary_variables['test2'].shape == (4,)

Nick Curtis's avatar
Nick Curtis committed
    # case 4) test that a conflict in user-specified vector types results in error
Nick Curtis's avatar
Nick Curtis committed

Nick Curtis's avatar
Nick Curtis committed
    # 4a) initial scalar assignment w/ later vector access
Nick Curtis's avatar
Nick Curtis committed
    preamble = """
    for k
        <:s> test = 1
    end
    """

    from loopy import LoopyError
    with pytest.raises(LoopyError):
        make_kernel('test = mask[j]', preamble=preamble, extra_inames='k')

Nick Curtis's avatar
Nick Curtis committed
    # 4b) initial vector assignment w/ later scalar access -- OK

    preamble = """
    for k
        <:v> test = 1
    end
    """

    from loopy import LoopyError
    # treat warning as error to make sure the logic detecting user specified
    # vectorization is good
    import warnings
Nick Curtis's avatar
Nick Curtis committed
    try:
        warnings.filterwarnings(
            'error', r"Instruction '[^\W]+': touched variable that \(for ILP\)")
        make_kernel('test = mask[i]', preamble=preamble, extra_inames='k')
    except Exception:
        raise
    finally:
        warnings.resetwarnings()
Nick Curtis's avatar
Nick Curtis committed

def test_explicit_simd_selects(ctx_factory):
    ctx = ctx_factory()

Nick Curtis's avatar
Nick Curtis committed
    def create_and_test(insn, condition, answer, exception=None, a=None, b=None,
        a = np.zeros((3, 4), dtype=np.int32) if a is None else a
        data = [lp.GlobalArg('a', shape=(12,), dtype=a.dtype)]
        kwargs = dict(a=a)
        if b is not None:
Nick Curtis's avatar
Nick Curtis committed
            data += [lp.GlobalArg('b', shape=(12,), dtype=b.dtype)]
            kwargs['b'] = b
        if c is not None:
            data += [lp.GlobalArg('c', shape=(12,), dtype=b.dtype)]
            kwargs['c'] = c
        names = [d.name for d in data]

        knl = lp.make_kernel(['{[i]: 0 <= i < 12}'],
            """