From 6b9f8297000b23db2b1832cb0f74a7a131e2b30e Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Wed, 22 Sep 2021 11:25:35 -0500 Subject: [PATCH] avoid call_loopy by taking a combination of sum/max/min and broadcast_to --- grudge/reductions.py | 65 ++++++++++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/grudge/reductions.py b/grudge/reductions.py index ec301576..08a87335 100644 --- a/grudge/reductions.py +++ b/grudge/reductions.py @@ -295,35 +295,46 @@ def _apply_elementwise_reduction( actx = vec.array_context - @memoize_in(actx, (_apply_elementwise_reduction, - "elementwise_%s_prg" % op_name)) - def elementwise_prg(): - # FIXME: This computes the reduction value redundantly for each - # output DOF. - t_unit = make_loopy_program( - [ - "{[iel]: 0 <= iel < nelements}", - "{[idof, jdof]: 0 <= idof, jdof < ndofs}" - ], - """ - result[iel, idof] = %s(jdof, operand[iel, jdof]) - """ % op_name, - name="grudge_elementwise_%s_knl" % op_name + if actx.supports_nonscalar_broadcasting: + return DOFArray( + actx, + data=tuple( + actx.np.broadcast_to((getattr(actx.np, op_name)(vec_i, axis=1) + .reshape(-1, 1)), + vec_i.shape) + for vec_i in vec + ) ) - import loopy as lp - from meshmode.transform_metadata import ( - ConcurrentElementInameTag, ConcurrentDOFInameTag) - return lp.tag_inames(t_unit, { - "iel": ConcurrentElementInameTag(), - "idof": ConcurrentDOFInameTag()}) - - return DOFArray( - actx, - data=tuple( - actx.call_loopy(elementwise_prg(), operand=vec_i)["result"] - for vec_i in vec + else: + @memoize_in(actx, (_apply_elementwise_reduction, + "elementwise_%s_prg" % op_name)) + def elementwise_prg(): + # FIXME: This computes the reduction value redundantly for each + # output DOF. + t_unit = make_loopy_program( + [ + "{[iel]: 0 <= iel < nelements}", + "{[idof, jdof]: 0 <= idof, jdof < ndofs}" + ], + """ + result[iel, idof] = %s(jdof, operand[iel, jdof]) + """ % op_name, + name="grudge_elementwise_%s_knl" % op_name + ) + import loopy as lp + from meshmode.transform_metadata import ( + ConcurrentElementInameTag, ConcurrentDOFInameTag) + return lp.tag_inames(t_unit, { + "iel": ConcurrentElementInameTag(), + "idof": ConcurrentDOFInameTag()}) + + return DOFArray( + actx, + data=tuple( + actx.call_loopy(elementwise_prg(), operand=vec_i)["result"] + for vec_i in vec + ) ) - ) def elementwise_sum(dcoll: DiscretizationCollection, *args) -> DOFArray: -- GitLab