Skip to content
Snippets Groups Projects
Commit 9dc012d7 authored by Kaushik Kulkarni's avatar Kaushik Kulkarni Committed by Andreas Klöckner
Browse files

rewrite face_mass as einsum

parent 3da41908
No related branches found
No related tags found
No related merge requests found
Pipeline #205733 passed
......@@ -50,15 +50,12 @@ THE SOFTWARE.
"""
from arraycontext import (
ArrayContext,
make_loopy_program
)
from arraycontext import ArrayContext
from meshmode.transform_metadata import FirstAxisIsElementsTag
from grudge.discretization import DiscretizationCollection
from pytools import memoize_in, keyed_memoize_in
from pytools import keyed_memoize_in
from pytools.obj_array import obj_array_vectorize, make_obj_array
from meshmode.dof_array import DOFArray
......@@ -763,52 +760,28 @@ def _apply_face_mass_operator(dcoll: DiscretizationCollection, dd, vec):
dtype = vec.entry_dtype
actx = vec.array_context
@memoize_in(actx, (_apply_face_mass_operator, "face_mass_knl"))
def prg():
t_unit = make_loopy_program(
[
"{[iel]: 0 <= iel < nelements}",
"{[f]: 0 <= f < nfaces}",
"{[idof]: 0 <= idof < nvol_nodes}",
"{[jdof]: 0 <= jdof < nface_nodes}"
],
"""
result[iel, idof] = sum(f, sum(jdof, mat[idof, f, jdof]
* jac_surf[f, iel, jdof]
* vec[f, iel, jdof]))
""",
name="face_mass"
)
import loopy as lp
from meshmode.transform_metadata import (
ConcurrentElementInameTag, ConcurrentDOFInameTag)
return lp.tag_inames(t_unit, {
"iel": ConcurrentElementInameTag(),
"idof": ConcurrentDOFInameTag()})
assert len(face_discr.groups) == len(volm_discr.groups)
surf_area_elements = area_element(actx, dcoll, dd=dd)
return DOFArray(
actx,
data=tuple(
actx.call_loopy(prg(),
mat=reference_face_mass_matrix(
actx.einsum("ifj,fej,fej->ei",
reference_face_mass_matrix(
actx,
face_element_group=afgrp,
vol_element_group=vgrp,
dtype=dtype
),
jac_surf=surf_ae_i.reshape(
dtype=dtype),
surf_ae_i.reshape(
vgrp.mesh_el_group.nfaces,
vgrp.nelements,
afgrp.nunit_dofs
),
vec=vec_i.reshape(
-1),
vec_i.reshape(
vgrp.mesh_el_group.nfaces,
vgrp.nelements,
afgrp.nunit_dofs
))["result"]
afgrp.nunit_dofs),
arg_names=("ref_face_mass_mat", "jac_surf", "vec"),
tagged=(FirstAxisIsElementsTag(),))
for vgrp, afgrp, vec_i, surf_ae_i in zip(volm_discr.groups,
face_discr.groups,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment