diff --git a/grudge/op.py b/grudge/op.py index b6d5e63596677a2c8d0e49a718dc3b863cd6f9a5..8acf3c83685526e16f9919dda6fb6034e78345e3 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -456,36 +456,25 @@ def _apply_inverse_mass_operator(dcoll, dd_out, dd_in, vec): actx = vec.array_context inv_area_elements = 1./area_element(actx, dcoll, dd=dd_in) if use_wadg: - # FIXME: This feels just wrong; can't we be smarter here? - # NOTE: Could just write a loopy program directly - return DOFArray( - actx, - tuple(actx.einsum("ik,km,em,mj,ej->ei", - reference_inverse_mass_matrix( - actx, - element_group=grp - ), - reference_mass_matrix( - actx, - out_element_group=grp, - in_element_group=grp - ), - iae_i, - reference_inverse_mass_matrix( - actx, - element_group=grp - ), - vec_i, - # Ew. Fix this `mass_inv_mat_x` business - arg_names=( - "mass_inv_mat_1", "mass_mat", - "jac_det_inv", "mass_inv_mat_2", "vec" - ), + # FIXME: Think of how to compose existing functions here... + # NOTE: Rewritten for readability/debuggability + result = discr.empty(actx, dtype=vec.entry_dtype) + grps = discr.groups + data = [] + for grp, jac_inv, x in zip(grps, inv_area_elements, vec): + ref_mass = reference_mass_matrix(actx, + out_element_group=grp, + in_element_group=grp) + ref_mass_inv = reference_inverse_mass_matrix(actx, + element_group=grp) + data.append( + # Based on https://arxiv.org/pdf/1608.03836.pdf + # true_Minv ~ ref_Minv * ref_M * (1/jac_det) * ref_Minv + actx.einsum("ik,km,em,mj,ej->ei", + ref_mass_inv, ref_mass, jac_inv, ref_mass_inv, x, tagged=(MassOperatorTag(),)) - - for grp, iae_i, vec_i in zip(discr.groups, inv_area_elements, vec) ) - ) + return DOFArray(actx, data=tuple(data)) else: return DOFArray( actx,