diff --git a/examples/dagrt_fusion/fusion-study.py b/examples/dagrt_fusion/fusion-study.py index 1fad3375c08ef3547ecb4b372c98b16bb96a49f1..4f301c836c47cec69406a69e600a9d55fc3d864e 100755 --- a/examples/dagrt_fusion/fusion-study.py +++ b/examples/dagrt_fusion/fusion-study.py @@ -898,6 +898,24 @@ def latex_table(table_format, header, rows): return "\n".join(result) +def problem_stats(order=3): + cl_ctx = cl.create_some_context() + + _, dg_discr_2d = get_strong_wave_op_with_discr(cl_ctx, dims=2, order=order) + print("Number of 2D elements:", dg_discr_2d.mesh.nelements) + vol_discr_2d = dg_discr_2d.discr_from_dd("vol") + dofs_2d = set(group.nunit_nodes for group in vol_discr_2d.groups) + from pytools import one + print("Number of DOFs per 2D element:", one(dofs_2d)) + + _, dg_discr_3d = get_strong_wave_op_with_discr(cl_ctx, dims=3, order=order) + print("Number of 3D elements:", dg_discr_3d.mesh.nelements) + vol_discr_3d = dg_discr_3d.discr_from_dd("vol") + dofs_3d = set(group.nunit_nodes for group in vol_discr_3d.groups) + from pytools import one + print("Number of DOFs per 3D element:", one(dofs_3d)) + + def statement_counts_table(): cl_ctx = cl.create_some_context() queue = cl.CommandQueue(cl_ctx) @@ -1085,6 +1103,7 @@ def scalar_assignment_effect_of_fusion_mem_ops_table(): if __name__ == "__main__": + problem_stats() statement_counts_table() scalar_assignment_percent_of_total_mem_ops_table() scalar_assignment_effect_of_fusion_mem_ops_table()