Skip to content
Snippets Groups Projects
Commit 708dfdf0 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Eager wave examples: Use sane broadcasting

parent c46b59c2
Branches
Tags
2 merge requests!76Overintegration for eager examples,!71Array context
...@@ -43,26 +43,26 @@ from mpi4py import MPI ...@@ -43,26 +43,26 @@ from mpi4py import MPI
# {{{ wave equation bits # {{{ wave equation bits
def scalar(arg):
return make_obj_array([arg])
def wave_flux(discr, c, w_tpair): def wave_flux(discr, c, w_tpair):
u = w_tpair[0] u = w_tpair[0]
v = w_tpair[1:] v = w_tpair[1:]
normal = thaw(u.int.array_context, discr.normal(w_tpair.dd)) normal = thaw(u.int.array_context, discr.normal(w_tpair.dd))
def normal_times(scalar):
# workaround for object array behavior
return make_obj_array([ni*scalar for ni in normal])
flux_weak = flat_obj_array( flux_weak = flat_obj_array(
np.dot(v.avg, normal), np.dot(v.avg, normal),
normal_times(u.avg), normal*scalar(u.avg),
) )
# upwind # upwind
v_jump = np.dot(normal, v.int-v.ext) v_jump = np.dot(normal, v.int-v.ext)
flux_weak -= flat_obj_array( flux_weak -= flat_obj_array(
0.5*(u.int-u.ext), 0.5*(u.int-u.ext),
0.5*normal_times(v_jump), 0.5*normal*scalar(v_jump),
) )
return discr.project(w_tpair.dd, "all_faces", c*flux_weak) return discr.project(w_tpair.dd, "all_faces", c*flux_weak)
......
...@@ -41,26 +41,26 @@ from grudge.symbolic.primitives import TracePair ...@@ -41,26 +41,26 @@ from grudge.symbolic.primitives import TracePair
# {{{ wave equation bits # {{{ wave equation bits
def scalar(arg):
return make_obj_array([arg])
def wave_flux(discr, c, w_tpair): def wave_flux(discr, c, w_tpair):
u = w_tpair[0] u = w_tpair[0]
v = w_tpair[1:] v = w_tpair[1:]
normal = thaw(u.int.array_context, discr.normal(w_tpair.dd)) normal = thaw(u.int.array_context, discr.normal(w_tpair.dd))
def normal_times(scalar):
# workaround for object array behavior
return make_obj_array([ni*scalar for ni in normal])
flux_weak = flat_obj_array( flux_weak = flat_obj_array(
np.dot(v.avg, normal), np.dot(v.avg, normal),
normal_times(u.avg), normal*scalar(u.avg),
) )
# upwind # upwind
v_jump = np.dot(normal, v.int-v.ext) v_jump = np.dot(normal, v.int-v.ext)
flux_weak -= flat_obj_array( flux_weak -= flat_obj_array(
0.5*(u.int-u.ext), 0.5*(u.int-u.ext),
0.5*normal_times(v_jump), 0.5*normal*scalar(v_jump),
) )
return discr.project(w_tpair.dd, "all_faces", c*flux_weak) return discr.project(w_tpair.dd, "all_faces", c*flux_weak)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment