From 6eb6e9d3df1b24f17b56f6adebb91072056069fe Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 22 Apr 2021 12:34:21 -0500 Subject: [PATCH] Drafty einsum-based mass operator --- grudge/op.py | 25 +++++++++++++++++++---- test/test_new_world_grudge.py | 37 ++++++++++++++++++++--------------- 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index c1270df0..66ac1183 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -352,12 +352,24 @@ def reference_mass_matrix(actx, out_element_group, in_element_group): ) -def _apply_mass_operator(dcoll, dd, vec): - pass +def _apply_mass_operator(dcoll, dd_out, dd_in, vec): + in_discr = dcoll.discr_from_dd(dd_in) + out_discr = dcoll.discr_from_dd(dd_out) + + actx = vec.array_context + return [ + actx.np.einsum("ij,j,ej->ei", + reference_mass_matrix(actx, + out_element_group=out_grp, + in_element_group=in_grp), + # FIXME: These are not area elements! + in_discr.zeros(actx), + vec_i) + for in_grp, out_grp, vec_i in zip(in_discr.groups, out_discr.groups, vec)] -def mass_operator(dcoll, *args): +def mass_operator(dcoll, *args): if len(args) == 1: vec, = args dd = sym.DOFDesc("vol", sym.QTAG_NONE) @@ -371,7 +383,12 @@ def mass_operator(dcoll, *args): lambda el: mass_operator(dcoll, dd, el), vec ) - return _apply_mass_operator(dcoll, dd, vec) + dd_in = dd + del dd + from grudge.dof_desc import QTAG_NONE + dd_out = dd_in.with_qtag(QTAG_NONE) + + return _apply_mass_operator(dcoll, dd_out, dd_in, vec) @memoize_on_first_arg diff --git a/test/test_new_world_grudge.py b/test/test_new_world_grudge.py index 82bfb3e2..a1fff8dd 100644 --- a/test/test_new_world_grudge.py +++ b/test/test_new_world_grudge.py @@ -33,8 +33,10 @@ from meshmode.array_context import ( # noqa import meshmode.mesh.generation as mgen from pytools.obj_array import flat_obj_array, make_obj_array +import grudge.op as op from grudge import DiscretizationCollection +import grudge.dof_desc as dof_desc import pytest @@ -46,7 +48,7 @@ logger = logging.getLogger(__name__) # {{{ mass operator trig integration @pytest.mark.parametrize("ambient_dim", [1, 2, 3]) -@pytest.mark.parametrize("quad_tag", [sym.QTAG_NONE, "OVSMP"]) +@pytest.mark.parametrize("quad_tag", [dof_desc.QTAG_NONE, "OVSMP"]) def test_mass_mat_trig(actx_factory, ambient_dim, quad_tag): """Check the integral of some trig functions on an interval using the mass matrix. @@ -63,9 +65,9 @@ def test_mass_mat_trig(actx_factory, ambient_dim, quad_tag): from meshmode.discretization.poly_element import \ QuadratureSimplexGroupFactory - dd_quad = sym.DOFDesc(sym.DTAG_VOLUME_ALL, quad_tag) + dd_quad = dof_desc.DOFDesc(dof_desc.DTAG_VOLUME_ALL, quad_tag) - if quad_tag is sym.QTAG_NONE: + if quad_tag is dof_desc.QTAG_NONE: quad_tag_to_group_factory = {} else: quad_tag_to_group_factory = { @@ -85,22 +87,22 @@ def test_mass_mat_trig(actx_factory, ambient_dim, quad_tag): def f(x): return actx.np.sin(x[0])**2 - def ones(x): - return actx.np.ones(x.shape) - - volm_disc = dcoll.disc_from_dd(sym.DD_VOLUME) + volm_disc = dcoll.discr_from_dd(dof_desc.DD_VOLUME) x_volm = thaw(actx, volm_disc.nodes()) - f_volm = f(x_vol) - ones_volm = ones(x_vol) + f_volm = f(x_volm) + ones_volm = volm_disc.zeros(actx) + 1 - quad_disc = dcoll.disc_from_dd(dd_quad) + quad_disc = dcoll.discr_from_dd(dd_quad) x_quad = thaw(actx, quad_disc.nodes()) f_quad = f(x_quad) - ones_quad = ones(x_quad) + ones_quad = quad_disc.zeros(actx) + 1 - mass_op = bind(discr, sym.MassOperator(dd_quad, sym.DD_VOLUME)(sym_f)) + #mass_op = bind(discr, dof_desc.MassOperator(dd_quad, dof_desc.DD_VOLUME)(sym_f)) - num_integral_1 = np.dot(ones_volm, actx.to_numpy(flatten(mass_op(f=f_quad)))) + num_integral_1 = np.dot(ones_volm, + actx.to_numpy(flatten( + op.mass_operator(dcoll, dd_quad, f_quad)))) + 1/0 err_1 = abs(num_integral_1 - true_integral) assert err_1 < 1e-9, err_1 @@ -108,12 +110,15 @@ def test_mass_mat_trig(actx_factory, ambient_dim, quad_tag): err_2 = abs(num_integral_2 - true_integral) assert err_2 < 1.0e-9, err_2 - if quad_tag is sym.QTAG_NONE: + if quad_tag is dof_desc.QTAG_NONE: # NOTE: `integral` always makes a square mass matrix and # `QuadratureSimplexGroupFactory` does not have a `mass_matrix` method. num_integral_3 = bind(discr, - sym.integral(sym_f, dd=dd_quad))(f=f_quad) + dof_desc.integral(sym_f, dd=dd_quad))(f=f_quad) err_3 = abs(num_integral_3 - true_integral) assert err_3 < 5.0e-10, err_3 -# }}} \ No newline at end of file +# }}} + + +# vim: foldmethod=marker -- GitLab