From 11b1833150b9c35cc917c96108f63a69eee661ef Mon Sep 17 00:00:00 2001 From: Matt Wala <wala1@illinois.edu> Date: Wed, 8 May 2019 19:15:10 -0500 Subject: [PATCH] Add a test for memory modeling --- examples/dagrt_fusion/fusion-study.py | 47 ++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/examples/dagrt_fusion/fusion-study.py b/examples/dagrt_fusion/fusion-study.py index 4f301c8..8dea5d2 100755 --- a/examples/dagrt_fusion/fusion-study.py +++ b/examples/dagrt_fusion/fusion-study.py @@ -168,11 +168,11 @@ def transcribe_phase(dag, field_var_name, field_components, phase_name, raise NotImplementedError("lhs of statement %s is not a variable: %s" % (stmt.id, stmt.lhs)) ctx[stmt.lhs.name] = sym.cse( - DagrtToGrudgeRewriter(ctx)(stmt.rhs), - ( - stmt.lhs.name - .replace("<", "") - .replace(">", ""))) + DagrtToGrudgeRewriter(ctx)(stmt.rhs), + ( + stmt.lhs.name + .replace("<", "") + .replace(">", ""))) elif isinstance(stmt, lang.AssignFunctionCall): if stmt.function_id != rhs_name: @@ -200,8 +200,11 @@ def transcribe_phase(dag, field_var_name, field_components, phase_name, elif isinstance(stmt, lang.YieldState): d2g = DagrtToGrudgeRewriter(ctx) yielded_states.append( - (stmt.time_id, d2g(stmt.time), stmt.component_id, - d2g(stmt.expression))) + ( + stmt.time_id, + d2g(stmt.time), + stmt.component_id, + d2g(stmt.expression))) else: raise NotImplementedError("statement %s is of unsupported type ''%s'" @@ -627,6 +630,32 @@ class ExecutionMapperWithMemOpCounting(ExecutionMapper): # {{{ mem op counter check +def test_assignment_memory_model(ctx_factory): + cl_ctx = ctx_factory() + queue = cl.CommandQueue(cl_ctx) + + _, discr = get_strong_wave_op_with_discr(cl_ctx, dims=2, order=3) + + # Assignment instruction + bound_op = bind( + discr, + sym.Variable("input0", sym.DD_VOLUME) + + sym.Variable("input1", sym.DD_VOLUME), + exec_mapper_factory=ExecutionMapperWithMemOpCounting) + + input0 = discr.zeros(queue) + input1 = discr.zeros(queue) + + result, profile_data = bound_op( + queue, + profile_data={}, + input0=input0, + input1=input1) + + assert profile_data["bytes_read"] == input0.nbytes + input1.nbytes + assert profile_data["bytes_written"] == result.nbytes + + @pytest.mark.parametrize("use_fusion", (True, False)) def test_stepper_mem_ops(ctx_factory, use_fusion): cl_ctx = ctx_factory() @@ -904,14 +933,14 @@ def problem_stats(order=3): _, dg_discr_2d = get_strong_wave_op_with_discr(cl_ctx, dims=2, order=order) print("Number of 2D elements:", dg_discr_2d.mesh.nelements) vol_discr_2d = dg_discr_2d.discr_from_dd("vol") - dofs_2d = set(group.nunit_nodes for group in vol_discr_2d.groups) + dofs_2d = {group.nunit_nodes for group in vol_discr_2d.groups} from pytools import one print("Number of DOFs per 2D element:", one(dofs_2d)) _, dg_discr_3d = get_strong_wave_op_with_discr(cl_ctx, dims=3, order=order) print("Number of 3D elements:", dg_discr_3d.mesh.nelements) vol_discr_3d = dg_discr_3d.discr_from_dd("vol") - dofs_3d = set(group.nunit_nodes for group in vol_discr_3d.groups) + dofs_3d = {group.nunit_nodes for group in vol_discr_3d.groups} from pytools import one print("Number of DOFs per 3D element:", one(dofs_3d)) -- GitLab