diff --git a/examples/dagrt_fusion/fusion-study.py b/examples/dagrt_fusion/fusion-study.py index 4f301c836c47cec69406a69e600a9d55fc3d864e..8dea5d27888aa03d75bbe1a3e7452028724be41d 100755 --- a/examples/dagrt_fusion/fusion-study.py +++ b/examples/dagrt_fusion/fusion-study.py @@ -168,11 +168,11 @@ def transcribe_phase(dag, field_var_name, field_components, phase_name, raise NotImplementedError("lhs of statement %s is not a variable: %s" % (stmt.id, stmt.lhs)) ctx[stmt.lhs.name] = sym.cse( - DagrtToGrudgeRewriter(ctx)(stmt.rhs), - ( - stmt.lhs.name - .replace("<", "") - .replace(">", ""))) + DagrtToGrudgeRewriter(ctx)(stmt.rhs), + ( + stmt.lhs.name + .replace("<", "") + .replace(">", ""))) elif isinstance(stmt, lang.AssignFunctionCall): if stmt.function_id != rhs_name: @@ -200,8 +200,11 @@ def transcribe_phase(dag, field_var_name, field_components, phase_name, elif isinstance(stmt, lang.YieldState): d2g = DagrtToGrudgeRewriter(ctx) yielded_states.append( - (stmt.time_id, d2g(stmt.time), stmt.component_id, - d2g(stmt.expression))) + ( + stmt.time_id, + d2g(stmt.time), + stmt.component_id, + d2g(stmt.expression))) else: raise NotImplementedError("statement %s is of unsupported type ''%s'" @@ -627,6 +630,32 @@ class ExecutionMapperWithMemOpCounting(ExecutionMapper): # {{{ mem op counter check +def test_assignment_memory_model(ctx_factory): + cl_ctx = ctx_factory() + queue = cl.CommandQueue(cl_ctx) + + _, discr = get_strong_wave_op_with_discr(cl_ctx, dims=2, order=3) + + # Assignment instruction + bound_op = bind( + discr, + sym.Variable("input0", sym.DD_VOLUME) + + sym.Variable("input1", sym.DD_VOLUME), + exec_mapper_factory=ExecutionMapperWithMemOpCounting) + + input0 = discr.zeros(queue) + input1 = discr.zeros(queue) + + result, profile_data = bound_op( + queue, + profile_data={}, + input0=input0, + input1=input1) + + assert profile_data["bytes_read"] == input0.nbytes + input1.nbytes + assert profile_data["bytes_written"] == result.nbytes + + @pytest.mark.parametrize("use_fusion", (True, False)) def test_stepper_mem_ops(ctx_factory, use_fusion): cl_ctx = ctx_factory() @@ -904,14 +933,14 @@ def problem_stats(order=3): _, dg_discr_2d = get_strong_wave_op_with_discr(cl_ctx, dims=2, order=order) print("Number of 2D elements:", dg_discr_2d.mesh.nelements) vol_discr_2d = dg_discr_2d.discr_from_dd("vol") - dofs_2d = set(group.nunit_nodes for group in vol_discr_2d.groups) + dofs_2d = {group.nunit_nodes for group in vol_discr_2d.groups} from pytools import one print("Number of DOFs per 2D element:", one(dofs_2d)) _, dg_discr_3d = get_strong_wave_op_with_discr(cl_ctx, dims=3, order=order) print("Number of 3D elements:", dg_discr_3d.mesh.nelements) vol_discr_3d = dg_discr_3d.discr_from_dd("vol") - dofs_3d = set(group.nunit_nodes for group in vol_discr_3d.groups) + dofs_3d = {group.nunit_nodes for group in vol_discr_3d.groups} from pytools import one print("Number of DOFs per 3D element:", one(dofs_3d))