Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import numpy as np
import pyopencl as cl
from grudge import sym, bind, DGDiscretizationWithBoundaries
from grudge.shortcuts import set_up_rk4
def simple_wave_entrypoint(dim=2, num_elems=256, order=4, num_steps=30,
log_filename="grudge.dat"):
cl_ctx = cl.create_some_context()
queue = cl.CommandQueue(cl_ctx)
from mpi4py import MPI
comm = MPI.COMM_WORLD
num_parts = comm.Get_size()
n = int(num_elems ** (1./dim))
from meshmode.distributed import MPIMeshDistributor
mesh_dist = MPIMeshDistributor(comm)
if mesh_dist.is_mananger_rank():
from meshmode.mesh.generation import generate_regular_rect_mesh
mesh = generate_regular_rect_mesh(a=(-0.5,)*dim,
b=(0.5,)*dim,
n=(n,)*dim)
from pymetis import part_graph
_, p = part_graph(num_parts,
xadj=mesh.nodal_adjacency.neighbors_starts.tolist(),
adjncy=mesh.nodal_adjacency.neighbors.tolist())
part_per_element = np.array(p)
local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
else:
local_mesh = mesh_dist.receive_mesh_part()
vol_discr = DGDiscretizationWithBoundaries(cl_ctx, local_mesh, order=order,
mpi_communicator=comm)
source_center = np.array([0.1, 0.22, 0.33])[:local_mesh.dim]
source_width = 0.05
source_omega = 3
sym_x = sym.nodes(local_mesh.dim)
sym_source_center_dist = sym_x - source_center
sym_t = sym.ScalarVariable("t")
from grudge.models.wave import StrongWaveOperator
from meshmode.mesh import BTAG_ALL, BTAG_NONE
op = StrongWaveOperator(-0.1, vol_discr.dim,
source_f=(
sym.sin(source_omega*sym_t)
* sym.exp(
-np.dot(sym_source_center_dist, sym_source_center_dist)
/ source_width**2)),
dirichlet_tag=BTAG_NONE,
neumann_tag=BTAG_NONE,
radiation_tag=BTAG_ALL,
flux_type="upwind")
from pytools.obj_array import join_fields
fields = join_fields(vol_discr.zeros(queue),
[vol_discr.zeros(queue) for i in range(vol_discr.dim)])
from logpyle import LogManager, \
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
add_general_quantities, \
add_run_info, \
IntervalTimer, EventCounter
# NOTE: LogManager hangs when using a file on a shared directory.
logmgr = LogManager(log_filename, "w", comm)
add_run_info(logmgr)
add_general_quantities(logmgr)
log_quantities =\
{"rank_data_swap_timer": IntervalTimer("rank_data_swap_timer",
"Time spent evaluating RankDataSwapAssign"),
"rank_data_swap_counter": EventCounter("rank_data_swap_counter",
"Number of RankDataSwapAssign instructions evaluated"),
"exec_timer": IntervalTimer("exec_timer",
"Total time spent executing instructions"),
"insn_eval_timer": IntervalTimer("insn_eval_timer",
"Time spend evaluating instructions"),
"future_eval_timer": IntervalTimer("future_eval_timer",
"Time spent evaluating futures"),
"busy_wait_timer": IntervalTimer("busy_wait_timer",
"Time wasted doing busy wait")}
for quantity in log_quantities.values():
logmgr.add_quantity(quantity)
bound_op = bind(vol_discr, op.sym_operator())
def rhs(t, w):
val, rhs.profile_data = bound_op(queue, profile_data=rhs.profile_data,
log_quantities=log_quantities,
t=t, w=w)
return val
rhs.profile_data = {}
dt = 0.04
dt_stepper = set_up_rk4("w", dt, fields, rhs)
logmgr.tick_before()
for event in dt_stepper.run(t_end=dt * num_steps):
if isinstance(event, dt_stepper.StateComputed):
logmgr.tick_after()
logmgr.tick_before()
logmgr.tick_after()
def print_profile_data(data):
print("""execute() for rank %d:
\tInstruction Evaluation: %f%%
\tFuture Evaluation: %f%%
\tBusy Wait: %f%%
\tTotal: %f seconds""" %
(comm.Get_rank(),
data['insn_eval_time'] / data['total_time'] * 100,
data['future_eval_time'] / data['total_time'] * 100,
data['busy_wait_time'] / data['total_time'] * 100,
data['total_time']))
print_profile_data(rhs.profile_data)
logmgr.close()
if __name__ == "__main__":
assert "RUN_WITHIN_MPI" in os.environ, "Must run within mpi"
import sys
assert len(sys.argv) == 5, \
"Usage: %s %s num_elems order num_steps logfile" \
% (sys.executable, sys.argv[0])
simple_wave_entrypoint(num_elems=int(sys.argv[1]),
order=int(sys.argv[2]),
num_steps=int(sys.argv[3]),
log_filename=sys.argv[4])