From f74954bc60443fc8f54dba92c316f9c808ec22f7 Mon Sep 17 00:00:00 2001 From: jdsteve2 Date: Thu, 22 Jul 2021 12:44:46 -0500 Subject: [PATCH] test counting of temps with global address space --- test/test_statistics.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/test_statistics.py b/test/test_statistics.py index f4949a9f1..95887f982 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -934,6 +934,35 @@ def test_mem_access_counter_consec(): assert f32consec == n*m*ell +def test_mem_access_counter_global_temps(): + + knl = lp.make_kernel( + "[n,m,ell] -> {[i,j,k]: 0<=ia[i, j, k] = 3.1 + <>b[i, j] = 3.2 + """, + assumptions="n,m,ell >= 1") + knl = lp.add_and_infer_dtypes(knl, {"a,b": np.float32}) + + # Change temporary b address space + knl = lp.privatize_temporaries_with_inames(knl, "i,j", "b") + knl = lp.set_temporary_scope(knl, "b", "global") + + mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, + subgroup_size="guess") + n = 512 + m = 256 + ell = 128 + params = {"n": n, "m": m, "ell": ell} + + # Count global accesses + global_accesses = mem_map.filter_by( + mtype=["global"]).sum().eval_with_dict(params) + + assert global_accesses == n*m + + def test_count_granularity_val_checks(): try: -- GitLab