diff --git a/grudge/op.py b/grudge/op.py index 25a9d913933244d004c97d4ff24a7bf012bd3f84..92de47b9e86b1b98ab0c28656314ded6dbf6189d 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -49,6 +49,7 @@ THE SOFTWARE. from numbers import Number from pytools import memoize_on_first_arg, keyed_memoize_in +from pytools.tag import Tag import numpy as np # noqa from pytools.obj_array import obj_array_vectorize, make_obj_array @@ -58,13 +59,26 @@ from grudge import sym, bind import grudge.dof_desc as dof_desc from meshmode.mesh import BTAG_ALL, BTAG_NONE, BTAG_PARTITION # noqa -from meshmode.dof_array import freeze, flatten, unflatten +from meshmode.dof_array import freeze, flatten, unflatten, DOFArray +from meshmode.array_context import FirstAxisIsElementsTag import loopy as lp from grudge.symbolic.primitives import TracePair +# {{{ tags + +class HasElementwiseMatvecTag(FirstAxisIsElementsTag): + pass + + +class MassOperatorTag(HasElementwiseMatvecTag): + pass + +# }}} + + # def interp(dcoll, src, tgt, vec): # from warnings import warn # warn("using 'interp' is deprecated, use 'project' instead.", @@ -353,20 +367,24 @@ def reference_mass_matrix(actx, out_element_group, in_element_group): def _apply_mass_operator(dcoll, dd_out, dd_in, vec): - in_discr = dcoll.discr_from_dd(dd_in) out_discr = dcoll.discr_from_dd(dd_out) actx = vec.array_context - return [ - actx.np.einsum("ij,j,ej->ei", - reference_mass_matrix(actx, - out_element_group=out_grp, - in_element_group=in_grp), - # FIXME: These are not area elements! - in_discr.zeros(actx), - vec_i) - for in_grp, out_grp, vec_i in zip(in_discr.groups, out_discr.groups, vec)] + area_elements = in_discr.zeros(actx) # FIXME *cough* + return DOFArray(actx, + tuple( + actx.einsum("ij,ej,ej->ei", + reference_mass_matrix(actx, + out_element_group=out_grp, + in_element_group=in_grp), + ae_i, vec_i, + arg_names=("mass_mat", "jac_det", "vec"), + tagged=(MassOperatorTag(),)) + + for in_grp, out_grp, ae_i, vec_i in zip( + in_discr.groups, out_discr.groups, + area_elements, vec))) def mass_operator(dcoll, *args): diff --git a/test/test_new_world_grudge.py b/test/test_new_world_grudge.py index a1fff8ddc226fbabd13c0a97a6d3869f8cf9864b..148dafbf73d50bf0f6d0161aa2db6876ef6007ec 100644 --- a/test/test_new_world_grudge.py +++ b/test/test_new_world_grudge.py @@ -99,10 +99,10 @@ def test_mass_mat_trig(actx_factory, ambient_dim, quad_tag): #mass_op = bind(discr, dof_desc.MassOperator(dd_quad, dof_desc.DD_VOLUME)(sym_f)) - num_integral_1 = np.dot(ones_volm, - actx.to_numpy(flatten( - op.mass_operator(dcoll, dd_quad, f_quad)))) - 1/0 + mop = op.mass_operator(dcoll, dd_quad, f_quad) + num_integral_1 = np.dot( + actx.to_numpy(flatten(ones_volm)), + actx.to_numpy(flatten( mop))) err_1 = abs(num_integral_1 - true_integral) assert err_1 < 1e-9, err_1