diff --git a/loopy/transform/subst.py b/loopy/transform/subst.py index 79ceff9fdf1e2c4b3b544e8ae85f8194b36ec444..a681afe06520483c83530c241e39229412e88f03 100644 --- a/loopy/transform/subst.py +++ b/loopy/transform/subst.py @@ -1,6 +1,4 @@ -from __future__ import division -from __future__ import absolute_import -import six +from __future__ import division, absolute_import __copyright__ = "Copyright (C) 2012 Andreas Kloeckner" @@ -24,6 +22,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import six from loopy.symbolic import ( get_dependencies, SubstitutionMapper, @@ -141,6 +140,7 @@ def extract_subst(kernel, subst_name, template, parameters=()): dfmapper = CallbackMapper(gather_exprs, WalkMapper()) for insn in kernel.instructions: + dfmapper(insn.assignees) dfmapper(insn.expression) for sr in six.itervalues(kernel.substitutions): @@ -178,8 +178,7 @@ def extract_subst(kernel, subst_name, template, parameters=()): new_insns = [] for insn in kernel.instructions: - new_expr = cbmapper(insn.expression) - new_insns.append(insn.copy(expression=new_expr)) + new_insns.append(insn.with_transformed_expressions(cbmapper)) from loopy.kernel.data import SubstitutionRule new_substs = { diff --git a/test/test_loopy.py b/test/test_loopy.py index f2658ca75fe726a809ea7e30b1263e51b93c1716..d0398f216a7f85798bc5f125e353578e74765b9f 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2678,6 +2678,30 @@ def test_preamble_with_separate_temporaries(ctx_factory): queue, data=data.flatten('C'))[1][0], data[offsets[:-1] + 1]) +def test_add_prefetch_works_in_lhs_index(): + knl = lp.make_kernel( + "{ [n,k,l,k1,l1,k2,l2]: " + "start<=n a1_tmp[k,l] = a1[a1_map[n, k],l] + a1_tmp[k1,l1] = a1_tmp[k1,l1] + 1 + a1_out[a1_map[n,k2], l2] = a1_tmp[k2,l2] + end + """, + [ + lp.GlobalArg("a1,a1_out", None, "ndofs,2"), + lp.GlobalArg("a1_map", None, "nelements,3"), + "..." + ]) + + knl = lp.add_prefetch(knl, "a1_map", "k") + + from loopy.symbolic import get_dependencies + for insn in knl.instructions: + assert "a1_map" not in get_dependencies(insn.assignees) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])