diff --git a/grudge/op.py b/grudge/op.py index 3ca897895418022fd23ed60b55d86078254c08c3..5176a710eb8ff52a1874f06cac97ace9ba438b33 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -269,38 +269,6 @@ def reference_stiffness_transpose_matrix( in_element_group) -def _apply_stiffness_transpose_operator( - dcoll: DiscretizationCollection, dd_out, dd_in, vec, xyz_axis): - from grudge.geometry import \ - inverse_surface_metric_derivative_mat, area_element - - in_discr = dcoll.discr_from_dd(dd_in) - out_discr = dcoll.discr_from_dd(dd_out) - - actx = vec.array_context - area_elements = area_element(actx, dcoll, dd=dd_in) - inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll, dd=dd_in) - - return DOFArray( - actx, - data=tuple( - actx.einsum("dij,ej,ej,dej->ei", - reference_stiffness_transpose_matrix( - actx, - out_element_group=out_grp, - in_element_group=in_grp - ), - ae_i, - vec_i, - ijm_i[xyz_axis], - arg_names=("ref_stiffT_mat", "jac", "vec", "inv_jac_t"), - tagged=(FirstAxisIsElementsTag(),)) - - for out_grp, in_grp, vec_i, ae_i, ijm_i in zip( - out_discr.groups, in_discr.groups, vec, area_elements, - inverse_jac_mat))) - - def weak_local_grad(dcoll: DiscretizationCollection, *args, nested=False): r"""Return the element-local weak gradient of the volume function represented by *vec*. @@ -324,26 +292,52 @@ def weak_local_grad(dcoll: DiscretizationCollection, *args, nested=False): """ if len(args) == 1: vec, = args - dd = dof_desc.DOFDesc("vol", dof_desc.DISCR_TAG_BASE) + dd_in = dof_desc.DOFDesc("vol", dof_desc.DISCR_TAG_BASE) elif len(args) == 2: - dd, vec = args + dd_in, vec = args else: raise TypeError("invalid number of arguments") if isinstance(vec, np.ndarray): grad = obj_array_vectorize( - lambda el: weak_local_grad(dcoll, dd, el, nested=nested), vec) + lambda el: weak_local_grad(dcoll, dd_in, el, nested=nested), vec) if nested: return grad else: return np.stack(grad, axis=0) - return make_obj_array( - [_apply_stiffness_transpose_operator(dcoll, - dof_desc.DD_VOLUME, - dd, vec, xyz_axis) - for xyz_axis in range(dcoll.dim)] - ) + from grudge.geometry import \ + inverse_surface_metric_derivative_mat, area_element + + in_discr = dcoll.discr_from_dd(dd_in) + out_discr = dcoll.discr_from_dd(dof_desc.DD_VOLUME) + + actx = vec.array_context + area_elements = area_element(actx, dcoll, dd=dd_in) + inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll, dd=dd_in) + + per_group_grads = [ + # r for rst axis + # x for xyz axis + actx.einsum("rij,ej,ej,xrej->xei", + reference_stiffness_transpose_matrix( + actx, + out_element_group=out_grp, + in_element_group=in_grp + ), + ae_i, + vec_i, + ijm_i, + arg_names=("ref_stiffT_mat", "jac", "vec", "inv_jac_t"), + tagged=(FirstAxisIsElementsTag(),)) + for out_grp, in_grp, vec_i, ae_i, ijm_i in zip( + out_discr.groups, in_discr.groups, vec, area_elements, + inverse_jac_mat)] + + return make_obj_array([ + DOFArray( + actx, data=tuple([pgg_i[xyz_axis] for pgg_i in per_group_grads])) + for xyz_axis in range(dcoll.ambient_dim)]) def weak_local_d_dx(dcoll: DiscretizationCollection, *args): @@ -374,15 +368,40 @@ def weak_local_d_dx(dcoll: DiscretizationCollection, *args): """ if len(args) == 2: xyz_axis, vec = args - dd = dof_desc.DOFDesc("vol", dof_desc.DISCR_TAG_BASE) + dd_in = dof_desc.DOFDesc("vol", dof_desc.DISCR_TAG_BASE) elif len(args) == 3: - dd, xyz_axis, vec = args + dd_in, xyz_axis, vec = args else: raise TypeError("invalid number of arguments") - return _apply_stiffness_transpose_operator(dcoll, - dof_desc.DD_VOLUME, - dd, vec, xyz_axis) + from grudge.geometry import \ + inverse_surface_metric_derivative_mat, area_element + + in_discr = dcoll.discr_from_dd(dd_in) + out_discr = dcoll.discr_from_dd(dof_desc.DD_VOLUME) + + actx = vec.array_context + area_elements = area_element(actx, dcoll, dd=dd_in) + inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll, dd=dd_in) + + return DOFArray( + actx, + data=tuple( + actx.einsum("dij,ej,ej,dej->ei", + reference_stiffness_transpose_matrix( + actx, + out_element_group=out_grp, + in_element_group=in_grp + ), + ae_i, + vec_i, + ijm_i[xyz_axis], + arg_names=("ref_stiffT_mat", "jac", "vec", "inv_jac_t"), + tagged=(FirstAxisIsElementsTag(),)) + + for out_grp, in_grp, vec_i, ae_i, ijm_i in zip( + out_discr.groups, in_discr.groups, vec, area_elements, + inverse_jac_mat))) def weak_local_div(dcoll: DiscretizationCollection, *args):