from __future__ import division, absolute_import, print_function __copyright__ = "Copyright (C) 2012 Andreas Kloeckner" __license__ = """ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import six from six.moves import range import sys import numpy as np import loopy as lp import pyopencl as cl import pyopencl.clmath # noqa import pyopencl.clrandom # noqa import pytest import logging logger = logging.getLogger(__name__) try: import faulthandler except ImportError: pass else: faulthandler.enable() from pyopencl.tools import pytest_generate_tests_for_pyopencl \ as pytest_generate_tests __all__ = [ "pytest_generate_tests", "cl" # 'cl.create_some_context' ] from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa def test_globals_decl_once_with_multi_subprogram(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) np.random.seed(17) a = np.random.randn(16) cnst = np.random.randn(16) knl = lp.make_kernel( "{[i, ii]: 0<=i, ii id:h and tag:two > id:g and tag:two") print(knl) sr_keys = list(knl.substitutions.keys()) for letter, how_many in [ ("f", 1), ("g", 1), ("h", 2) ]: substs_with_letter = sum(1 for k in sr_keys if k.startswith(letter)) assert substs_with_letter == how_many def test_type_inference_no_artificial_doubles(ctx_factory): ctx = ctx_factory() knl = lp.make_kernel( "{[i]: 0<=i bb = a[i] - b[i] c[i] = bb """, [ lp.GlobalArg("a", np.float32, shape=("n",)), lp.GlobalArg("b", np.float32, shape=("n",)), lp.GlobalArg("c", np.float32, shape=("n",)), lp.ValueArg("n", np.int32), ], assumptions="n>=1") knl = lp.preprocess_kernel(knl, ctx.devices[0]) for k in lp.generate_loop_schedules(knl): code = lp.generate_code(k) assert "double" not in code def test_type_inference_with_type_dependencies(): knl = lp.make_kernel( "{[i]: i=0}", """ <>a = 99 a = a + 1 <>b = 0 <>c = 1 b = b + c + 1.0 c = b + c <>d = b + 2 + 1j """, "...") knl = lp.infer_unknown_types(knl) from loopy.types import to_loopy_type assert knl.temporary_variables["a"].dtype == to_loopy_type(np.int32) assert knl.temporary_variables["b"].dtype == to_loopy_type(np.float32) assert knl.temporary_variables["c"].dtype == to_loopy_type(np.float32) assert knl.temporary_variables["d"].dtype == to_loopy_type(np.complex128) def test_sized_and_complex_literals(ctx_factory): ctx = ctx_factory() knl = lp.make_kernel( "{[i]: 0<=i aa = 5jf <> bb = 5j a[i] = imag(aa) b[i] = imag(bb) c[i] = 5f """, [ lp.GlobalArg("a", np.float32, shape=("n",)), lp.GlobalArg("b", np.float32, shape=("n",)), lp.GlobalArg("c", np.float32, shape=("n",)), lp.ValueArg("n", np.int32), ], assumptions="n>=1") lp.auto_test_vs_ref(knl, ctx, knl, parameters=dict(n=5)) def test_simple_side_effect(ctx_factory): ctx = ctx_factory() knl = lp.make_kernel( "{[i,j]: 0<=i,j<100}", """ a[i] = a[i] + 1 """, [lp.GlobalArg("a", np.float32, shape=(100,))] ) knl = lp.preprocess_kernel(knl, ctx.devices[0]) kernel_gen = lp.generate_loop_schedules(knl) for gen_knl in kernel_gen: print(gen_knl) compiled = lp.CompiledKernel(ctx, gen_knl) print(compiled.get_code()) def test_owed_barriers(ctx_factory): ctx = ctx_factory() knl = lp.make_kernel( "{[i]: 0<=i<100}", [ " z[i] = a[i]" ], [lp.GlobalArg("a", np.float32, shape=(100,))] ) knl = lp.tag_inames(knl, dict(i="l.0")) knl = lp.preprocess_kernel(knl, ctx.devices[0]) kernel_gen = lp.generate_loop_schedules(knl) for gen_knl in kernel_gen: compiled = lp.CompiledKernel(ctx, gen_knl) print(compiled.get_code()) def test_wg_too_small(ctx_factory): ctx = ctx_factory() knl = lp.make_kernel( "{[i]: 0<=i<100}", [ " z[i] = a[i] {id=copy}" ], [lp.GlobalArg("a", np.float32, shape=(100,))], local_sizes={0: 16}) knl = lp.tag_inames(knl, dict(i="l.0")) knl = lp.preprocess_kernel(knl, ctx.devices[0]) kernel_gen = lp.generate_loop_schedules(knl) import pytest for gen_knl in kernel_gen: with pytest.raises(RuntimeError): lp.CompiledKernel(ctx, gen_knl).get_code() def test_multi_cse(ctx_factory): ctx = ctx_factory() knl = lp.make_kernel( "{[i]: 0<=i<100}", [ " z[i] = a[i] + a[i]**2" ], [lp.GlobalArg("a", np.float32, shape=(100,))], local_sizes={0: 16}) knl = lp.split_iname(knl, "i", 16, inner_tag="l.0") knl = lp.add_prefetch(knl, "a", []) knl = lp.preprocess_kernel(knl, ctx.devices[0]) kernel_gen = lp.generate_loop_schedules(knl) for gen_knl in kernel_gen: compiled = lp.CompiledKernel(ctx, gen_knl) print(compiled.get_code()) # {{{ code generator fuzzing def make_random_value(): from random import randrange, uniform v = randrange(3) if v == 0: while True: z = randrange(-1000, 1000) if z: return z elif v == 1: return uniform(-10, 10) else: cval = uniform(-10, 10) + 1j*uniform(-10, 10) if randrange(0, 2) == 0: return np.complex128(cval) else: return np.complex128(cval) def make_random_expression(var_values, size): from random import randrange import pymbolic.primitives as p v = randrange(1500) size[0] += 1 if v < 500 and size[0] < 40: term_count = randrange(2, 5) if randrange(2) < 1: cls = p.Sum else: cls = p.Product return cls(tuple( make_random_expression(var_values, size) for i in range(term_count))) elif v < 750: return make_random_value() elif v < 1000: var_name = "var_%d" % len(var_values) assert var_name not in var_values var_values[var_name] = make_random_value() return p.Variable(var_name) elif v < 1250: # Cannot use '-' because that destroys numpy constants. return p.Sum(( make_random_expression(var_values, size), - make_random_expression(var_values, size))) elif v < 1500: # Cannot use '/' because that destroys numpy constants. return p.Quotient( make_random_expression(var_values, size), make_random_expression(var_values, size)) def generate_random_fuzz_examples(count): for i in range(count): size = [0] var_values = {} expr = make_random_expression(var_values, size) yield expr, var_values def test_fuzz_code_generator(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) if ctx.devices[0].platform.vendor.startswith("Advanced Micro"): pytest.skip("crashes on AMD 15.12") #from expr_fuzz import get_fuzz_examples #for expr, var_values in get_fuzz_examples(): for expr, var_values in generate_random_fuzz_examples(50): from pymbolic import evaluate try: true_value = evaluate(expr, var_values) except ZeroDivisionError: continue def get_dtype(x): if isinstance(x, (complex, np.complexfloating)): return np.complex128 else: return np.float64 knl = lp.make_kernel("{ : }", [lp.Assignment("value", expr)], [lp.GlobalArg("value", np.complex128, shape=())] + [ lp.ValueArg(name, get_dtype(val)) for name, val in six.iteritems(var_values) ]) ck = lp.CompiledKernel(ctx, knl) evt, (lp_value,) = ck(queue, out_host=True, **var_values) err = abs(true_value-lp_value)/abs(true_value) if abs(err) > 1e-10: print(80*"-") print("WRONG: rel error=%g" % err) print("true=%r" % true_value) print("loopy=%r" % lp_value) print(80*"-") print(ck.get_code()) print(80*"-") print(var_values) print(80*"-") print(repr(expr)) print(80*"-") print(expr) print(80*"-") 1/0 # }}} def test_bare_data_dependency(ctx_factory): dtype = np.dtype(np.float32) ctx = ctx_factory() queue = cl.CommandQueue(ctx) knl = lp.make_kernel( [ "[znirp] -> {[i]: 0<=i znirp = n", "a[i] = 1", ], [ lp.GlobalArg("a", dtype, shape=("n"), order="C"), lp.ValueArg("n", np.int32), ]) cknl = lp.CompiledKernel(ctx, knl) n = 20000 evt, (a,) = cknl(queue, n=n, out_host=True) assert a.shape == (n,) assert (a == 1).all() # {{{ test race detection @pytest.mark.skipif("sys.version_info < (2,6)") def test_ilp_write_race_detection_global(ctx_factory): ctx = ctx_factory() knl = lp.make_kernel( "[n] -> {[i,j]: 0<=i,j a[i] = 5+i+j", ], []) knl = lp.tag_inames(knl, dict(i="l.0", j="ilp")) knl = lp.preprocess_kernel(knl, ctx.devices[0]) for k in lp.generate_loop_schedules(knl): assert k.temporary_variables["a"].shape == (16, 17) def test_ilp_write_race_avoidance_private(ctx_factory): ctx = ctx_factory() knl = lp.make_kernel( "{[j]: 0<=j<16 }", [ "<> a = 5+j", ], []) knl = lp.tag_inames(knl, dict(j="ilp")) knl = lp.preprocess_kernel(knl, ctx.devices[0]) for k in lp.generate_loop_schedules(knl): assert k.temporary_variables["a"].shape == (16,) # }}} def test_write_parameter(ctx_factory): dtype = np.float32 ctx = ctx_factory() knl = lp.make_kernel( "{[i,j]: 0<=i,j gid = i/256 start = gid*256 for j a[start + j] = a[start + j] + j end end """, seq_dependencies=True, name="uniform_l", target=PyOpenCLTarget(), assumptions="m<=%d and m>=1 and n mod %d = 0" % (bsize[0], bsize[0])) knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32)) kernel_info = CompiledKernel(ctx, knl).kernel_info(frozenset()) # noqa # }}} def test_nonlinear_index(ctx_factory): ctx = ctx_factory() knl = lp.make_kernel( "{[i,j]: 0<=i,j src_ibox = source_boxes[isrc_box] <> isrc_start = box_source_starts[src_ibox] <> isrc_end = isrc_start+box_source_counts_nonchild[src_ibox] <> strength = strengths[isrc] {id=set_strength} """, [ lp.GlobalArg("box_source_starts,box_source_counts_nonchild", None, shape=None), lp.GlobalArg("strengths", None, shape="nsources"), "..."]) print(knl) assert "isrc_box" in knl.insn_inames("set_strength") print(lp.CompiledKernel(ctx, knl).get_highlighted_code( dict( source_boxes=np.int32, box_source_starts=np.int32, box_source_counts_nonchild=np.int32, strengths=np.float64, nsources=np.int32, ))) def test_inames_deps_from_write_subscript(ctx_factory): knl = lp.make_kernel( "{[i,j]: 0<=i,j src_ibox = source_boxes[i] something = 5 a[src_ibox] = sum(j, something) {id=myred} """, [ lp.GlobalArg("box_source_starts,box_source_counts_nonchild,a", None, shape=None), "..."]) print(knl) assert "i" in knl.insn_inames("myred") def test_modulo_indexing(ctx_factory): ctx = ctx_factory() knl = lp.make_kernel( "{[i,j]: 0<=i my_a = a[i,j] {id=read_a} <> a_less_than_zero = my_a < 0 {dep=read_a,inames=i:j} my_a = 2*my_a {id=twice_a,dep=read_a,if=a_less_than_zero} my_a = my_a+1 {id=aplus,dep=twice_a,if=a_less_than_zero} out[i,j] = 2*my_a {dep=aplus} """, [ lp.GlobalArg("a", np.float32, shape=lp.auto), lp.GlobalArg("out", np.float32, shape=lp.auto), "..." ]) ref_knl = knl lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict( n=200 )) def test_ilp_loop_bound(ctx_factory): # The salient bit of this test is that a joint bound on (outer, inner) # from a split occurs in a setting where the inner loop has been ilp'ed. # In 'normal' parallel loops, the inner index is available for conditionals # throughout. In ILP'd loops, not so much. ctx = ctx_factory() knl = lp.make_kernel( "{ [i,j,k]: 0<=i,j,k temp[i, 0] = 17 temp[i, 1] = 15 """) knl = lp.tag_inames(knl, dict(i="l.0")) knl = lp.preprocess_kernel(knl) for k in lp.generate_loop_schedules(knl): code, _ = lp.generate_code(k) print(code) def test_make_copy_kernel(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) intermediate_format = "f,f,sep" a1 = np.random.randn(1024, 4, 3) cknl1 = lp.make_copy_kernel(intermediate_format) cknl1 = lp.fix_parameters(cknl1, n2=3) cknl1 = lp.set_options(cknl1, write_cl=True) evt, a2 = cknl1(queue, input=a1) cknl2 = lp.make_copy_kernel("c,c,c", intermediate_format) cknl2 = lp.fix_parameters(cknl2, n2=3) evt, a3 = cknl2(queue, input=a2) assert (a1 == a3).all() def test_auto_test_can_detect_problems(ctx_factory): ctx = ctx_factory() ref_knl = lp.make_kernel( "{[i,j]: 0<=i,j upper = 0 {id=init_upper} <> lower = 0 {id=init_lower} temp = 0 {id=init, atomic} for i upper = upper + i * a[i] {id=sum0,dep=init_upper} lower = lower - b[i] {id=sum1,dep=init_lower} end temp = temp + lower {id=temp_sum, dep=sum*:init, atomic,\ nosync=init} ... lbarrier {id=lb2, dep=temp_sum} out[j] = upper / temp {id=final, dep=lb2, atomic,\ nosync=init:temp_sum} end """, [ lp.GlobalArg("out", dtype, shape=lp.auto, for_atomic=True), lp.GlobalArg("a", dtype, shape=lp.auto), lp.GlobalArg("b", dtype, shape=lp.auto), lp.TemporaryVariable('temp', dtype, for_atomic=True, scope=AddressSpace.LOCAL), "..." ], silenced_warnings=["write_race(init)", "write_race(temp_sum)"]) knl = lp.fix_parameters(knl, n=n) knl = lp.split_iname(knl, "j", vec_width, inner_tag="l.0") _, out = knl(queue, a=np.arange(n, dtype=dtype), b=np.arange(n, dtype=dtype)) assert np.allclose(out, np.full_like(out, ((1 - 2 * n) / 3.0))) @pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64]) def test_atomic_init(dtype): vec_width = 4 knl = lp.make_kernel( "{ [i,j]: 0<=i<100 }", """ out[i%4] = 0 {id=init, atomic=init} """, [ lp.GlobalArg("out", dtype, shape=lp.auto, for_atomic=True), "..." ], silenced_warnings=["write_race(init)"]) knl = lp.split_iname(knl, 'i', vec_width, inner_tag='l.0') print(knl) print(lp.generate_code_v2(knl).device_code()) def test_within_inames_and_reduction(): # See https://github.com/inducer/loopy/issues/24 # This is (purposefully) somewhat un-idiomatic, to replicate the conditions # under which the above bug was found. If assignees were phi[i], then the # iname propagation heuristic would not assume that dependent instructions # need to run inside of 'i', and hence the forced_iname_* bits below would not # be needed. i1 = lp.CInstruction("i", "doSomethingToGetPhi();", assignees="phi") from pymbolic.primitives import Subscript, Variable i2 = lp.Assignment("a", lp.Reduction("sum", "j", Subscript(Variable("phi"), Variable("j"))), within_inames=frozenset(), within_inames_is_final=True) k = lp.make_kernel("{[i,j] : 0<=i,jt = i ... gbarrier out[i] = t end """, seq_dependencies=True) if hw_loop: knl = lp.tag_inames(knl, dict(i="g.0")) save_and_reload_temporaries_test(queue, knl, np.arange(8), debug) def test_save_of_private_array(ctx_factory, debug=False): ctx = ctx_factory() queue = cl.CommandQueue(ctx) knl = lp.make_kernel( "{ [i]: 0<=i<8 }", """ for i <>t[i] = i ... gbarrier out[i] = t[i] end """, seq_dependencies=True) knl = lp.set_temporary_scope(knl, "t", "private") save_and_reload_temporaries_test(queue, knl, np.arange(8), debug) def test_save_of_private_array_in_hw_loop(ctx_factory, debug=False): ctx = ctx_factory() queue = cl.CommandQueue(ctx) knl = lp.make_kernel( "{ [i,j,k]: 0<=i,j,k<8 }", """ for i for j <>t[j] = j end ... gbarrier for k out[i,k] = t[k] end end """, seq_dependencies=True) knl = lp.tag_inames(knl, dict(i="g.0")) knl = lp.set_temporary_scope(knl, "t", "private") save_and_reload_temporaries_test( queue, knl, np.vstack((8 * (np.arange(8),))), debug) def test_save_of_private_multidim_array(ctx_factory, debug=False): ctx = ctx_factory() queue = cl.CommandQueue(ctx) knl = lp.make_kernel( "{ [i,j,k,l,m]: 0<=i,j,k,l,m<8 }", """ for i for j, k <>t[j,k] = k end ... gbarrier for l, m out[i,l,m] = t[l,m] end end """, seq_dependencies=True) knl = lp.set_temporary_scope(knl, "t", "private") result = np.array([np.vstack((8 * (np.arange(8),))) for i in range(8)]) save_and_reload_temporaries_test(queue, knl, result, debug) def test_save_of_private_multidim_array_in_hw_loop(ctx_factory, debug=False): ctx = ctx_factory() queue = cl.CommandQueue(ctx) knl = lp.make_kernel( "{ [i,j,k,l,m]: 0<=i,j,k,l,m<8 }", """ for i for j, k <>t[j,k] = k end ... gbarrier for l, m out[i,l,m] = t[l,m] end end """, seq_dependencies=True) knl = lp.set_temporary_scope(knl, "t", "private") knl = lp.tag_inames(knl, dict(i="g.0")) result = np.array([np.vstack((8 * (np.arange(8),))) for i in range(8)]) save_and_reload_temporaries_test(queue, knl, result, debug) @pytest.mark.parametrize("hw_loop", [True, False]) def test_save_of_multiple_private_temporaries(ctx_factory, hw_loop, debug=False): ctx = ctx_factory() queue = cl.CommandQueue(ctx) knl = lp.make_kernel( "{ [i,j,k]: 0<=i,j,k<10 }", """ for i for k <> t_arr[k] = k end <> t_scalar = 1 for j ... gbarrier out[j] = t_scalar ... gbarrier t_scalar = 10 end ... gbarrier <> flag = i == 9 out[i] = t_arr[i] {if=flag} end """, seq_dependencies=True) knl = lp.set_temporary_scope(knl, "t_arr", "private") if hw_loop: knl = lp.tag_inames(knl, dict(i="g.0")) result = np.array([1, 10, 10, 10, 10, 10, 10, 10, 10, 9]) save_and_reload_temporaries_test(queue, knl, result, debug) def test_save_of_local_array(ctx_factory, debug=False): ctx = ctx_factory() queue = cl.CommandQueue(ctx) knl = lp.make_kernel( "{ [i,j]: 0<=i,j<8 }", """ for i, j <>t[2*j] = j t[2*j+1] = j ... gbarrier out[i] = t[2*i] end """, seq_dependencies=True) knl = lp.set_temporary_scope(knl, "t", "local") knl = lp.tag_inames(knl, dict(i="g.0", j="l.0")) save_and_reload_temporaries_test(queue, knl, np.arange(8), debug) def test_save_of_local_array_with_explicit_local_barrier(ctx_factory, debug=False): ctx = ctx_factory() queue = cl.CommandQueue(ctx) knl = lp.make_kernel( "{ [i,j]: 0<=i,j<8 }", """ for i, j <>t[2*j] = j ... lbarrier t[2*j+1] = t[2*j] ... gbarrier out[i] = t[2*i] end """, seq_dependencies=True) knl = lp.set_temporary_scope(knl, "t", "local") knl = lp.tag_inames(knl, dict(i="g.0", j="l.0")) save_and_reload_temporaries_test(queue, knl, np.arange(8), debug) def test_save_local_multidim_array(ctx_factory, debug=False): ctx = ctx_factory() queue = cl.CommandQueue(ctx) knl = lp.make_kernel( "{ [i,j,k]: 0<=i<2 and 0<=k<3 and 0<=j<2}", """ for i, j, k ... gbarrier <> t_local[k,j] = 1 ... gbarrier out[k,i*2+j] = t_local[k,j] end """, seq_dependencies=True) knl = lp.set_temporary_scope(knl, "t_local", "local") knl = lp.tag_inames(knl, dict(j="l.0", i="g.0")) save_and_reload_temporaries_test(queue, knl, 1, debug) def test_save_with_base_storage(ctx_factory, debug=False): ctx = ctx_factory() queue = cl.CommandQueue(ctx) knl = lp.make_kernel( "{[i]: 0 <= i < 10}", """ <>a[i] = 0 <>b[i] = i ... gbarrier out[i] = a[i] """, "...", seq_dependencies=True) knl = lp.tag_inames(knl, dict(i="l.0")) knl = lp.set_temporary_scope(knl, "a", "local") knl = lp.set_temporary_scope(knl, "b", "local") knl = lp.alias_temporaries(knl, ["a", "b"], synchronize_for_exclusive_use=False) save_and_reload_temporaries_test(queue, knl, np.arange(10), debug) def test_save_ambiguous_storage_requirements(): knl = lp.make_kernel( "{[i,j]: 0 <= i < 10 and 0 <= j < 10}", """ <>a[j] = j ... gbarrier out[i,j] = a[j] """, seq_dependencies=True) knl = lp.tag_inames(knl, dict(i="g.0", j="l.0")) knl = lp.duplicate_inames(knl, "j", within="writes:out", tags={"j": "l.0"}) knl = lp.set_temporary_scope(knl, "a", "local") knl = lp.preprocess_kernel(knl) knl = lp.get_one_scheduled_kernel(knl) from loopy.diagnostic import LoopyError with pytest.raises(LoopyError): lp.save_and_reload_temporaries(knl) def test_save_across_inames_with_same_tag(ctx_factory, debug=False): ctx = ctx_factory() queue = cl.CommandQueue(ctx) knl = lp.make_kernel( "{[i]: 0 <= i < 10}", """ <>a[i] = i ... gbarrier out[i] = a[i] """, "...", seq_dependencies=True) knl = lp.tag_inames(knl, dict(i="l.0")) knl = lp.duplicate_inames(knl, "i", within="reads:a", tags={"i": "l.0"}) save_and_reload_temporaries_test(queue, knl, np.arange(10), debug) def test_missing_temporary_definition_detection(): knl = lp.make_kernel( "{ [i]: 0<=i<10 }", """ for i <> t = 1 ... gbarrier out[i] = t end """, seq_dependencies=True) from loopy.diagnostic import MissingDefinitionError with pytest.raises(MissingDefinitionError): lp.generate_code_v2(knl) def test_missing_definition_check_respects_aliases(): # Based on https://github.com/inducer/loopy/issues/69 knl = lp.make_kernel("{ [i] : 0<=i c[i] = a[i + 1] ... gbarrier out[i] = c[i] end """, seq_dependencies=True) knl = lp.add_and_infer_dtypes(knl, {"a": np.float32, "c": np.float32, "out": np.float32, "n": np.int32}) knl = lp.set_temporary_scope(knl, "c", "global") ref_knl = knl knl = lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0") cgr = lp.generate_code_v2(knl) assert len(cgr.device_programs) == 2 #print(cgr.device_code()) #print(cgr.host_code()) lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5)) def test_assign_to_linear_subscript(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) knl1 = lp.make_kernel( "{ [i]: 0<=i aa = 5jf <> bb = 5j a[i] = imag(aa) b[i] = imag(bb) c[i] = 5f end """, seq_dependencies=True) print(knl.stringify(with_dependencies=True)) lp.auto_test_vs_ref(knl, ctx, knl, parameters=dict(n=5)) def test_nop(ctx_factory): ctx = ctx_factory() knl = lp.make_kernel( "{[i,itrip]: 0<=i z[i] = z[i+1] + z[i] {id=wr_z} <> v[i] = 11 {id=wr_v} ... nop {dep=wr_z:wr_v,id=yoink} z[i] = z[i] - z[i+1] + v[i] {dep=yoink} end """) print(knl) knl = lp.fix_parameters(knl, n=15) knl = lp.add_and_infer_dtypes(knl, {"z": np.float64}) lp.auto_test_vs_ref(knl, ctx, knl, parameters=dict(ntrips=5)) def test_global_barrier(ctx_factory): ctx = ctx_factory() knl = lp.make_kernel( "{[i,itrip]: 0<=i z[i] = z[i+1] + z[i] {id=wr_z,dep=top} <> v[i] = 11 {id=wr_v,dep=top} ... gbarrier {dep=wr_z:wr_v,id=yoink} z[i] = z[i] - z[i+1] + v[i] {id=iupd, dep=wr_z} end ... gbarrier {dep=iupd,id=postloop} z[i] = z[i] - z[i+1] + v[i] {dep=postloop} end """) knl = lp.fix_parameters(knl, ntrips=3) knl = lp.add_and_infer_dtypes(knl, {"z": np.float64}) ref_knl = knl ref_knl = lp.set_temporary_scope(ref_knl, "z", "global") ref_knl = lp.set_temporary_scope(ref_knl, "v", "global") knl = lp.split_iname(knl, "i", 256, outer_tag="g.0", inner_tag="l.0") print(knl) knl = lp.preprocess_kernel(knl) assert knl.temporary_variables["z"].address_space == lp.AddressSpace.GLOBAL assert knl.temporary_variables["v"].address_space == lp.AddressSpace.GLOBAL print(knl) lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(ntrips=5, n=10)) def test_missing_global_barrier(): knl = lp.make_kernel( "{[i,itrip]: 0<=i z[i] = z[i] - z[i+1] {id=iupd,dep=yoink} end # This is where the barrier should be z[i] = z[i] - z[i+1] + v[i] {dep=iupd} end """) knl = lp.set_temporary_scope(knl, "z", "global") knl = lp.split_iname(knl, "i", 256, outer_tag="g.0") knl = lp.preprocess_kernel(knl) from loopy.diagnostic import MissingBarrierError with pytest.raises(MissingBarrierError): lp.get_one_scheduled_kernel(knl) def test_index_cse(ctx_factory): knl = lp.make_kernel(["{[i,j,k,l,m]:0<=i,j,k,l,m Tcond = T[k] < 0.5 if Tcond cp[k] = 2 * T[k] + Tcond end end """) knl = lp.fix_parameters(knl, n=200) knl = lp.add_and_infer_dtypes(knl, {"T": np.float32}) ref_knl = knl knl = lp.split_iname(knl, 'k', 2, inner_tag='ilp') lp.auto_test_vs_ref(ref_knl, ctx, knl) def test_unr_and_conditionals(ctx_factory): ctx = ctx_factory() knl = lp.make_kernel('{[k]: 0<=k Tcond[k] = T[k] < 0.5 if Tcond[k] cp[k] = 2 * T[k] + Tcond[k] end end """) knl = lp.fix_parameters(knl, n=200) knl = lp.add_and_infer_dtypes(knl, {"T": np.float32}) ref_knl = knl knl = lp.split_iname(knl, 'k', 2, inner_tag='unr') lp.auto_test_vs_ref(ref_knl, ctx, knl) def test_constant_array_args(ctx_factory): ctx = ctx_factory() knl = lp.make_kernel('{[k]: 0<=k Tcond[k] = T[k] < 0.5 if Tcond[k] cp[k] = 2 * T[k] + Tcond[k] end end """, [lp.ConstantArg('T', shape=(200,), dtype=np.float32), '...']) knl = lp.fix_parameters(knl, n=200) lp.auto_test_vs_ref(knl, ctx, knl) @pytest.mark.parametrize("src_order", ["C"]) @pytest.mark.parametrize("tmp_order", ["C", "F"]) def test_temp_initializer(ctx_factory, src_order, tmp_order): a = np.random.randn(3, 3).copy(order=src_order) ctx = ctx_factory() queue = cl.CommandQueue(ctx) knl = lp.make_kernel( "{[i,j]: 0<=i,j { [j] : 2 * i - 2 < j <= 2 * i and 0 <= j <= 9 }"], """ for i for j out[j] = j end end """, silenced_warnings="write_race(insn)") knl = lp.split_iname(knl, "i", 5, inner_tag="l.0", outer_tag="g.0") evt, (out,) = knl(queue, out_host=True) assert (out == np.arange(10)).all() def test_tight_loop_bounds_codegen(): knl = lp.make_kernel( ["{ [i] : 0 <= i <= 5 }", "[i] -> { [j] : 2 * i - 2 <= j <= 2 * i and 0 <= j <= 9 }"], """ for i for j out[j] = j end end """, silenced_warnings="write_race(insn)", target=lp.OpenCLTarget()) knl = lp.split_iname(knl, "i", 5, inner_tag="l.0", outer_tag="g.0") cgr = lp.generate_code_v2(knl) #print(cgr.device_code()) for_loop = \ "for (int j = " \ "(gid(0) == 0 && lid(0) == 0 ? 0 : -2 + 2 * lid(0) + 10 * gid(0)); " \ "j <= (-1 + gid(0) == 0 && lid(0) == 0 ? 9 : 2 * lid(0)); ++j)" assert for_loop in cgr.device_code() def test_unscheduled_insn_detection(): knl = lp.make_kernel( "{ [i]: 0 <= i < 10 }", """ out[i] = i {id=insn1} """, "...") knl = lp.get_one_scheduled_kernel(lp.preprocess_kernel(knl)) insn1, = lp.find_instructions(knl, "id:insn1") knl.instructions.append(insn1.copy(id="insn2")) from loopy.diagnostic import UnscheduledInstructionError with pytest.raises(UnscheduledInstructionError): lp.generate_code(knl) def test_integer_reduction(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) from loopy.types import to_loopy_type n = 200 for vtype in [np.int32, np.int64]: var_int = np.random.randint(1000, size=n).astype(vtype) var_lp = lp.TemporaryVariable('var', initializer=var_int, read_only=True, scope=lp.AddressSpace.PRIVATE, dtype=to_loopy_type(vtype), shape=lp.auto) from collections import namedtuple ReductionTest = namedtuple('ReductionTest', 'kind, check, args') reductions = [ ReductionTest('max', lambda x: x == np.max(var_int), args='var[k]'), ReductionTest('min', lambda x: x == np.min(var_int), args='var[k]'), ReductionTest('sum', lambda x: x == np.sum(var_int), args='var[k]'), ReductionTest('product', lambda x: x == np.prod(var_int), args='var[k]'), ReductionTest('argmax', lambda x: ( x[0] == np.max(var_int) and var_int[out[1]] == np.max(var_int)), args='var[k], k'), ReductionTest('argmin', lambda x: ( x[0] == np.min(var_int) and var_int[out[1]] == np.min(var_int)), args='var[k], k') ] for reduction, function, args in reductions: kstr = ("out" if 'arg' not in reduction else "out[0], out[1]") kstr += ' = {0}(k, {1})'.format(reduction, args) knl = lp.make_kernel('{[k]: 0<=k dist_sq = sum(idim, (tgt[idim,itgt] - center[idim,ictr])**2) <> in_disk = dist_sq < (radius[ictr]*1.05)**2 <> matches = ( (in_disk and qbx_forced_limit == 0) or (in_disk and qbx_forced_limit != 0 and qbx_forced_limit * center_side[ictr] > 0) ) <> post_dist_sq = if(matches, dist_sq, HUGE) end <> min_dist_sq, <> min_ictr = argmin(ictr, ictr, post_dist_sq) tgt_to_qbx_center[itgt] = if(min_dist_sq < HUGE, min_ictr, -1) end """) knl = lp.fix_parameters(knl, ambient_dim=2) knl = lp.add_and_infer_dtypes(knl, { "tgt,center,radius,HUGE": np.float32, "center_side,qbx_forced_limit": np.int32, }) lp.auto_test_vs_ref(knl, cl_ctx, knl, parameters={ "HUGE": 1e20, "ncenters": 200, "ntargets": 300, "qbx_forced_limit": 1}) def test_nosync_option_parsing(): knl = lp.make_kernel( "{[i]: 0 <= i < 10}", """ <>t = 1 {id=insn1,nosync=insn1} t = 2 {id=insn2,nosync=insn1:insn2} t = 3 {id=insn3,nosync=insn1@local:insn2@global:insn3@any} t = 4 {id=insn4,nosync_query=id:insn*@local} t = 5 {id=insn5,nosync_query=id:insn1} """, options=lp.Options(allow_terminal_colors=False)) kernel_str = str(knl) print(kernel_str) assert "id=insn1, no_sync_with=insn1@any" in kernel_str assert "id=insn2, no_sync_with=insn1@any:insn2@any" in kernel_str assert "id=insn3, no_sync_with=insn1@local:insn2@global:insn3@any" in kernel_str assert "id=insn4, no_sync_with=insn1@local:insn2@local:insn3@local:insn5@local" in kernel_str # noqa assert "id=insn5, no_sync_with=insn1@any" in kernel_str def barrier_between(knl, id1, id2, ignore_barriers_in_levels=()): from loopy.schedule import (RunInstruction, Barrier, EnterLoop, LeaveLoop, CallKernel, ReturnFromKernel) watch_for_barrier = False seen_barrier = False loop_level = 0 for sched_item in knl.schedule: if isinstance(sched_item, RunInstruction): if sched_item.insn_id == id1: watch_for_barrier = True elif sched_item.insn_id == id2: return watch_for_barrier and seen_barrier elif isinstance(sched_item, Barrier): if watch_for_barrier and loop_level not in ignore_barriers_in_levels: seen_barrier = True elif isinstance(sched_item, EnterLoop): loop_level += 1 elif isinstance(sched_item, LeaveLoop): loop_level -= 1 elif isinstance(sched_item, (CallKernel, ReturnFromKernel)): pass else: raise RuntimeError("schedule item type '%s' not understood" % type(sched_item).__name__) raise RuntimeError("id2 was not seen") def test_barrier_insertion_near_top_of_loop(): knl = lp.make_kernel( "{[i,j]: 0 <= i,j < 10 }", """ for i <>a[i] = i {id=ainit} for j <>t = a[(i + 1) % 10] {id=tcomp} <>b[i,j] = a[i] + t {id=bcomp1} b[i,j] = b[i,j] + 1 {id=bcomp2} end end """, seq_dependencies=True) knl = lp.tag_inames(knl, dict(i="l.0")) knl = lp.set_temporary_scope(knl, "a", "local") knl = lp.set_temporary_scope(knl, "b", "local") knl = lp.get_one_scheduled_kernel(lp.preprocess_kernel(knl)) print(knl) assert barrier_between(knl, "ainit", "tcomp") assert barrier_between(knl, "tcomp", "bcomp1") assert barrier_between(knl, "bcomp1", "bcomp2") def test_barrier_insertion_near_bottom_of_loop(): knl = lp.make_kernel( ["{[i]: 0 <= i < 10 }", "[jmax] -> {[j]: 0 <= j < jmax}"], """ for i <>a[i] = i {id=ainit} for j <>b[i,j] = a[i] + t {id=bcomp1} b[i,j] = b[i,j] + 1 {id=bcomp2} end a[i] = i + 1 {id=aupdate} end """, seq_dependencies=True) knl = lp.tag_inames(knl, dict(i="l.0")) knl = lp.set_temporary_scope(knl, "a", "local") knl = lp.set_temporary_scope(knl, "b", "local") knl = lp.get_one_scheduled_kernel(lp.preprocess_kernel(knl)) print(knl) assert barrier_between(knl, "bcomp1", "bcomp2") assert barrier_between(knl, "ainit", "aupdate", ignore_barriers_in_levels=[1]) def test_barrier_in_overridden_get_grid_size_expanded_kernel(): # make simple barrier'd kernel knl = lp.make_kernel('{[i]: 0 <= i < 10}', """ for i a[i] = i {id=a} ... lbarrier {id=barrier} b[i + 1] = a[i] {nosync=a} end """, [lp.TemporaryVariable("a", np.float32, shape=(10,), order='C', scope=lp.AddressSpace.LOCAL), lp.GlobalArg("b", np.float32, shape=(11,), order='C')], seq_dependencies=True) # split into kernel w/ vesize larger than iname domain vecsize = 16 knl = lp.split_iname(knl, 'i', vecsize, inner_tag='l.0') from testlib import GridOverride # artifically expand via overridden_get_grid_sizes_for_insn_ids knl = knl.copy(overridden_get_grid_sizes_for_insn_ids=GridOverride( knl.copy(), vecsize)) # make sure we can generate the code lp.generate_code_v2(knl) def test_multi_argument_reduction_type_inference(): from loopy.type_inference import TypeInferenceMapper from loopy.library.reduction import SegmentedSumReductionOperation from loopy.types import to_loopy_type op = SegmentedSumReductionOperation() knl = lp.make_kernel("{[i,j]: 0<=i<10 and 0<=j z[i] = z[i+1] + z[i] {id=wr_z,dep=top} <> v[i] = 11 {id=wr_v,dep=top} ... gbarrier {dep=wr_z:wr_v,id=yoink} z[i] = z[i] - z[i+1] + v[i] {id=iupd, dep=yoink} end ... nop {id=nop} ... gbarrier {dep=iupd,id=postloop} z[i] = z[i] - z[i+1] + v[i] {id=zzzv,dep=postloop} end """) assert lp.get_global_barrier_order(knl) == ("top", "yoink", "postloop") for insn, barrier in ( ("nop", None), ("top", None), ("wr_z", "top"), ("wr_v", "top"), ("yoink", "top"), ("postloop", "yoink"), ("zzzv", "postloop")): assert lp.find_most_recent_global_barrier(knl, insn) == barrier def test_global_barrier_error_if_unordered(): # FIXME: Should be illegal to declare this knl = lp.make_kernel("{[i]: 0 <= i < 10}", """ ... gbarrier ... gbarrier """) from loopy.diagnostic import LoopyError with pytest.raises(LoopyError): lp.get_global_barrier_order(knl) def test_struct_assignment(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) bbhit = np.dtype([ ("tmin", np.float32), ("tmax", np.float32), ("bi", np.int32), ("hit", np.int32)]) bbhit, bbhit_c_decl = cl.tools.match_dtype_to_c_struct( ctx.devices[0], "bbhit", bbhit) bbhit = cl.tools.get_or_register_dtype('bbhit', bbhit) preamble = bbhit_c_decl knl = lp.make_kernel( "{ [i]: 0<=itmp1 = 0 end for j ... gbarrier <>tmp2 = i end """, "...", seq_dependencies=True) knl = lp.tag_inames(knl, dict(i="g.0")) with cl.CommandQueue(ctx) as queue: knl(queue) def test_kernel_var_name_generator(): knl = lp.make_kernel( "{[i]: 0 <= i <= 10}", """ <>a = 0 <>b_s0 = 0 """) vng = knl.get_var_name_generator() assert vng("a_s0") != "a_s0" assert vng("b") != "b" def test_fixed_parameters(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) knl = lp.make_kernel( "[n] -> {[i]: 0 <= i < n}", """ <>tmp[i] = i {id=init} tmp[0] = 0 {dep=init} """, fixed_parameters=dict(n=1)) knl(queue) def test_parameter_inference(): knl = lp.make_kernel("{[i]: 0 <= i < n and i mod 2 = 0}", "") assert knl.all_params() == set(["n"]) def test_execution_backend_can_cache_dtypes(ctx_factory): # When the kernel is invoked, the execution backend uses it as a cache key # for the type inference and scheduling cache. This tests to make sure that # dtypes in the kernel can be cached, even though they may not have a # target. ctx = ctx_factory() queue = cl.CommandQueue(ctx) knl = lp.make_kernel("{[i]: 0 <= i < 10}", "<>tmp[i] = i") knl = lp.add_dtypes(knl, dict(tmp=int)) knl(queue) def test_wildcard_dep_matching(): knl = lp.make_kernel( "{[i]: 0 <= i < 10}", """ <>a = 0 {id=insn1} <>b = 0 {id=insn2,dep=insn?} <>c = 0 {id=insn3,dep=insn*} <>d = 0 {id=insn4,dep=insn[12]} <>e = 0 {id=insn5,dep=insn[!1]} """, "...") all_insns = set("insn%d" % i for i in range(1, 6)) assert knl.id_to_insn["insn1"].depends_on == set() assert knl.id_to_insn["insn2"].depends_on == all_insns - set(["insn2"]) assert knl.id_to_insn["insn3"].depends_on == all_insns - set(["insn3"]) assert knl.id_to_insn["insn4"].depends_on == set(["insn1", "insn2"]) assert knl.id_to_insn["insn5"].depends_on == all_insns - set(["insn1", "insn5"]) def test_preamble_with_separate_temporaries(ctx_factory): # create a function mangler # and finally create a test n = 10 # for each entry come up with a random number of data points num_data = np.asarray(np.random.randint(2, 10, size=n), dtype=np.int32) # turn into offsets offsets = np.asarray(np.hstack(([0], np.cumsum(num_data))), dtype=np.int32) # create lookup data lookup = np.empty(0) for i in num_data: lookup = np.hstack((lookup, np.arange(i))) lookup = np.asarray(lookup, dtype=np.int32) # and create data array data = np.random.rand(np.product(num_data)) # make kernel kernel = lp.make_kernel('{[i]: 0 <= i < n}', """ for i <>ind = indirect(offsets[i], offsets[i + 1], 1) out[i] = data[ind] end """, [lp.GlobalArg('out', shape=('n',)), lp.TemporaryVariable( 'offsets', shape=(offsets.size,), initializer=offsets, scope=lp.AddressSpace.GLOBAL, read_only=True), lp.GlobalArg('data', shape=(data.size,), dtype=np.float64)], ) # fixt params, and add manglers / preamble from testlib import ( SeparateTemporariesPreambleTestMangler, SeparateTemporariesPreambleTestPreambleGenerator, ) func_info = dict( func_name='indirect', func_arg_dtypes=(np.int32, np.int32, np.int32), func_result_dtypes=(np.int32,), arr=lookup ) kernel = lp.fix_parameters(kernel, **{'n': n}) kernel = lp.register_preamble_generators( kernel, [SeparateTemporariesPreambleTestPreambleGenerator(**func_info)]) kernel = lp.register_function_manglers( kernel, [SeparateTemporariesPreambleTestMangler(**func_info)]) print(lp.generate_code(kernel)[0]) # and call (functionality unimportant, more that it compiles) ctx = ctx_factory() queue = cl.CommandQueue(ctx) # check that it actually performs the lookup correctly assert np.allclose(kernel( queue, data=data.flatten('C'))[1][0], data[offsets[:-1] + 1]) def test_arg_inference_for_predicates(): knl = lp.make_kernel("{[i]: 0 <= i < 10}", """ if incr[i] a = a + 1 end """) assert "incr" in knl.arg_dict assert knl.arg_dict["incr"].shape == (10,) def test_relaxed_stride_checks(ctx_factory): # Check that loopy is compatible with numpy's relaxed stride rules. ctx = ctx_factory() knl = lp.make_kernel("{[i,j]: 0 <= i <= n and 0 <= j <= m}", """ a[i] = sum(j, A[i,j] * b[j]) """) with cl.CommandQueue(ctx) as queue: mat = np.zeros((1, 10), order="F") b = np.zeros(10) evt, (a,) = knl(queue, A=mat, b=b) assert a == 0 def test_add_prefetch_works_in_lhs_index(): knl = lp.make_kernel( "{ [n,k,l,k1,l1,k2,l2]: " "start<=n a1_tmp[k,l] = a1[a1_map[n, k],l] a1_tmp[k1,l1] = a1_tmp[k1,l1] + 1 a1_out[a1_map[n,k2], l2] = a1_tmp[k2,l2] end """, [ lp.GlobalArg("a1,a1_out", None, "ndofs,2"), lp.GlobalArg("a1_map", None, "nelements,3"), "..." ]) knl = lp.add_prefetch(knl, "a1_map", "k", default_tag="l.auto") from loopy.symbolic import get_dependencies for insn in knl.instructions: assert "a1_map" not in get_dependencies(insn.assignees) def test_explicit_simd_shuffles(ctx_factory): ctx = ctx_factory() def create_and_test(insn, answer=None, atomic=False, additional_check=None, store=False): knl = lp.make_kernel(['{[i]: 0 <= i < 12}', '{[j]: 0 <= j < 1}'], insn, [lp.GlobalArg('a', shape=(1, 14,), dtype=np.int32, for_atomic=atomic), lp.GlobalArg('b', shape=(1, 14,), dtype=np.int32, for_atomic=atomic)]) knl = lp.split_iname(knl, 'i', 4, inner_tag='vec') knl = lp.tag_inames(knl, [('j', 'g.0')]) knl = lp.split_array_axis(knl, ['a', 'b'], 1, 4) knl = lp.tag_array_axes(knl, ['a', 'b'], 'N1,N0,vec') print(lp.generate_code_v2(knl).device_code()) queue = cl.CommandQueue(ctx) if answer is None: answer = np.zeros(16, dtype=np.int32) if store: answer[2:-2] = np.arange(0, 12, dtype=np.int32) else: answer[:-4] = np.arange(2, 14, dtype=np.int32) a = np.zeros((1, 4, 4), dtype=np.int32) b = np.arange(16, dtype=np.int32).reshape((1, 4, 4)) result = knl(queue, a=a, b=b)[1][0] assert np.array_equal(result.flatten('C'), answer) if additional_check is not None: assert additional_check(knl) # test w/ compile time temporary constant create_and_test("<>c = 2\n" + "a[j, i] = b[j, i + c]", additional_check=lambda knl: 'vload' in lp.generate_code_v2( knl).device_code()) create_and_test("a[j, i] = b[j, i + 2]") create_and_test("a[j, i] = b[j, i + 2] + a[j, i]") create_and_test("a[j, i] = a[j, i] + b[j, i + 2]") # test vector stores create_and_test("<>c = 2\n" + "a[j, i + c] = b[j, i]", additional_check=lambda knl: 'vstore' in lp.generate_code_v2( knl).device_code(), store=True) create_and_test("a[j, i + 2] = b[j, i]", store=True) create_and_test("a[j, i + 2] = b[j, i] + a[j, i + 2]", store=True) create_and_test("a[j, i + 2] = a[j, i + 2] + b[j, i]", store=True) # test small vector shuffle shuffled = np.arange(16, dtype=np.int32)[(np.arange(16) + 2) % 4 + 4 * (np.arange(16) // 4)] shuffled[12:] = 0 create_and_test("a[j, i] = b[j, (i + 2) % 4 + 4 * (i // 4)]", shuffled) create_and_test("a[j, (i + 2) % 4 + 4 * (i // 4)] = b[j, i]", shuffled) # test atomics from loopy import LoopyError from loopy.codegen import Unvectorizable with pytest.raises((LoopyError, Unvectorizable)): temp = np.arange(12, dtype=np.int32) answer = np.zeros(4, dtype=np.int32) for i in range(4): answer[i] = np.sum(temp[(i + 2) % 4::4]) create_and_test("a[j, (i + 2) % 4] = a[j, (i + 2) % 4] + b[j, i] {atomic}", answer, True) def test_explicit_simd_unr_iname(ctx_factory): """ tests as scatter load to a specific lane of a vector array via an unrolled iname """ ctx = ctx_factory() queue = cl.CommandQueue(ctx) insns = """ for j_outer, lane, i a[j_outer, i, lane] = b[j_outer + lane, i] end """ knl = lp.make_kernel( ['{[j_outer]: 0 <= j_outer < 4}', '{[i]: 0 <= i < 4}', '{[lane]: 0 <= lane < 4}'], insns, [lp.GlobalArg('a', shape=(4, 4, 4)), lp.GlobalArg('b', shape=(8, 4))]) knl = lp.tag_array_axes(knl, 'a', 'N1,N0,vec') knl = lp.tag_inames(knl, {'lane': 'unr'}) knl = lp.prioritize_loops(knl, 'j_outer, i, lane') a = np.zeros((4, 4, 4)) b = np.arange(8 * 4).reshape((8, 4)) a = knl(queue, a=a, b=b)[1][0] # create answer ans = np.tile(np.arange(4, dtype=np.float64), 16).reshape((4, 4, 4)) ans *= 4 ans += 4 * np.arange(4)[:, np.newaxis, np.newaxis] + np.arange(4)[:, np.newaxis] assert np.array_equal(a, ans) def test_explicit_simd_temporary_promotion(ctx_factory): from loopy.kernel.data import temp_var_scope as scopes ctx = ctx_factory() queue = cl.CommandQueue(ctx) # fun with vector temporaries def make_kernel(insn, ans=None, preamble=None, extra_inames=None, skeleton=None, dtype=None): skeleton = """ %(preamble)s for j for i %(insn)s if test a[i, j] = 1 end end end """ if skeleton is None else skeleton dtype = dtype if dtype is not None else ( ans.dtype if ans is not None else np.int32) inames = ['i, j'] if extra_inames is not None: inames += list(extra_inames) knl = lp.make_kernel( '{[%(inames)s]: 0 <= %(inames)s < 12}' % {'inames': ', '.join(inames)}, skeleton % dict(insn=insn, preamble='' if not preamble else preamble), [lp.GlobalArg('a', shape=(12, 12), dtype=dtype), lp.TemporaryVariable('mask', shape=(12,), initializer=np.array( np.arange(12) >= 6, dtype=dtype), read_only=True, scope=scopes.GLOBAL)]) knl = lp.split_iname(knl, 'j', 4, inner_tag='vec') knl = lp.split_array_axis(knl, 'a', 1, 4) knl = lp.tag_array_axes(knl, 'a', 'N1,N0,vec') knl = lp.preprocess_kernel(knl) if ans is not None: assert np.array_equal(knl(queue, a=np.zeros((12, 3, 4), dtype=dtype))[ 1][0], ans) return knl ans = np.zeros((12, 3, 4)) ans[6:, :, :] = 1 # case 1) -- incorrect promotion of temporaries to vector dtypes make_kernel('<> test = mask[i]', ans) # next test the writer heuristic # case 2) assignment from a vector iname knl = make_kernel('<> test = mask[j]') assert knl.temporary_variables['test'].shape == (4,) # case 3) recursive dependency knl = make_kernel(""" <> test = mask[j] <> test2 = test """) assert knl.temporary_variables['test2'].shape == (4,) # case 4) # modified case from pyjac -- what makes this case special is that # Kc is never directly assigned to in an instruction that directly references # the vector iname, j_inner. Instead, it is a good test of the recursive # vector temporary promotion, as it is written to by B_sum, which _is_ directly # written to from an instruction (bset1) that references j_inner skeleton = """ for j %(preamble)s for i %(insn)s if i > 6 <> P_val = 100 {id=pset0, nosync=pset1} else P_val = 0.01 {id=pset1, nosync=pset0} end <> B_sum = 0 {id=bset0} for k B_sum = B_sum + k * a[i, j] {id=bset1, dep=*:bset0} end # here, we are testing that Kc is properly promoted to a vector dtype <> P_sum = P_val * i {id=pset2, dep=pset0:pset1} B_sum = exp(B_sum) {id=bset2, dep=bset0:bset1} <> Kc = P_sum * B_sum {id=kset, dep=bset*:pset2} a[i, j] = Kc {dep=*:kset, nosync=pset0:pset1} end end """ knl = make_kernel('', dtype=np.float32, skeleton=skeleton, extra_inames='k') from loopy.kernel.array import VectorArrayDimTag assert any(isinstance(x, VectorArrayDimTag) for x in knl.temporary_variables['Kc'].dim_tags) def test_explicit_simd_selects(ctx_factory): ctx = ctx_factory() def create_and_test(insn, condition, answer, exception=None, a=None, b=None, extra_insns=None, c=None, v=None, check=None, debug=False): a = np.zeros((3, 4), dtype=np.int32) if a is None else a data = [lp.GlobalArg('a', shape=(12,), dtype=a.dtype)] kwargs = dict(a=a) if b is not None: data += [lp.GlobalArg('b', shape=(12,), dtype=b.dtype)] kwargs['b'] = b if c is not None: data += [lp.GlobalArg('c', shape=(12,), dtype=b.dtype)] kwargs['c'] = c names = [d.name for d in data] # add after defining names to avoid trying to split value arg if v is not None: data += [lp.ValueArg('v', dtype=np.int32)] kwargs['v'] = v knl = lp.make_kernel(['{[i]: 0 <= i < 12}'], """ for i %(extra)s if %(condition)s %(insn)s end end """ % dict(condition=condition, insn=insn, extra=extra_insns if extra_insns else ''), data ) knl = lp.split_iname(knl, 'i', 4, inner_tag='vec') knl = lp.split_array_axis(knl, names, 0, 4) knl = lp.tag_array_axes(knl, names, 'N0,vec') if v is not None: knl = lp.set_options(knl, write_wrapper=True) queue = cl.CommandQueue(ctx) if check is not None: assert check(knl) elif exception is not None: with pytest.raises(exception): knl(queue, **kwargs) else: if not isinstance(answer, tuple): answer = (answer,) if debug: print(lp.generate_code_v2(knl).device_code()) result = knl(queue, **kwargs)[1] for r, a in zip(result, answer): assert np.array_equal(r.flatten('C'), a) ans = np.zeros(12, dtype=np.int32) ans[7:] = 1 # 1) test a conditional on a vector iname create_and_test('a[i] = 1', 'i > 6', ans) # 2) condition on a vector array create_and_test('a[i] = 1', 'b[i] > 6', ans, b=np.arange( 12, dtype=np.int32).reshape((3, 4))) # 3) condition on a vector temporary create_and_test('a[i] = 1', 'c', ans, extra_insns='<> c = (i < 7) - 1') # 4) condition on an assigned vector array, this should work as assignment to a # vector can be safely unrolled create_and_test('a[i] = 1', '(b[i] > 6)', ans, b=np.zeros((3, 4), dtype=np.int32), extra_insns='b[i] = i') # 5) a block of simple assignments, this should be seemlessly translated to # multiple vector if statements c_ans = np.ones(12, dtype=np.int32) c_ans[7:] = 0 create_and_test('a[i] = 1\nc[i] = 0', '(b[i] > 6)', (ans, c_ans), b=np.arange( 12, dtype=np.int32).reshape((3, 4)), c=np.ones((3, 4), dtype=np.int32)) # 6) test a negated conditional ans_negated = np.invert(ans) + 2 create_and_test('a[i] = 1', 'not (b[i] > 6)', ans_negated, b=np.arange( 12, dtype=np.int32).reshape((3, 4))) # 7) test conditional on differing dtype ans_negated = np.invert(ans) + 2 create_and_test('a[i] = 1', 'not (b[i] > 6)', ans_negated, b=np.arange( 12, dtype=np.int64).reshape((3, 4))) # 8) test conditional on differing dtype (float->int) and (int->float) ans_negated = np.invert(ans) + 2 create_and_test('a[i] = 1', 'not (b[i] > 6)', ans_negated, b=np.arange( 12, dtype=np.float64).reshape((3, 4))) create_and_test('a[i] = 1', 'not (b[i] > 6)', ans_negated, b=np.arange( 12, dtype=np.int64).reshape((3, 4)), a=np.zeros((3, 4), dtype=np.float32)) # 9) test conditional on valuearg, the "test" here is that we can actually # generate the code create_and_test('a[i] = 1', 'v', np.ones_like(ans), v=1) @pytest.mark.parametrize(('lhs_dtype', 'rhs_dtype'), [ (np.int32, np.int64), (np.float32, np.float64)]) def test_explicit_vector_dtype_conversion(ctx_factory, lhs_dtype, rhs_dtype): ctx = ctx_factory() # test that dtype conversion happens correctly between differing vector-dtypes def __make_kernel(insn, has_conversion=True, uses_temp=True): vw = 4 a_lp = lp.GlobalArg('a', shape=(12,), dtype=rhs_dtype) temp_lp = lp.TemporaryVariable('temp', dtype=lhs_dtype) knl = lp.make_kernel(['{[i]: 0 <= i < 12}'], """ for i {insn} end """.format(insn=insn), [a_lp, temp_lp], target=lp.PyOpenCLTarget(ctx.devices[0]), silenced_warnings=['temp_to_write(temp)'] if not uses_temp else []) knl = lp.split_iname(knl, 'i', vw, inner_tag='vec') knl = lp.split_array_axis(knl, 'a', 0, 4) knl = lp.tag_array_axes(knl, 'a', 'N0,vec') queue = cl.CommandQueue(ctx) # check that the kernel compiles correctly knl(queue, a=np.zeros((12,), dtype=rhs_dtype).reshape((3, 4))) # check that we have or don't have a conversion assert ('convert_' in lp.generate_code_v2(knl).device_code()) == \ has_conversion # test simple dtype conversion __make_kernel("temp = a[i]") # test literal assignment __make_kernel("a[i] = 1", False, False) # test that a non-vector temporary doesn't trigger conversion # # this should generate the code (e.g.,): # __kernel void __attribute__ ((reqd_work_group_size(1, 1, 1))) # loopy_kernel(__global long4 *__restrict__ a) # { # int temp; # for (int i_outer = 0; i_outer <= 2; ++i_outer) # { # temp = 1; # a[i_outer] = temp; # } # } # # that is, temp should _not_ be assigned to "a" w/ convert_long4 __make_kernel(""" temp = 1 a[i] = temp """, has_conversion=False) # test that the inverse _does_ result in a convers __make_kernel(""" temp = a[i] {id=1, dep=*} a[i] = temp {id=2, dep=1} """) @pytest.mark.parametrize('dtype', [np.int32, np.int64, np.float32, np.float64]) @pytest.mark.parametrize('vec_width', [2, 3, 4, 8, 16]) def test_explicit_simd_vector_iname_in_conditional(ctx_factory, dtype, vec_width): ctx = ctx_factory() size = vec_width * 4 def create_and_test(insn, answer, shape=(1, size), debug=False, vectors=['a', 'b']): num_conditions = shape[0] knl = lp.make_kernel(['{{[i]: 0 <= i < {}}}'.format(size), '{{[j]: 0 <= j < {}}}'.format(num_conditions)], insn, [lp.GlobalArg('a', shape=shape, dtype=dtype), lp.GlobalArg('b', shape=shape, dtype=dtype)]) knl = lp.split_iname(knl, 'i', 4, inner_tag='vec') knl = lp.tag_inames(knl, [('j', 'g.0')]) knl = lp.split_array_axis(knl, ['a', 'b'], 1, 4) knl = lp.tag_array_axes(knl, vectors, 'N1,N0,vec') # ensure we can generate code code = lp.generate_code_v2(knl).device_code() if debug: print(code) # and check answer queue = cl.CommandQueue(ctx) num_vectors = int(shape[1] / 4) a = np.zeros((num_conditions, num_vectors, 4), dtype=dtype) b = np.arange(num_conditions * num_vectors * 4, dtype=dtype).reshape( (num_conditions, num_vectors, 4)) result = knl(queue, a=a, b=b)[1][0] assert np.array_equal(result.flatten('C'), answer) ans = np.arange(size, dtype=np.int32) ans[:7] = 0 create_and_test(""" if i >= 7 a[j, i] = b[j, i] end """, ans) # a case that will result in a unvectorized evaluation # this tests that we are properly able to unwind any vectorized conditional that # has been applied, and then reapply the correct scalar conditional in # unvectorize ans = np.arange(12 * size, dtype=np.int32) ans[:7] = 0 create_and_test(""" if j * 12 + i >= 7 a[j, i] = b[j, i] end """, ans, shape=(12, size), vectors=['b']) def test_vectorizability(): # check new vectorizability conditions from loopy.kernel.array import VectorArrayDimTag from loopy.kernel.data import VectorizeTag, filter_iname_tags_by_type def create_and_test(insn, exception=None, a=None, b=None): a = np.zeros((3, 4), dtype=np.int32) if a is None else a data = [lp.GlobalArg('a', shape=(12,), dtype=a.dtype)] kwargs = dict(a=a) if b is not None: data += [lp.GlobalArg('b', shape=(12,), dtype=b.dtype)] kwargs['b'] = b names = [d.name for d in data] knl = lp.make_kernel(['{[i]: 0 <= i < 12}'], """ for i %(insn)s end """ % dict(insn=insn), data ) knl = lp.split_iname(knl, 'i', 4, inner_tag='vec') knl = lp.split_array_axis(knl, names, 0, 4) knl = lp.tag_array_axes(knl, names, 'N0,vec') knl = lp.preprocess_kernel(knl) lp.generate_code_v2(knl).device_code() assert knl.instructions[0].within_inames & set(['i_inner']) assert isinstance(knl.args[0].dim_tags[-1], VectorArrayDimTag) assert isinstance(knl.args[0].dim_tags[-1], VectorArrayDimTag) assert filter_iname_tags_by_type(knl.iname_to_tags['i_inner'], VectorizeTag) def run(op_list=[], unary_operators=[], func_list=[], unary_funcs=[], rvals=['1', 'a[i]']): for op in op_list: template = 'a[i] = a[i] %(op)s %(rval)s' \ if op not in unary_operators else 'a[i] = %(op)s a[i]' for rval in rvals: create_and_test(template % dict(op=op, rval=rval)) for func in func_list: template = 'a[i] = %(func)s(a[i], %(rval)s)' \ if func not in unary_funcs else 'a[i] = %(func)s(a[i])' for rval in rvals: create_and_test(template % dict(func=func, rval=rval)) # 1) comparisons run(['>', '>=', '<', '<=', '==', '!=']) # 2) logical operators run(['and', 'or', 'not'], ['not']) # 3) bitwise operators # bitwise xor '^' not not implemented in codegen run(['~', '|', '&'], ['~']) # 4) functions -- a random selection of the enabled math functions in opencl run(func_list=['acos', 'exp10', 'atan2', 'round'], unary_funcs=['round', 'acos', 'exp10']) # 5) remainders and floor division (use 4 instead of 1 to avoid pymbolic # optimizing out the a[i] % 1) run(['%', '//'], rvals=['a[i]', '4']) def test_check_for_variable_access_ordering(): knl = lp.make_kernel( "{[i]: 0<=i nu = i - 4 if nu > 0 <> P_val = a[i, j] {id=pset0} else P_val = 0.1 * a[i, j] {id=pset1} end <> B_sum = 0 for k B_sum = B_sum + k * P_val {id=bset, dep=pset*} end # here, we are testing that Kc is properly promoted to a vector dtype <> Kc = P_val * B_sum {id=kset, dep=bset} a[i, j] = Kc {dep=kset} end end """, [lp.GlobalArg('a', shape=(12, 12), dtype=np.int32)]) knl = lp.split_iname(knl, 'j', 4, inner_tag='vec') knl = lp.split_array_axis(knl, 'a', 1, 4) knl = lp.tag_array_axes(knl, 'a', 'N1,N0,vec') knl = lp.preprocess_kernel(knl) from loopy.diagnostic import DependencyCycleFound with pytest.raises(DependencyCycleFound): print(lp.generate_code(knl)[0]) if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) else: from pytest import main main([__file__]) # vim: foldmethod=marker