From ea06b523fc7dcb64db8a809249073e2d25ece01d Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 12 Mar 2023 10:30:39 -0500 Subject: [PATCH] test pytato.pad --- test/test_codegen.py | 87 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/test/test_codegen.py b/test/test_codegen.py index 7dc1b6a..f0e4469 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1834,6 +1834,93 @@ def test_placeholders_do_not_diverge_after_removing_impl_stored(ctx_factory): np.testing.assert_allclose(out["out2"], x_np) +def _get_masking_array_for_test_pad(array, pad_widths): + from pytato.pad import _normalize_pad_width + pad_widths = _normalize_pad_width(array, pad_widths) + + def _get_mask_array_idx(*idxs): + return np.where( + sum([((idx < pad_width[0]) + | (idx >= (pad_width[0]+axis_len)) + ).astype(np.int32) + for idx, axis_len, pad_width in zip(idxs, + array.shape, + pad_widths)]) > 1, + 0*idxs[0], + 0*idxs[0] + 1) + + return np.fromfunction( + _get_mask_array_idx, + shape=tuple(dim + pad_width[0] + pad_width[1] + for dim, pad_width in zip(array.shape, pad_widths)), + ) + + +def test_pad(ctx_factory): + + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + rng = np.random.default_rng(0) + + for _ in range(10): + ndim = rng.integers(1, 8) + ary_shape = rng.integers(2, 7, ndim) + ary_np = rng.random(tuple(ary_shape), dtype=np.float32) + ary = pt.make_data_wrapper(ary_np) + + # test constant pad length - I + pad_width = rng.integers(0, 3) + np_out = np.pad(ary_np, pad_width) + out = pt.pad(ary, pad_width) + + _, (pt_out,) = pt.generate_loopy(out)(cq) + mask_array = _get_masking_array_for_test_pad(ary_np, pad_width) + + np.testing.assert_allclose(np_out * mask_array, pt_out * mask_array) + + # testing constant pad length - II + pad_width = tuple(rng.integers(0, 3, 2)) + np_out = np.pad(ary_np, pad_width) + out = pt.pad(ary, pad_width) + + _, (pt_out,) = pt.generate_loopy(out)(cq) + mask_array = _get_masking_array_for_test_pad(ary_np, pad_width) + + np.testing.assert_allclose(np_out * mask_array, pt_out * mask_array) + + # test unequal padding - I + pad_width = [tuple(pad) for pad in rng.integers(0, 3, (ndim, 2))] + np_out = np.pad(ary_np, pad_width, constant_values=32) + out = pt.pad(ary, pad_width, constant_values=32) + + _, (pt_out,) = pt.generate_loopy(out)(cq) + mask_array = _get_masking_array_for_test_pad(ary_np, pad_width) + + np.testing.assert_allclose(np_out * mask_array, pt_out * mask_array) + + # test unequal padding - II + pad_width = [tuple(pad) for pad in rng.integers(0, 3, (ndim, 2))] + np_out = np.pad(ary_np, pad_width, constant_values=(32, 42)) + out = pt.pad(ary, pad_width, constant_values=(32, 42)) + + _, (pt_out,) = pt.generate_loopy(out)(cq) + mask_array = _get_masking_array_for_test_pad(ary_np, pad_width) + + np.testing.assert_allclose(np_out * mask_array, pt_out * mask_array) + + # test unequal padding - III + pad_width = [tuple(pad) for pad in rng.integers(0, 3, (ndim, 2))] + constant_values = [tuple(val) for val in rng.random((ndim, 2))] + + np_out = np.pad(ary_np, pad_width, constant_values=constant_values) + out = pt.pad(ary, pad_width, constant_values=constant_values) + + _, (pt_out,) = pt.generate_loopy(out)(cq) + mask_array = _get_masking_array_for_test_pad(ary_np, pad_width) + + np.testing.assert_allclose(np_out * mask_array, pt_out * mask_array) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab