From ce273c766d98720fef435ffbdbb90ba1d3784753 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 19 Nov 2020 15:51:36 -0600 Subject: [PATCH] remove make_obj_array([...]) for scalar DOF arrays --- examples/wave/wave-eager-mpi.py | 10 +++------- examples/wave/wave-eager-var-velocity.py | 14 +++++--------- examples/wave/wave-eager.py | 10 +++------- 3 files changed, 11 insertions(+), 23 deletions(-) diff --git a/examples/wave/wave-eager-mpi.py b/examples/wave/wave-eager-mpi.py index 5971b2ab..b6ca9ae3 100644 --- a/examples/wave/wave-eager-mpi.py +++ b/examples/wave/wave-eager-mpi.py @@ -25,7 +25,7 @@ import numpy as np import numpy.linalg as la # noqa import pyopencl as cl -from pytools.obj_array import flat_obj_array, make_obj_array +from pytools.obj_array import flat_obj_array from meshmode.array_context import PyOpenCLArrayContext from meshmode.dof_array import thaw @@ -41,10 +41,6 @@ from mpi4py import MPI # {{{ wave equation bits -def scalar(arg): - return make_obj_array([arg]) - - def wave_flux(discr, c, w_tpair): u = w_tpair[0] v = w_tpair[1:] @@ -53,14 +49,14 @@ def wave_flux(discr, c, w_tpair): flux_weak = flat_obj_array( np.dot(v.avg, normal), - normal*scalar(u.avg), + normal*u.avg, ) # upwind v_jump = np.dot(normal, v.ext-v.int) flux_weak += flat_obj_array( 0.5*(u.ext-u.int), - 0.5*normal*scalar(v_jump), + 0.5*normal*v_jump, ) return discr.project(w_tpair.dd, "all_faces", c*flux_weak) diff --git a/examples/wave/wave-eager-var-velocity.py b/examples/wave/wave-eager-var-velocity.py index 1ea155cb..194a1d65 100644 --- a/examples/wave/wave-eager-var-velocity.py +++ b/examples/wave/wave-eager-var-velocity.py @@ -25,7 +25,7 @@ import numpy as np import numpy.linalg as la # noqa import pyopencl as cl -from pytools.obj_array import flat_obj_array, make_obj_array +from pytools.obj_array import flat_obj_array from meshmode.array_context import PyOpenCLArrayContext from meshmode.dof_array import thaw @@ -39,10 +39,6 @@ from grudge.symbolic.primitives import TracePair, QTAG_NONE, DOFDesc # {{{ wave equation bits -def scalar(arg): - return make_obj_array([arg]) - - def wave_flux(discr, c, w_tpair): dd = w_tpair.dd dd_quad = dd.with_qtag("vel_prod") @@ -54,13 +50,13 @@ def wave_flux(discr, c, w_tpair): flux_weak = flat_obj_array( np.dot(v.avg, normal), - normal*scalar(u.avg), + normal*u.avg, ) # upwind flux_weak += flat_obj_array( 0.5*(u.ext-u.int), - 0.5*normal*scalar(np.dot(normal, v.ext-v.int)), + 0.5*normal*np.dot(normal, v.ext-v.int), ) # FIXME this flux is only correct for continuous c @@ -68,7 +64,7 @@ def wave_flux(discr, c, w_tpair): c_quad = discr.project("vol", dd_quad, c) flux_quad = discr.project(dd, dd_quad, flux_weak) - return discr.project(dd_quad, dd_allfaces_quad, scalar(c_quad)*flux_quad) + return discr.project(dd_quad, dd_allfaces_quad, c_quad*flux_quad) def wave_operator(discr, c, w): @@ -91,7 +87,7 @@ def wave_operator(discr, c, w): return ( discr.inverse_mass( flat_obj_array( - -discr.weak_div(dd_quad, scalar(c_quad)*v_quad), + -discr.weak_div(dd_quad, c_quad*v_quad), -discr.weak_grad(dd_quad, c_quad*u_quad) ) + # noqa: W504 diff --git a/examples/wave/wave-eager.py b/examples/wave/wave-eager.py index 4ac1a646..a8c44382 100644 --- a/examples/wave/wave-eager.py +++ b/examples/wave/wave-eager.py @@ -25,7 +25,7 @@ import numpy as np import numpy.linalg as la # noqa import pyopencl as cl -from pytools.obj_array import flat_obj_array, make_obj_array +from pytools.obj_array import flat_obj_array from meshmode.array_context import PyOpenCLArrayContext from meshmode.dof_array import thaw @@ -39,10 +39,6 @@ from grudge.symbolic.primitives import TracePair # {{{ wave equation bits -def scalar(arg): - return make_obj_array([arg]) - - def wave_flux(discr, c, w_tpair): u = w_tpair[0] v = w_tpair[1:] @@ -51,13 +47,13 @@ def wave_flux(discr, c, w_tpair): flux_weak = flat_obj_array( np.dot(v.avg, normal), - normal*scalar(u.avg), + normal*u.avg, ) # upwind flux_weak += flat_obj_array( 0.5*(u.ext-u.int), - 0.5*normal*scalar(np.dot(normal, v.ext-v.int)), + 0.5*normal*np.dot(normal, v.ext-v.int), ) return discr.project(w_tpair.dd, "all_faces", c*flux_weak) -- GitLab