From 808eccf46985a3180c5059ceea20e89d7c225bda Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Fri, 7 Apr 2017 13:14:46 -0500
Subject: [PATCH] Implement loopy.add_nosync() as a transformation.

---
 loopy/__init__.py              |  4 +-
 loopy/transform/instruction.py | 73 ++++++++++++++++++++++++++++++++++
 test/test_transform.py         | 36 +++++++++++++++++
 3 files changed, 112 insertions(+), 1 deletion(-)

diff --git a/loopy/__init__.py b/loopy/__init__.py
index 6cbb3362e..53dd9c8ee 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -75,7 +75,8 @@ from loopy.transform.instruction import (
         set_instruction_priority, add_dependency,
         remove_instructions,
         replace_instruction_ids,
-        tag_instructions)
+        tag_instructions,
+        add_nosync)
 
 from loopy.transform.data import (
         add_prefetch, change_arg_to_image,
@@ -189,6 +190,7 @@ __all__ = [
         "remove_instructions",
         "replace_instruction_ids",
         "tag_instructions",
+        "add_nosync",
 
         "extract_subst", "expand_subst", "assignment_to_subst",
         "find_rules_matching", "find_one_rule_matching",
diff --git a/loopy/transform/instruction.py b/loopy/transform/instruction.py
index 6d7a676b0..410274f90 100644
--- a/loopy/transform/instruction.py
+++ b/loopy/transform/instruction.py
@@ -217,4 +217,77 @@ def tag_instructions(kernel, new_tag, within=None):
 # }}}
 
 
+# {{{ add nosync
+
+def add_nosync(kernel, scope, source, sink, bidirectional=False, force=False):
+    """Add a *no_sync_with* directive between *source* and *sink*.
+    *no_sync_with* is only added if a (syntactic) dependency edge
+    is present or if the instruction pair is in a conflicting group
+    (this does not check for memory dependencies).
+
+    :arg kernel:
+    :arg source: Either a single instruction id, or any instruction id
+        match understood by :func:`loopy.match.parse_match`.
+    :arg sink: Either a single instruction id, or any instruction id
+        match understood by :func:`loopy.match.parse_match`.
+    :arg scope: A string which is a valid *no_sync_with* scope.
+    :arg bidirectional: A :class:`bool`. If *True*, add a *no_sync_with*
+        to both the source and sink instructions, otherwise the directive
+        is only added to the sink instructions.
+    :arg force: A :class:`bool`. If *True*, will add a *no_sync_with*
+        even without the presence of a syntactic dependency edge/
+        conflicting instruction group.
+
+    :return: The updated kernel
+    """
+
+    if isinstance(source, str) and source in kernel.id_to_insn:
+        sources = frozenset([source])
+    else:
+        sources = frozenset(
+                source.id for source in find_instructions(kernel, source))
+
+    if isinstance(sink, str) and sink in kernel.id_to_insn:
+        sinks = frozenset([sink])
+    else:
+        sinks = frozenset(
+                sink.id for sink in find_instructions(kernel, sink))
+
+    def insns_in_conflicting_groups(insn1_id, insn2_id):
+        insn1 = kernel.id_to_insn[insn1_id]
+        insn2 = kernel.id_to_insn[insn2_id]
+        return (
+                bool(insn1.groups & insn2.conflicts_with_groups)
+                or
+                bool(insn2.groups & insn1.conflicts_with_groups))
+
+    from collections import defaultdict
+    nosync_to_add = defaultdict(set)
+
+    for sink in sinks:
+        for source in sources:
+
+            needs_nosync = force or (
+                    source in kernel.recursive_insn_dep_map()[sink]
+                    or insns_in_conflicting_groups(source, sink))
+
+            if not needs_nosync:
+                continue
+
+            nosync_to_add[sink].add((source, scope))
+            if bidirectional:
+                nosync_to_add[source].add((sink, scope))
+
+    new_instructions = list(kernel.instructions)
+
+    for i, insn in enumerate(new_instructions):
+        if insn.id in nosync_to_add:
+            new_instructions[i] = insn.copy(no_sync_with=insn.no_sync_with
+                    | frozenset(nosync_to_add[insn.id]))
+
+    return kernel.copy(instructions=new_instructions)
+
+# }}}
+
+
 # vim: foldmethod=marker
diff --git a/test/test_transform.py b/test/test_transform.py
index ac5a26f6a..b5fcdf04c 100644
--- a/test/test_transform.py
+++ b/test/test_transform.py
@@ -402,6 +402,42 @@ def test_precompute_with_preexisting_inames_fail():
                 precompute_inames="ii,jj")
 
 
+def test_add_nosync():
+    orig_knl = lp.make_kernel("{[i]: 0<=i<10}",
+        """
+        <>tmp[i] = 10 {id=insn1}
+        <>tmp2[i] = 10 {id=insn2}
+
+        <>tmp3[2*i] = 0 {id=insn3}
+        <>tmp4 = 1 + tmp3[2*i] {id=insn4}
+
+        <>tmp5[i] = 0 {id=insn5,groups=g1}
+        tmp5[i] = 1 {id=insn6,conflicts=g1}
+        """)
+
+    orig_knl = lp.set_temporary_scope(orig_knl, "tmp3", "local")
+    orig_knl = lp.set_temporary_scope(orig_knl, "tmp5", "local")
+
+    # No dependency present - don't add nosync
+    knl = lp.add_nosync(orig_knl, "any", "writes:tmp", "writes:tmp2")
+    assert frozenset() == knl.id_to_insn["insn2"].no_sync_with
+
+    # Dependency present
+    knl = lp.add_nosync(orig_knl, "local", "writes:tmp3", "reads:tmp3")
+    assert frozenset() == knl.id_to_insn["insn3"].no_sync_with
+    assert frozenset([("insn3", "local")]) == knl.id_to_insn["insn4"].no_sync_with
+
+    # Bidirectional
+    knl = lp.add_nosync(
+            orig_knl, "local", "writes:tmp3", "reads:tmp3", bidirectional=True)
+    assert frozenset([("insn4", "local")]) == knl.id_to_insn["insn3"].no_sync_with
+    assert frozenset([("insn3", "local")]) == knl.id_to_insn["insn4"].no_sync_with
+
+    # Groups
+    knl = lp.add_nosync(orig_knl, "local", "insn5", "insn6")
+    assert frozenset([("insn5", "local")]) == knl.id_to_insn["insn6"].no_sync_with
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab