diff --git a/examples/wave/wave-eager-mpi.py b/examples/wave/wave-eager-mpi.py index d61c91641ee3c01d49a8707c82e10aa2acf0a4a4..4b408bb599c5002a32325f5befd9ab66ab02cee9 100644 --- a/examples/wave/wave-eager-mpi.py +++ b/examples/wave/wave-eager-mpi.py @@ -43,26 +43,26 @@ from mpi4py import MPI # {{{ wave equation bits +def scalar(arg): + return make_obj_array([arg]) + + def wave_flux(discr, c, w_tpair): u = w_tpair[0] v = w_tpair[1:] normal = thaw(u.int.array_context, discr.normal(w_tpair.dd)) - def normal_times(scalar): - # workaround for object array behavior - return make_obj_array([ni*scalar for ni in normal]) - flux_weak = flat_obj_array( np.dot(v.avg, normal), - normal_times(u.avg), + normal*scalar(u.avg), ) # upwind v_jump = np.dot(normal, v.int-v.ext) flux_weak -= flat_obj_array( 0.5*(u.int-u.ext), - 0.5*normal_times(v_jump), + 0.5*normal*scalar(v_jump), ) return discr.project(w_tpair.dd, "all_faces", c*flux_weak) diff --git a/examples/wave/wave-eager.py b/examples/wave/wave-eager.py index 2b455509fa66f9d13e50012b60b31226661b3343..e78214313bfb8ea0ab9e547bbc37d6ebc0aaaa5f 100644 --- a/examples/wave/wave-eager.py +++ b/examples/wave/wave-eager.py @@ -41,26 +41,26 @@ from grudge.symbolic.primitives import TracePair # {{{ wave equation bits +def scalar(arg): + return make_obj_array([arg]) + + def wave_flux(discr, c, w_tpair): u = w_tpair[0] v = w_tpair[1:] normal = thaw(u.int.array_context, discr.normal(w_tpair.dd)) - def normal_times(scalar): - # workaround for object array behavior - return make_obj_array([ni*scalar for ni in normal]) - flux_weak = flat_obj_array( np.dot(v.avg, normal), - normal_times(u.avg), + normal*scalar(u.avg), ) # upwind v_jump = np.dot(normal, v.int-v.ext) flux_weak -= flat_obj_array( 0.5*(u.int-u.ext), - 0.5*normal_times(v_jump), + 0.5*normal*scalar(v_jump), ) return discr.project(w_tpair.dd, "all_faces", c*flux_weak)