diff --git a/examples/dagrt-fusion.py b/examples/dagrt-fusion.py index c12d63605e291b7e040ceca9dd076f8c3e373cf5..e0f35b02f96303875ca40521663268ea0f491a5d 100755 --- a/examples/dagrt-fusion.py +++ b/examples/dagrt-fusion.py @@ -42,7 +42,7 @@ import pytest import dagrt.language as lang import pymbolic.primitives as p import grudge.symbolic.mappers as gmap -import grudge.symbolic.operators as op +import grudge.symbolic.operators as sym_op from grudge.execution import ExecutionMapper from grudge.function_registry import base_function_registry from pymbolic.mapper import Mapper @@ -648,10 +648,10 @@ class ExecutionMapperWithMemOpCounting(ExecutionMapperWrapper): expr.op, ( # TODO: Not comprehensive. - op.InterpolationOperator, - op.RefFaceMassOperator, - op.RefInverseMassOperator, - op.OppositeInteriorFaceSwap)): + sym_op.InterpolationOperator, + sym_op.RefFaceMassOperator, + sym_op.RefInverseMassOperator, + sym_op.OppositeInteriorFaceSwap)): val = self.map_profiled_essentially_elementwise_linear( expr.op, expr.field, profile_data) @@ -761,7 +761,7 @@ def test_assignment_memory_model(ctx_factory): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) - _, discr = get_strong_wave_op_with_discr(cl_ctx, dims=2, order=3) + _, discr = get_strong_wave_op_with_discr_direct(cl_ctx, dims=2, order=3) # Assignment instruction bound_op = bind( @@ -790,7 +790,8 @@ def test_stepper_mem_ops(ctx_factory, use_fusion): dims = 2 - op, discr = get_strong_wave_op_with_discr(cl_ctx, dims=dims, order=3) + sym_operator, discr = get_strong_wave_op_with_discr_direct( + cl_ctx, dims=dims, order=3) t_start = 0 dt = 0.04 @@ -802,7 +803,7 @@ def test_stepper_mem_ops(ctx_factory, use_fusion): if not use_fusion: bound_op = bind( - discr, op.sym_operator(), + discr, sym_operator, exec_mapper_factory=ExecutionMapperWithMemOpCounting) stepper = RK4TimeStepper( @@ -812,7 +813,7 @@ def test_stepper_mem_ops(ctx_factory, use_fusion): else: stepper = FusedRK4TimeStepper( - queue, discr, "w", op.sym_operator(), 1 + discr.dim, + queue, discr, "w", sym_operator, 1 + discr.dim, get_strong_wave_component, exec_mapper_factory=ExecutionMapperWithMemOpCounting) @@ -960,7 +961,8 @@ def test_stepper_timing(ctx_factory, use_fusion): dims = 3 - op, discr = get_strong_wave_op_with_discr(cl_ctx, dims=dims, order=3) + sym_operator, discr = get_strong_wave_op_with_discr_direct( + cl_ctx, dims=dims, order=3) t_start = 0 dt = 0.04 @@ -972,7 +974,7 @@ def test_stepper_timing(ctx_factory, use_fusion): if not use_fusion: bound_op = bind( - discr, op.sym_operator(), + discr, sym_operator, exec_mapper_factory=ExecutionMapperWithTiming) stepper = RK4TimeStepper( @@ -982,7 +984,7 @@ def test_stepper_timing(ctx_factory, use_fusion): else: stepper = FusedRK4TimeStepper( - queue, discr, "w", op.sym_operator(), 1 + discr.dim, + queue, discr, "w", sym_operator, 1 + discr.dim, get_strong_wave_component, exec_mapper_factory=ExecutionMapperWithTiming) @@ -1011,11 +1013,12 @@ def test_stepper_timing(ctx_factory, use_fusion): def get_example_stepper(queue, dims=2, order=3, use_fusion=True, exec_mapper_factory=ExecutionMapper, return_ic=False): - op, discr = get_strong_wave_op_with_discr(queue.context, dims=dims, order=3) + sym_operator, discr = get_strong_wave_op_with_discr_direct( + queue.context, dims=dims, order=3) if not use_fusion: bound_op = bind( - discr, op.sym_operator(), + discr, sym_operator, exec_mapper_factory=exec_mapper_factory) stepper = RK4TimeStepper( @@ -1025,7 +1028,7 @@ def get_example_stepper(queue, dims=2, order=3, use_fusion=True, else: stepper = FusedRK4TimeStepper( - queue, discr, "w", op.sym_operator(), 1 + discr.dim, + queue, discr, "w", sym_operator, 1 + discr.dim, get_strong_wave_component, exec_mapper_factory=exec_mapper_factory) @@ -1080,14 +1083,16 @@ def problem_stats(order=3): cl_ctx = cl.create_some_context() outf = open_output_file("grudge-problem-stats.txt") - _, dg_discr_2d = get_strong_wave_op_with_discr(cl_ctx, dims=2, order=order) + _, dg_discr_2d = get_strong_wave_op_with_discr_direct( + cl_ctx, dims=2, order=order) print("Number of 2D elements:", dg_discr_2d.mesh.nelements, file=outf) vol_discr_2d = dg_discr_2d.discr_from_dd("vol") dofs_2d = {group.nunit_nodes for group in vol_discr_2d.groups} from pytools import one print("Number of DOFs per 2D element:", one(dofs_2d), file=outf) - _, dg_discr_3d = get_strong_wave_op_with_discr(cl_ctx, dims=3, order=order) + _, dg_discr_3d = get_strong_wave_op_with_discr_direct( + cl_ctx, dims=3, order=order) print("Number of 3D elements:", dg_discr_3d.mesh.nelements, file=outf) vol_discr_3d = dg_discr_3d.discr_from_dd("vol") dofs_3d = {group.nunit_nodes for group in vol_discr_3d.groups} @@ -1304,3 +1309,5 @@ def main(): if __name__ == "__main__": main() + +# vim: foldmethod=marker