diff --git a/test/test_statistics.py b/test/test_statistics.py index f4949a9f1b2c89c763aae3c32b59be26ac38963c..95887f9827bd6c2d96ab3667d256a1f6cc06d2ec 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -934,6 +934,35 @@ def test_mem_access_counter_consec(): assert f32consec == n*m*ell +def test_mem_access_counter_global_temps(): + + knl = lp.make_kernel( + "[n,m,ell] -> {[i,j,k]: 0<=ia[i, j, k] = 3.1 + <>b[i, j] = 3.2 + """, + assumptions="n,m,ell >= 1") + knl = lp.add_and_infer_dtypes(knl, {"a,b": np.float32}) + + # Change temporary b address space + knl = lp.privatize_temporaries_with_inames(knl, "i,j", "b") + knl = lp.set_temporary_scope(knl, "b", "global") + + mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, + subgroup_size="guess") + n = 512 + m = 256 + ell = 128 + params = {"n": n, "m": m, "ell": ell} + + # Count global accesses + global_accesses = mem_map.filter_by( + mtype=["global"]).sum().eval_with_dict(params) + + assert global_accesses == n*m + + def test_count_granularity_val_checks(): try: