From 708dfdf06b45e29a74298629f901383ee6894c6a Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Sat, 27 Jun 2020 23:10:06 -0500 Subject: [PATCH] Eager wave examples: Use sane broadcasting --- examples/wave/wave-eager-mpi.py | 12 ++++++------ examples/wave/wave-eager.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/wave/wave-eager-mpi.py b/examples/wave/wave-eager-mpi.py index d61c9164..4b408bb5 100644 --- a/examples/wave/wave-eager-mpi.py +++ b/examples/wave/wave-eager-mpi.py @@ -43,26 +43,26 @@ from mpi4py import MPI # {{{ wave equation bits +def scalar(arg): + return make_obj_array([arg]) + + def wave_flux(discr, c, w_tpair): u = w_tpair[0] v = w_tpair[1:] normal = thaw(u.int.array_context, discr.normal(w_tpair.dd)) - def normal_times(scalar): - # workaround for object array behavior - return make_obj_array([ni*scalar for ni in normal]) - flux_weak = flat_obj_array( np.dot(v.avg, normal), - normal_times(u.avg), + normal*scalar(u.avg), ) # upwind v_jump = np.dot(normal, v.int-v.ext) flux_weak -= flat_obj_array( 0.5*(u.int-u.ext), - 0.5*normal_times(v_jump), + 0.5*normal*scalar(v_jump), ) return discr.project(w_tpair.dd, "all_faces", c*flux_weak) diff --git a/examples/wave/wave-eager.py b/examples/wave/wave-eager.py index 2b455509..e7821431 100644 --- a/examples/wave/wave-eager.py +++ b/examples/wave/wave-eager.py @@ -41,26 +41,26 @@ from grudge.symbolic.primitives import TracePair # {{{ wave equation bits +def scalar(arg): + return make_obj_array([arg]) + + def wave_flux(discr, c, w_tpair): u = w_tpair[0] v = w_tpair[1:] normal = thaw(u.int.array_context, discr.normal(w_tpair.dd)) - def normal_times(scalar): - # workaround for object array behavior - return make_obj_array([ni*scalar for ni in normal]) - flux_weak = flat_obj_array( np.dot(v.avg, normal), - normal_times(u.avg), + normal*scalar(u.avg), ) # upwind v_jump = np.dot(normal, v.int-v.ext) flux_weak -= flat_obj_array( 0.5*(u.int-u.ext), - 0.5*normal_times(v_jump), + 0.5*normal*scalar(v_jump), ) return discr.project(w_tpair.dd, "all_faces", c*flux_weak) -- GitLab