diff --git a/grudge/op.py b/grudge/op.py index c3054f306044f85b3c74c83eda262b0a945a8ab4..13c644507889ba140f8e010a44bb2e8db6c26604 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -77,15 +77,11 @@ from meshmode.dof_array import freeze, flatten, unflatten from grudge.symbolic.primitives import TracePair -# {{{ tags +# {{{ Kernel tags class HasElementwiseMatvecTag(FirstAxisIsElementsTag): pass - -class MassOperatorTag(HasElementwiseMatvecTag): - pass - # }}} @@ -545,16 +541,16 @@ def _apply_mass_operator(dcoll, dd_out, dd_in, vec): return DOFArray( actx, tuple( - actx.einsum("ej,ij,ej->ei", - ae_i, + actx.einsum("ij,ej,ej->ei", reference_mass_matrix( actx, out_element_group=out_grp, in_element_group=in_grp ), + ae_i, vec_i, - arg_names=("jac", "mass_mat", "vec"), - tagged=(MassOperatorTag(),)) + arg_names=("mass_mat", "jac", "vec"), + tagged=(HasElementwiseMatvecTag(),)) for in_grp, out_grp, ae_i, vec_i in zip( in_discr.groups, out_discr.groups, area_elements, vec) @@ -633,22 +629,22 @@ def _apply_inverse_mass_operator(dcoll, dd_out, dd_in, vec): # true_Minv ~ ref_Minv * ref_M * (1/jac_det) * ref_Minv actx.einsum("ik,km,em,mj,ej->ei", ref_mass_inv, ref_mass, jac_inv, ref_mass_inv, x, - tagged=(MassOperatorTag(),)) + tagged=(HasElementwiseMatvecTag(),)) ) return DOFArray(actx, data=tuple(data)) else: return DOFArray( actx, tuple( - actx.einsum("ej,ij,ej->ei", - iae_i, + actx.einsum("ij,ej,ej->ei", reference_inverse_mass_matrix( actx, element_group=grp ), + iae_i, vec_i, - arg_names=("jac_det_inv", "mass_inv_mat", "vec"), - tagged=(MassOperatorTag(),)) + arg_names=("mass_inv_mat", "jac_det_inv", "vec"), + tagged=(HasElementwiseMatvecTag(),)) for grp, iae_i, vec_i in zip(discr.groups, inv_area_elements, vec)