diff --git a/test/test_loopy.py b/test/test_loopy.py index 3ac857478bf4ac1d4cd6868f22896ca63de34f04..48cb6980ab79bcfc640a19551d9a7708b6a2b20c 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2087,6 +2087,47 @@ def test_integer_reduction(ctx_factory): assert function(out) +def test_complicated_argmin_reduction(ctx_factory): + cl_ctx = ctx_factory() + knl = lp.make_kernel( + "{[ictr,itgt,idim]: " + "0<=itgt<ntargets " + "and 0<=ictr<ncenters " + "and 0<=idim<ambient_dim}", + + """ + for itgt + for ictr + <> dist_sq = sum(idim, + (tgt[idim,itgt] - center[idim,ictr])**2) + <> in_disk = dist_sq < (radius[ictr]*1.05)**2 + <> matches = ( + (in_disk + and qbx_forced_limit == 0) + or (in_disk + and qbx_forced_limit != 0 + and qbx_forced_limit * center_side[ictr] > 0) + ) + + <> post_dist_sq = if(matches, dist_sq, HUGE) + end + <> min_dist_sq, <> min_ictr = argmin(ictr, ictr, post_dist_sq) + + tgt_to_qbx_center[itgt] = if(min_dist_sq < HUGE, min_ictr, -1) + end + """) + + knl = lp.fix_parameters(knl, ambient_dim=2) + knl = lp.add_and_infer_dtypes(knl, { + "tgt,center,radius,HUGE": np.float32, + "center_side,qbx_forced_limit": np.int32, + }) + + lp.auto_test_vs_ref(knl, cl_ctx, knl, parameters={ + "HUGE": 1e20, "ncenters": 200, "ntargets": 300, + "qbx_forced_limit": 1}) + + def test_nosync_option_parsing(): knl = lp.make_kernel( "{[i]: 0 <= i < 10}", @@ -2335,47 +2376,6 @@ def test_kernel_var_name_generator(): assert vng("b") != "b" -def test_complex_argmin(ctx_factory): - cl_ctx = ctx_factory() - knl = lp.make_kernel( - "{[ictr,itgt,idim]: " - "0<=itgt<ntargets " - "and 0<=ictr<ncenters " - "and 0<=idim<ambient_dim}", - - """ - for itgt - for ictr - <> dist_sq = sum(idim, - (tgt[idim,itgt] - center[idim,ictr])**2) - <> in_disk = dist_sq < (radius[ictr]*1.05)**2 - <> matches = ( - (in_disk - and qbx_forced_limit == 0) - or (in_disk - and qbx_forced_limit != 0 - and qbx_forced_limit * center_side[ictr] > 0) - ) - - <> post_dist_sq = if(matches, dist_sq, HUGE) - end - <> min_dist_sq, <> min_ictr = argmin(ictr, ictr, post_dist_sq) - - tgt_to_qbx_center[itgt] = if(min_dist_sq < HUGE, min_ictr, -1) - end - """) - - knl = lp.fix_parameters(knl, ambient_dim=2) - knl = lp.add_and_infer_dtypes(knl, { - "tgt,center,radius,HUGE": np.float32, - "center_side,qbx_forced_limit": np.int32, - }) - - lp.auto_test_vs_ref(knl, cl_ctx, knl, parameters={ - "HUGE": 1e20, "ncenters": 200, "ntargets": 300, - "qbx_forced_limit": 1}) - - if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])