diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 5a100d7120892e6cfef80e4ee1cd1f1d6cd3bc5b..9ae3e18b93f1a855c28827926cfd1b42f5fc5f9f 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -135,15 +135,15 @@ def find_all_insn_inames(kernel): insn_id_to_inames = {} insn_assignee_inames = {} - all_read_deps = {} - all_write_deps = {} - from loopy.transform.subst import expand_subst kernel = expand_subst(kernel) + from loopy.preprocess import add_default_dependencies + kernel = add_default_dependencies(kernel) + for insn in kernel.instructions: - all_read_deps[insn.id] = read_deps = insn.read_dependency_names() - all_write_deps[insn.id] = write_deps = insn.write_dependency_names() + read_deps = insn.read_dependency_names() + write_deps = insn.write_dependency_names() deps = read_deps | write_deps if insn.forced_iname_deps_is_final: @@ -168,8 +168,6 @@ def find_all_insn_inames(kernel): write_deps & kernel.all_inames() | insn.stop_iname_dep_propagation) - written_vars = kernel.get_written_variables() - # fixed point iteration until all iname dep sets have converged # Why is fixed point iteration necessary here? Consider the following @@ -196,35 +194,38 @@ def find_all_insn_inames(kernel): # of iname deps of all writers, and add those to insn's # dependencies. - for tv_name in (all_read_deps[insn.id] & written_vars): - implicit_inames = None + implicit_inames = None + inames_old_start = inames_old = insn_id_to_inames[insn.id] - for writer_id in writer_map[tv_name]: - writer_implicit_inames = ( - insn_id_to_inames[writer_id] - - insn_assignee_inames[writer_id]) - if implicit_inames is None: - implicit_inames = writer_implicit_inames - else: - implicit_inames = (implicit_inames - & writer_implicit_inames) + for dep_id in insn.depends_on: + writer_implicit_inames = ( + insn_id_to_inames[dep_id] + - insn_assignee_inames[dep_id]) + if implicit_inames is None: + implicit_inames = writer_implicit_inames + else: + implicit_inames = (implicit_inames + & writer_implicit_inames) - inames_old = insn_id_to_inames[insn.id] + if implicit_inames is not None: inames_new = (inames_old | implicit_inames) \ - insn.reduction_inames() - insn_id_to_inames[insn.id] = inames_new + else: + inames_new = inames_old + + insn_id_to_inames[insn.id] = inames_new - if inames_new != inames_old: - did_something = True - logger.debug("%s: find_all_insn_inames: %s -> %s (dep-based)" % ( - kernel.name, insn.id, ", ".join(sorted(inames_new)))) + if inames_new != inames_old: + did_something = True + logger.debug("%s: find_all_insn_inames: %s -> %s (dep-based)" % ( + kernel.name, insn.id, ", ".join(sorted(inames_new)))) # }}} # {{{ domain-based propagation inames_old = insn_id_to_inames[insn.id] - inames_new = set(insn_id_to_inames[insn.id]) + inames_new = set(inames_old) for iname in inames_old: home_domain = kernel.domains[kernel.get_home_domain_index(iname)] @@ -242,10 +243,14 @@ def find_all_insn_inames(kernel): if par in kernel.temporary_variables: for writer_id in writer_map.get(par, []): - inames_new.update(insn_id_to_inames[writer_id]) + inames_new.update( + insn_id_to_inames[writer_id] + - insn.reduction_inames()) if inames_new != inames_old: did_something = True + assert inames_new != inames_old_start + insn_id_to_inames[insn.id] = frozenset(inames_new) logger.debug("%s: find_all_insn_inames: %s -> %s (domain-based)" % ( kernel.name, insn.id, ", ".join(sorted(inames_new)))) diff --git a/test/test_linalg.py b/test/test_linalg.py index e3cae8557a2266bb7b27adb4d6148ed2cda32063..a28d885c1af4c8f595bf65274a8f794ec5d34bdd 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -628,6 +628,23 @@ def test_small_batched_matvec(ctx_factory): parameters=dict(K=K)) +def test_outer_product(ctx_factory): + logging.basicConfig(level=logging.DEBUG) + #ctx = ctx_factory() + + knl = lp.make_kernel( + "{[i,j]: 0<=i,j 1: exec(sys.argv[1])