diff --git a/grudge/op.py b/grudge/op.py index 4aac98815a89eaf63dc138d2279d8f0017e4b575..ed2a29b3fa3611388000dbe1a37fd4cccfc5015d 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -92,11 +92,78 @@ from grudge.trace_pair import ( # noqa: F401 ) +# {{{ common derivative "kernels" + +def _single_axis_derivative_kernel( + actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, xyz_axis, vec, + *, metric_in_matvec): + # This gets used from both the strong and the weak derivative. These differ + # in three ways: + # - which differentiation matrix gets used, + # - whether inv_jac_mat is pre-multiplied by a factor that includes the + # area element, and + # - whether the chain rule terms ("inv_jac_mat") sit outside (strong) + # or inside (weak) the matrix-vector product that carries out the + # derivative, cf. "metric_in_matvec". + return DOFArray( + actx, + data=tuple( + # r for rst axis + actx.einsum("rej,rij,ej->ei" if metric_in_matvec else "rei,rij,ej->ei", + ijm_i[xyz_axis], + get_diff_mat( + actx, + out_element_group=out_grp, + in_element_group=in_grp + ), + vec_i, + arg_names=("inv_jac_t", "ref_stiffT_mat", "vec", ), + tagged=(FirstAxisIsElementsTag(),)) + + for out_grp, in_grp, vec_i, ijm_i in zip( + out_discr.groups, in_discr.groups, vec, + inv_jac_mat))) + + +def _gradient_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec, + *, metric_in_matvec): + # See _single_axis_derivative_kernel for comments on the usage scenarios + # (both strong and weak derivative) and their differences. + per_group_grads = [ + # r for rst axis + # x for xyz axis + actx.einsum("xrej,rij,ej->xei" if metric_in_matvec else "xrei,rij,ej->xei", + ijm_i, + get_diff_mat( + actx, + out_element_group=out_grp, + in_element_group=in_grp + ), + vec_i, + arg_names=("inv_jac_t", "ref_stiffT_mat", "vec"), + tagged=(FirstAxisIsElementsTag(),)) + for out_grp, in_grp, vec_i, ijm_i in zip( + out_discr.groups, in_discr.groups, vec, + inv_jac_mat)] + + return make_obj_array([ + DOFArray( + actx, data=tuple([pgg_i[xyz_axis] for pgg_i in per_group_grads])) + for xyz_axis in range(out_discr.ambient_dim)]) + +# }}} + + # {{{ Derivative operators -def reference_derivative_matrices(actx: ArrayContext, element_group): +def _reference_derivative_matrices(actx: ArrayContext, + out_element_group, in_element_group): + # We're accepting in_element_group for interface consistency with + # _reference_stiffness_transpose_matrix. + assert out_element_group is in_element_group + @keyed_memoize_in( - actx, reference_derivative_matrices, + actx, _reference_derivative_matrices, lambda grp: grp.discretization_key()) def get_ref_derivative_mats(grp): from meshmode.discretization.poly_element import diff_matrices @@ -107,29 +174,7 @@ def reference_derivative_matrices(actx: ArrayContext, element_group): ) ) ) - return get_ref_derivative_mats(element_group) - - -def _compute_local_gradient(dcoll: DiscretizationCollection, vec, xyz_axis): - from grudge.geometry import inverse_surface_metric_derivative_mat - - discr = dcoll.discr_from_dd(dof_desc.DD_VOLUME) - actx = vec.array_context - - inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll) - return DOFArray( - actx, - data=tuple( - actx.einsum("dei,dij,ej->ei", - ijm_i[xyz_axis], - reference_derivative_matrices(actx, grp), - vec_i, - arg_names=("inv_jac_t", "ref_diff_mat", "vec"), - tagged=(FirstAxisIsElementsTag(),)) - - for grp, vec_i, ijm_i in zip(discr.groups, vec, inverse_jac_mat) - ) - ) + return get_ref_derivative_mats(out_element_group) def local_grad( @@ -157,8 +202,15 @@ def local_grad( else: return np.stack(grad, axis=0) - return make_obj_array([_compute_local_gradient(dcoll, vec, xyz_axis) - for xyz_axis in range(dcoll.dim)]) + from grudge.geometry import inverse_surface_metric_derivative_mat + + discr = dcoll.discr_from_dd(dof_desc.DD_VOLUME) + actx = vec.array_context + + inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll) + return _gradient_kernel(actx, discr, discr, + _reference_derivative_matrices, inverse_jac_mat, vec, + metric_in_matvec=False) def local_d_dx(dcoll: DiscretizationCollection, xyz_axis, vec): @@ -174,7 +226,16 @@ def local_d_dx(dcoll: DiscretizationCollection, xyz_axis, vec): :arg vec: a :class:`~meshmode.dof_array.DOFArray`. :returns: a :class:`~meshmode.dof_array.DOFArray`\ s. """ - return _compute_local_gradient(dcoll, vec, xyz_axis) + discr = dcoll.discr_from_dd(dof_desc.DD_VOLUME) + actx = vec.array_context + + from grudge.geometry import inverse_surface_metric_derivative_mat + inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll) + + return _single_axis_derivative_kernel( + actx, discr, discr, + _reference_derivative_matrices, inverse_jac_mat, xyz_axis, vec, + metric_in_matvec=False) def _div_helper(dcoll: DiscretizationCollection, diff_func, vecs): @@ -224,10 +285,10 @@ def local_div(dcoll: DiscretizationCollection, vecs): # {{{ Weak derivative operators -def reference_stiffness_transpose_matrix( +def _reference_stiffness_transpose_matrix( actx: ArrayContext, out_element_group, in_element_group): @keyed_memoize_in( - actx, reference_stiffness_transpose_matrix, + actx, _reference_stiffness_transpose_matrix, lambda out_grp, in_grp: (out_grp.discretization_key(), in_grp.discretization_key())) def get_ref_stiffness_transpose_mat(out_grp, in_grp): @@ -315,27 +376,9 @@ def weak_local_grad(dcoll: DiscretizationCollection, *args, nested=False): inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll, dd=dd_in, times_area_element=True) - per_group_grads = [ - # r for rst axis - # x for xyz axis - actx.einsum("rij,ej,xrej->xei", - reference_stiffness_transpose_matrix( - actx, - out_element_group=out_grp, - in_element_group=in_grp - ), - vec_i, - ijm_i, - arg_names=("ref_stiffT_mat", "vec", "inv_jac_t"), - tagged=(FirstAxisIsElementsTag(),)) - for out_grp, in_grp, vec_i, ijm_i in zip( - out_discr.groups, in_discr.groups, vec, - inverse_jac_mat)] - - return make_obj_array([ - DOFArray( - actx, data=tuple([pgg_i[xyz_axis] for pgg_i in per_group_grads])) - for xyz_axis in range(dcoll.ambient_dim)]) + return _gradient_kernel(actx, out_discr, in_discr, + _reference_stiffness_transpose_matrix, inverse_jac_mat, vec, + metric_in_matvec=True) def weak_local_d_dx(dcoll: DiscretizationCollection, *args): @@ -381,23 +424,10 @@ def weak_local_d_dx(dcoll: DiscretizationCollection, *args): inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll, dd=dd_in, times_area_element=True) - return DOFArray( - actx, - data=tuple( - actx.einsum("dij,ej,dej->ei", - reference_stiffness_transpose_matrix( - actx, - out_element_group=out_grp, - in_element_group=in_grp - ), - vec_i, - ijm_i[xyz_axis], - arg_names=("ref_stiffT_mat", "vec", "inv_jac_t"), - tagged=(FirstAxisIsElementsTag(),)) - - for out_grp, in_grp, vec_i, ijm_i in zip( - out_discr.groups, in_discr.groups, vec, - inverse_jac_mat))) + return _single_axis_derivative_kernel( + actx, out_discr, in_discr, _reference_stiffness_transpose_matrix, + inverse_jac_mat, xyz_axis, vec, + metric_in_matvec=True) def weak_local_div(dcoll: DiscretizationCollection, *args):