diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 9ae3e18b93f1a855c28827926cfd1b42f5fc5f9f..5a100d7120892e6cfef80e4ee1cd1f1d6cd3bc5b 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -135,15 +135,15 @@ def find_all_insn_inames(kernel): insn_id_to_inames = {} insn_assignee_inames = {} + all_read_deps = {} + all_write_deps = {} + from loopy.transform.subst import expand_subst kernel = expand_subst(kernel) - from loopy.preprocess import add_default_dependencies - kernel = add_default_dependencies(kernel) - for insn in kernel.instructions: - read_deps = insn.read_dependency_names() - write_deps = insn.write_dependency_names() + all_read_deps[insn.id] = read_deps = insn.read_dependency_names() + all_write_deps[insn.id] = write_deps = insn.write_dependency_names() deps = read_deps | write_deps if insn.forced_iname_deps_is_final: @@ -168,6 +168,8 @@ def find_all_insn_inames(kernel): write_deps & kernel.all_inames() | insn.stop_iname_dep_propagation) + written_vars = kernel.get_written_variables() + # fixed point iteration until all iname dep sets have converged # Why is fixed point iteration necessary here? Consider the following @@ -194,38 +196,35 @@ def find_all_insn_inames(kernel): # of iname deps of all writers, and add those to insn's # dependencies. - implicit_inames = None - inames_old_start = inames_old = insn_id_to_inames[insn.id] + for tv_name in (all_read_deps[insn.id] & written_vars): + implicit_inames = None - for dep_id in insn.depends_on: - writer_implicit_inames = ( - insn_id_to_inames[dep_id] - - insn_assignee_inames[dep_id]) - if implicit_inames is None: - implicit_inames = writer_implicit_inames - else: - implicit_inames = (implicit_inames - & writer_implicit_inames) + for writer_id in writer_map[tv_name]: + writer_implicit_inames = ( + insn_id_to_inames[writer_id] + - insn_assignee_inames[writer_id]) + if implicit_inames is None: + implicit_inames = writer_implicit_inames + else: + implicit_inames = (implicit_inames + & writer_implicit_inames) - if implicit_inames is not None: + inames_old = insn_id_to_inames[insn.id] inames_new = (inames_old | implicit_inames) \ - insn.reduction_inames() - else: - inames_new = inames_old - - insn_id_to_inames[insn.id] = inames_new + insn_id_to_inames[insn.id] = inames_new - if inames_new != inames_old: - did_something = True - logger.debug("%s: find_all_insn_inames: %s -> %s (dep-based)" % ( - kernel.name, insn.id, ", ".join(sorted(inames_new)))) + if inames_new != inames_old: + did_something = True + logger.debug("%s: find_all_insn_inames: %s -> %s (dep-based)" % ( + kernel.name, insn.id, ", ".join(sorted(inames_new)))) # }}} # {{{ domain-based propagation inames_old = insn_id_to_inames[insn.id] - inames_new = set(inames_old) + inames_new = set(insn_id_to_inames[insn.id]) for iname in inames_old: home_domain = kernel.domains[kernel.get_home_domain_index(iname)] @@ -243,14 +242,10 @@ def find_all_insn_inames(kernel): if par in kernel.temporary_variables: for writer_id in writer_map.get(par, []): - inames_new.update( - insn_id_to_inames[writer_id] - - insn.reduction_inames()) + inames_new.update(insn_id_to_inames[writer_id]) if inames_new != inames_old: did_something = True - assert inames_new != inames_old_start - insn_id_to_inames[insn.id] = frozenset(inames_new) logger.debug("%s: find_all_insn_inames: %s -> %s (domain-based)" % ( kernel.name, insn.id, ", ".join(sorted(inames_new)))) diff --git a/test/test_linalg.py b/test/test_linalg.py index a28d885c1af4c8f595bf65274a8f794ec5d34bdd..e3cae8557a2266bb7b27adb4d6148ed2cda32063 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -628,23 +628,6 @@ def test_small_batched_matvec(ctx_factory): parameters=dict(K=K)) -def test_outer_product(ctx_factory): - logging.basicConfig(level=logging.DEBUG) - #ctx = ctx_factory() - - knl = lp.make_kernel( - "{[i,j]: 0<=i,j<n}}", - """ - z = a[i]*b[j] {id=init_z} - z = z*2 {id=mult_z,dep=init_z} - c[i,j] = z {dep=mult_z} - """) - - print(knl) - - assert knl.insn_inames("mult_z") == frozenset(["i", "j"]) - - if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])