diff --git a/loopy/buffer_writes.py b/loopy/buffer_writes.py index 7083e3c269aea7d663c1c38605ecd8b581324b6d..7616dcba59d57997786538d9a67525bdfe07cac1 100644 --- a/loopy/buffer_writes.py +++ b/loopy/buffer_writes.py @@ -26,7 +26,8 @@ THE SOFTWARE. from loopy.array_buffer import (ArrayToBufferMap, NoOpArrayToBufferMap, AccessDescriptor) -from loopy.symbolic import ExpandingIdentityMapper, SubstitutionMapper +from loopy.symbolic import (get_dependencies, ExpandingIdentityMapper, + SubstitutionMapper) from pymbolic.mapper.substitutor import make_subst_func from pymbolic import var @@ -112,8 +113,7 @@ class ArrayAccessReplacer(ExpandingIdentityMapper): def buffer_write(kernel, var_name, buffer_inames, init_expression=None, store_expression=None, within=None, default_tag="l.auto", - temporary_is_local=None, fetch_bounding_box=False, - within_inames=()): + temporary_is_local=None, fetch_bounding_box=False): """ :arg init_expression: Either *None* (indicating the prior value of the buffered array should be read) or an expression optionally involving the @@ -134,10 +134,8 @@ def buffer_write(kernel, var_name, buffer_inames, init_expression=None, store_expression = parse(store_expression) if isinstance(buffer_inames, str): - buffer_inames = buffer_inames.split(",") - - if isinstance(within_inames, str): - within_inames = within_inames.split(",") + buffer_inames = [s.strip() + for s in buffer_inames.split(",") if s.strip()] for iname in buffer_inames: if iname not in kernel.all_inames(): @@ -170,6 +168,7 @@ def buffer_write(kernel, var_name, buffer_inames, init_expression=None, # }}} var_name_gen = kernel.get_var_name_generator() + within_inames = set() access_descriptors = [] for insn in kernel.instructions: @@ -178,6 +177,9 @@ def buffer_write(kernel, var_name, buffer_inames, init_expression=None, for assignee, index in insn.assignees_and_indices(): if assignee == var_name: + within_inames.update( + (get_dependencies(index) & kernel.all_inames()) + - buffer_inames_set) access_descriptors.append( AccessDescriptor( identifier=insn.id, @@ -211,7 +213,7 @@ def buffer_write(kernel, var_name, buffer_inames, init_expression=None, # {{{ find domain to be changed from loopy.kernel.tools import DomainChanger - domch = DomainChanger(kernel, buffer_inames_set) + domch = DomainChanger(kernel, buffer_inames_set | within_inames) if domch.leaf_domain_index is not None: # If the sweep inames are at home in parent domains, then we'll add diff --git a/test/test_fortran.py b/test/test_fortran.py index c68e963a60585634437d8a80b5eb1cbb1e37e2c2..eeac0b23abe75dffc27375bfc4b66b9708c7e63c 100644 --- a/test/test_fortran.py +++ b/test/test_fortran.py @@ -273,7 +273,12 @@ def test_tagged(ctx_factory): assert sum(1 for insn in lp.find_instructions(knl, "*$input")) == 2 -def test_matmul(ctx_factory): +@pytest.mark.parametrize("buffer_inames", [ + "", + "i_inner", + "i_inner,j_inner", + ]) +def test_matmul(ctx_factory, buffer_inames): fortran_src = """ subroutine dgemm(m,n,l,a,b,c) implicit none @@ -311,11 +316,8 @@ def test_matmul(ctx_factory): knl = lp.precompute(knl, "a_acc", "k_inner,i_inner") knl = lp.precompute(knl, "b_acc", "j_inner,k_inner") - # FIXME: also test - # knl = lp.buffer_write(knl, "c", (), init_expression="0", - # store_expression="base+buffer") - knl = lp.buffer_write(knl, "c", "i_inner,j_inner", init_expression="0", - store_expression="base+buffer", within_inames="i_outer,j_outer") + knl = lp.buffer_write(knl, "c", buffer_inames=buffer_inames, + init_expression="0", store_expression="base+buffer") #ctx = ctx_factory() #lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5, m=7, l=10))