From f2c6e250b903ed9e8051b6520bded749ffcbb5aa Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 11 Aug 2015 23:55:34 -0500
Subject: [PATCH] Implement aliasing of temporaries, fix scheduler for conflict
 groups

---
 doc/reference.rst          |   2 +
 loopy/__init__.py          |  69 ++++++++++++++++++++++++
 loopy/kernel/__init__.py   |  14 ++++-
 loopy/kernel/array.py      |  14 +++--
 loopy/kernel/data.py       |  43 +++++++--------
 loopy/schedule.py          | 105 +++++++++++++++++++++++--------------
 loopy/target/c/__init__.py |  84 ++++++++++++++++++++++++++---
 loopy/version.py           |   2 +-
 test/test_loopy.py         |  33 ++++++++++++
 9 files changed, 294 insertions(+), 72 deletions(-)

diff --git a/doc/reference.rst b/doc/reference.rst
index 6a42ed944..556437ea4 100644
--- a/doc/reference.rst
+++ b/doc/reference.rst
@@ -421,6 +421,8 @@ Caching, Precomputation and Prefetching
 
 .. autofunction:: buffer_array
 
+.. autofunction:: alias_temporaries
+
 Influencing data access
 ^^^^^^^^^^^^^^^^^^^^^^^
 
diff --git a/loopy/__init__.py b/loopy/__init__.py
index a161e5478..3a4674749 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -1845,4 +1845,73 @@ def tag_instructions(kernel, new_tag, within=None):
 
 # }}}
 
+
+# {{{ alias_temporaries
+
+def alias_temporaries(knl, names, base_name_prefix=None):
+    """Sets all temporaries given by *names* to be backed by a single piece of
+    storage. Also introduces ordering structures ("groups") to prevent the
+    usage of each temporary to interfere with another.
+
+    :arg base_name_prefix: an identifier to be used for the common storage
+        area
+    """
+    gng = knl.get_group_name_generator()
+    group_names = [gng("tmpgrp_"+name) for name in names]
+
+    if base_name_prefix is None:
+        base_name_prefix = "temp_storage"
+
+    vng = knl.get_var_name_generator()
+    base_name = vng(base_name_prefix)
+
+    names_set = set(names)
+
+    new_insns = []
+    for insn in knl.instructions:
+        temp_deps = insn.dependency_names() & names_set
+
+        if not temp_deps:
+            new_insns.append(insn)
+            continue
+
+        if len(temp_deps) > 1:
+            raise LoopyError("Instruction {insn} refers to multiple of the "
+                    "temporaries being aliased, namely '{temps}'. Cannot alias."
+                    .format(
+                        insn=insn.id,
+                        temps=", ".join(temp_deps)))
+
+        temp_name, = temp_deps
+        temp_idx = names.index(temp_name)
+        group_name = group_names[temp_idx]
+        other_group_names = (
+                frozenset(group_names[:temp_idx])
+                | frozenset(group_names[temp_idx+1:]))
+
+        new_insns.append(
+                insn.copy(
+                    groups=insn.groups | frozenset([group_name]),
+                    conflicts_with_groups=(
+                        insn.conflicts_with_groups | other_group_names)))
+
+    new_temporary_variables = {}
+    for tv in six.itervalues(knl.temporary_variables):
+        if tv.name in names_set:
+            if tv.base_storage is not None:
+                raise LoopyError("temporary variable '{tv}' already has "
+                        "a defined storage array -- cannot alias"
+                        .format(tv=tv.name))
+
+            new_temporary_variables[tv.name] = \
+                    tv.copy(base_storage=base_name)
+        else:
+            new_temporary_variables[tv.name] = tv
+
+    return knl.copy(
+            instructions=new_insns,
+            temporary_variables=new_temporary_variables)
+
+# }}}
+
 # vim: foldmethod=marker
diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py
index 0884a3fca..4e31db993 100644
--- a/loopy/kernel/__init__.py
+++ b/loopy/kernel/__init__.py
@@ -335,6 +335,9 @@ class LoopKernel(RecordWithoutPickling):
     def all_variable_names(self):
         return (
                 set(six.iterkeys(self.temporary_variables))
+                | set(tv.base_storage
+                    for tv in six.itervalues(self.temporary_variables)
+                    if tv.base_storage is not None)
                 | set(six.iterkeys(self.substitutions))
                 | set(arg.name for arg in self.args)
                 | set(self.all_inames()))
@@ -362,7 +365,7 @@ class LoopKernel(RecordWithoutPickling):
         return frozenset(result)
 
     def get_group_name_generator(self):
-        return _UniqueVarNameGenerator(self.all_group_names())
+        return _UniqueVarNameGenerator(set(self.all_group_names()))
 
     def get_var_descriptor(self, name):
         try:
@@ -761,6 +764,15 @@ class LoopKernel(RecordWithoutPickling):
                 for insn in self.instructions
                 for var_name, _ in insn.assignees_and_indices())
 
+    @memoize_method
+    def get_temporary_to_base_storage_map(self):
+        result = {}
+        for tv in six.itervalues(self.temporary_variables):
+            if tv.base_storage:
+                result[tv.name] = tv.base_storage
+
+        return result
+
     # }}}
 
     # {{{ argument wrangling
diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py
index 2923c945c..98ba0dea7 100644
--- a/loopy/kernel/array.py
+++ b/loopy/kernel/array.py
@@ -874,11 +874,15 @@ class ArrayBase(Record):
 
         return 1
 
-    def decl_info(self, target, is_written, index_dtype):
+    def decl_info(self, target, is_written, index_dtype, shape_override=None):
         """Return a list of :class:`loopy.codegen.ImplementedDataInfo`
         instances corresponding to the array.
         """
 
+        array_shape = self.shape
+        if shape_override is not None:
+            array_shape = shape_override
+
         from loopy.codegen import ImplementedDataInfo
         from loopy.kernel.data import ValueArg
 
@@ -978,10 +982,10 @@ class ArrayBase(Record):
             dim_tag = self.dim_tags[user_axis]
 
             if isinstance(dim_tag, FixedStrideArrayDimTag):
-                if self.shape is None:
+                if array_shape is None:
                     new_shape_axis = None
                 else:
-                    new_shape_axis = self.shape[user_axis]
+                    new_shape_axis = array_shape[user_axis]
 
                 import loopy as lp
                 if dim_tag.stride is lp.auto:
@@ -1004,7 +1008,7 @@ class ArrayBase(Record):
                     yield res
 
             elif isinstance(dim_tag, SeparateArrayArrayDimTag):
-                shape_i = self.shape[user_axis]
+                shape_i = array_shape[user_axis]
                 if not is_integer(shape_i):
                     raise LoopyError("shape of '%s' has non-constant "
                             "integer axis %d (0-based)" % (
@@ -1018,7 +1022,7 @@ class ArrayBase(Record):
                         yield res
 
             elif isinstance(dim_tag, VectorArrayDimTag):
-                shape_i = self.shape[user_axis]
+                shape_i = array_shape[user_axis]
                 if not is_integer(shape_i):
                     raise LoopyError("shape of '%s' has non-constant "
                             "integer axis %d (0-based)" % (
diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py
index 2d9faa4cf..c5cecfde2 100644
--- a/loopy/kernel/data.py
+++ b/loopy/kernel/data.py
@@ -315,6 +315,11 @@ class TemporaryVariable(ArrayBase):
         Whether this is temporary lives in ``local`` memory.
         May be *True*, *False*, or :class:`loopy.auto` if this is
         to be automatically determined.
+
+    .. attribute:: base_storage
+
+        The name of a storage array that is to be used to actually
+        hold the data in this temporary.
     """
 
     min_target_axes = 0
@@ -323,12 +328,14 @@ class TemporaryVariable(ArrayBase):
     allowed_extra_kwargs = [
             "storage_shape",
             "base_indices",
-            "is_local"
+            "is_local",
+            "base_storage"
             ]
 
     def __init__(self, name, dtype=None, shape=(), is_local=auto,
             dim_tags=None, offset=0, strides=None, order=None,
-            base_indices=None, storage_shape=None):
+            base_indices=None, storage_shape=None,
+            base_storage=None):
         """
         :arg dtype: :class:`loopy.auto` or a :class:`numpy.dtype`
         :arg shape: :class:`loopy.auto` or a shape tuple
@@ -346,31 +353,25 @@ class TemporaryVariable(ArrayBase):
                 dtype=dtype, shape=shape,
                 dim_tags=dim_tags, order="C",
                 base_indices=base_indices, is_local=is_local,
-                storage_shape=storage_shape)
+                storage_shape=storage_shape,
+                base_storage=base_storage)
 
     @property
     def nbytes(self):
-        from pytools import product
-        return product(si for si in self.shape)*self.dtype.itemsize
-
-    def get_arg_decl(self, target, name_suffix, shape, dtype, is_written):
-        from cgen import ArrayOf
-        from loopy.codegen import POD  # uses the correct complex type
-        from cgen.opencl import CLLocal
+        shape = self.shape
+        if self.storage_shape is not None:
+            shape = self.storage_shape
 
-        temp_var_decl = POD(target, dtype, self.name)
-
-        # FIXME take into account storage_shape, or something like it
-        storage_shape = shape
-
-        if storage_shape:
-            temp_var_decl = ArrayOf(temp_var_decl,
-                    " * ".join(str(s) for s in storage_shape))
+        from pytools import product
+        return product(si for si in shape)*self.dtype.itemsize
 
-        if self.is_local:
-            temp_var_decl = CLLocal(temp_var_decl)
+    def decl_info(self, target, index_dtype):
+        return super(TemporaryVariable, self).decl_info(
+                target, is_written=True, index_dtype=index_dtype,
+                shape_override=self.storage_shape)
 
-        return temp_var_decl
+    def get_arg_decl(self, target, name_suffix, shape, dtype, is_written):
+        return None
 
     def __str__(self):
         return self.stringify(include_typename=False)
diff --git a/loopy/schedule.py b/loopy/schedule.py
index b44569b97..113819b9a 100644
--- a/loopy/schedule.py
+++ b/loopy/schedule.py
@@ -497,7 +497,7 @@ def generate_loop_schedules_internal(
 
                     else:
                         new_active_group_counts[grp] = (
-                                sched_state.group_insn_counts[grp])
+                                sched_state.group_insn_counts[grp] - 1)
 
             else:
                 new_active_group_counts = sched_state.active_group_counts
@@ -603,6 +603,9 @@ def generate_loop_schedules_internal(
         print("active inames :", ",".join(sched_state.active_inames))
         print("inames entered so far :", ",".join(sched_state.entered_inames))
         print("reachable insns:", ",".join(reachable_insn_ids))
+        print("active groups (with insn counts):", ",".join(
+            "%s: %d" % (grp, c)
+            for grp, c in six.iteritems(sched_state.active_group_counts)))
         print(75*"-")
 
     if needed_inames:
@@ -795,6 +798,11 @@ class DependencyRecord(Record):
 
         A :class:`loopy.InstructionBase` instance.
 
+    .. attribute:: dep_descr
+
+        A string containing a phrase describing the dependency. The variables
+        '{src}' and '{tgt}' will be replaced by their respective instruction IDs.
+
     .. attribute:: variable
 
         A string, the name of the variable that caused the dependency to arise.
@@ -802,23 +810,15 @@ class DependencyRecord(Record):
     .. attribute:: var_kind
 
         "global" or "local"
-
-    .. attribute:: is_forward
-
-        A :class:`bool` indicating whether this is a forward or reverse
-        dependency.
-
-        In a 'forward' dependency, the target depends on the source.
-        In a 'reverse' dependency, the source depends on the target.
     """
 
-    def __init__(self, source, target, variable, var_kind, is_forward):
+    def __init__(self, source, target, dep_descr, variable, var_kind):
         Record.__init__(self,
                 source=source,
                 target=target,
+                dep_descr=dep_descr,
                 variable=variable,
-                var_kind=var_kind,
-                is_forward=is_forward)
+                var_kind=var_kind)
 
 
 def get_barrier_needing_dependency(kernel, target, source, reverse, var_kind):
@@ -827,7 +827,7 @@ def get_barrier_needing_dependency(kernel, target, source, reverse, var_kind):
     at least one write), then the function will return a tuple
     ``(target, source, var_name)``. Otherwise, it will return *None*.
 
-    This function finds  direct or indirect instruction dependencies, but does
+    This function finds direct or indirect instruction dependencies, but does
     not attempt to guess dependencies that exist based on common access to
     variables.
 
@@ -847,11 +847,30 @@ def get_barrier_needing_dependency(kernel, target, source, reverse, var_kind):
     if reverse:
         source, target = target, source
 
-    # Check that a dependency exists.
+    # {{{ check that a dependency exists
+
+    dep_descr = None
+
     target_deps = kernel.recursive_insn_dep_map()[target.id]
-    if source.id not in target_deps:
+    if source.id in target_deps:
+        if reverse:
+            dep_descr = "{src} rev-depends on {tgt}"
+        else:
+            dep_descr = "{tgt} depends on {src}"
+
+    grps = source.groups & target.conflicts_with_groups
+    if grps:
+        dep_descr = "{src} conflicts with {tgt} (via '%s')" % ", ".join(grps)
+
+    grps = target.groups & source.conflicts_with_groups
+    if grps:
+        dep_descr = "{src} conflicts with {tgt} (via '%s')" % ", ".join(grps)
+
+    if not dep_descr:
         return None
 
+    # }}}
+
     if var_kind == "local":
         relevant_vars = kernel.local_var_names()
     elif var_kind == "global":
@@ -859,11 +878,27 @@ def get_barrier_needing_dependency(kernel, target, source, reverse, var_kind):
     else:
         raise ValueError("unknown 'var_kind': %s" % var_kind)
 
-    tgt_write = set(target.assignee_var_names()) & relevant_vars
-    tgt_read = target.read_dependency_names() & relevant_vars
+    temp_to_base_storage = kernel.get_temporary_to_base_storage_map()
+
+    def map_to_base_storage(var_names):
+        result = set(var_names)
+
+        for name in var_names:
+            bs = temp_to_base_storage.get(name)
+            if bs is not None:
+                result.add(bs)
 
-    src_write = set(source.assignee_var_names()) & relevant_vars
-    src_read = source.read_dependency_names() & relevant_vars
+        return result
+
+    tgt_write = map_to_base_storage(
+            set(target.assignee_var_names()) & relevant_vars)
+    tgt_read = map_to_base_storage(
+            target.read_dependency_names() & relevant_vars)
+
+    src_write = map_to_base_storage(
+            set(source.assignee_var_names()) & relevant_vars)
+    src_read = map_to_base_storage(
+            source.read_dependency_names() & relevant_vars)
 
     waw = tgt_write & src_write
     raw = tgt_read & src_write
@@ -873,9 +908,9 @@ def get_barrier_needing_dependency(kernel, target, source, reverse, var_kind):
         return DependencyRecord(
                 source=source,
                 target=target,
+                dep_descr=dep_descr,
                 variable=var_name,
-                var_kind=var_kind,
-                is_forward=not reverse)
+                var_kind=var_kind)
 
     if source is target:
         return None
@@ -884,9 +919,9 @@ def get_barrier_needing_dependency(kernel, target, source, reverse, var_kind):
         return DependencyRecord(
                 source=source,
                 target=target,
+                dep_descr=dep_descr,
                 variable=var_name,
-                var_kind=var_kind,
-                is_forward=not reverse)
+                var_kind=var_kind)
 
     return None
 
@@ -998,12 +1033,9 @@ def insert_barriers(kernel, schedule, reverse, kind, level=0):
 
         comment = None
         if dep is not None:
-            if dep.is_forward:
-                comment = "for %s (%s depends on %s)" % (
-                        dep.variable, dep.target.id, dep.source.id)
-            else:
-                comment = "for %s (%s rev-depends on %s)" % (
-                        dep.variable, dep.source.id, dep.target.id)
+            comment = "for %s (%s)" % (
+                    dep.variable, dep.dep_descr.format(
+                        tgt=dep.target.id, src=dep.source.id))
 
         result.append(Barrier(comment=comment, kind=dep.var_kind))
 
@@ -1047,10 +1079,6 @@ def insert_barriers(kernel, schedule, reverse, kind, level=0):
             # (for leading (before-first-barrier) bit of loop body)
             for insn_id in insn_ids_from_schedule(subresult[:first_barrier_index]):
                 search_set = candidates
-                if not reverse:
-                    # can limit search set in case of forward dep
-                    search_set = search_set \
-                            & kernel.recursive_insn_dep_map()[insn_id]
 
                 for dep_src_insn_id in search_set:
                     dep = get_barrier_needing_dependency(
@@ -1090,10 +1118,6 @@ def insert_barriers(kernel, schedule, reverse, kind, level=0):
             i += 1
 
             search_set = candidates
-            if not reverse:
-                # can limit search set in case of forward dep
-                search_set = search_set \
-                        & kernel.recursive_insn_dep_map()[sched_item.insn_id]
 
             for dep_src_insn_id in search_set:
                 dep = get_barrier_needing_dependency(
@@ -1190,10 +1214,15 @@ def generate_loop_schedules(kernel, debug_args={}):
             #         raise LoopyError("kernel requires a global barrier %s"
             #                 % sched_item.comment)
 
+            debug.stop()
+
+            logger.info("%s: barrier insertion: start" % kernel.name)
+
             gen_sched = insert_barriers(kernel, gen_sched,
                     reverse=False, kind="local")
 
-            debug.stop()
+            logger.info("%s: barrier insertion: done" % kernel.name)
+
             yield kernel.copy(
                     schedule=gen_sched,
                     state=kernel_state.SCHEDULED)
diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py
index 28943644e..8620996b0 100644
--- a/loopy/target/c/__init__.py
+++ b/loopy/target/c/__init__.py
@@ -85,14 +85,86 @@ class CTarget(TargetBase):
         from cgen import Block
         body = Block()
 
+        temp_decls = []
+
         # {{{ declare temporaries
 
-        body.extend(
-                idi.cgen_declarator
-                for tv in six.itervalues(kernel.temporary_variables)
-                for idi in tv.decl_info(
-                    kernel.target,
-                    is_written=True, index_dtype=kernel.index_dtype))
+        base_storage_sizes = {}
+        base_storage_to_is_local = {}
+
+        from cgen import ArrayOf, Pointer, Initializer
+        from loopy.codegen import POD  # uses the correct complex type
+        from cgen.opencl import CLLocal
+
+        class ConstRestrictPointer(Pointer):
+            def get_decl_pair(self):
+                sub_tp, sub_decl = self.subdecl.get_decl_pair()
+                return sub_tp, ("*const restrict %s" % sub_decl)
+
+        for tv in six.itervalues(kernel.temporary_variables):
+            decl_info = tv.decl_info(self, index_dtype=kernel.index_dtype)
+
+            if not tv.base_storage:
+                for idi in decl_info:
+                    temp_var_decl = POD(self, idi.dtype, idi.name)
+
+                    if idi.shape:
+                        temp_var_decl = ArrayOf(temp_var_decl,
+                                " * ".join(str(s) for s in idi.shape))
+
+                    if tv.is_local:
+                        temp_var_decl = CLLocal(temp_var_decl)
+
+                    temp_decls.append(temp_var_decl)
+
+            else:
+                offset = 0
+                base_storage_sizes.setdefault(tv.base_storage, []).append(
+                        tv.nbytes)
+                base_storage_to_is_local.setdefault(tv.base_storage, []).append(
+                        tv.is_local)
+
+                for idi in decl_info:
+                    cast_decl = POD(self, idi.dtype, "")
+                    temp_var_decl = POD(self, idi.dtype, idi.name)
+
+                    if tv.is_local:
+                        cast_decl = CLLocal(cast_decl)
+                        temp_var_decl = CLLocal(temp_var_decl)
+
+                    # The 'restrict' part of this is a complete lie--of course
+                    # all these temporaries are aliased. But we're promising to
+                    # not use them to shovel data from one representation to the
+                    # other. That counts, right?
+
+                    cast_decl = ConstRestrictPointer(cast_decl)
+                    temp_var_decl = ConstRestrictPointer(temp_var_decl)
+
+                    cast_tp, cast_d = cast_decl.get_decl_pair()
+                    temp_var_decl = Initializer(
+                            temp_var_decl,
+                            "(%s %s) (%s + %s)" % (
+                                " ".join(cast_tp), cast_d,
+                                tv.base_storage,
+                                offset))
+
+                    temp_decls.append(temp_var_decl)
+
+                    from pytools import product
+                    offset += (
+                            idi.dtype.itemsize
+                            * product(si for si in idi.shape))
+
+        for bs_name, bs_sizes in six.iteritems(base_storage_sizes):
+            bs_var_decl = POD(self, np.int8, bs_name)
+            if base_storage_to_is_local[bs_name]:
+                bs_var_decl = CLLocal(bs_var_decl)
+
+            bs_var_decl = ArrayOf(bs_var_decl, max(bs_sizes))
+
+            body.append(bs_var_decl)
+
+        body.extend(temp_decls)
 
         # }}}
 
diff --git a/loopy/version.py b/loopy/version.py
index 9f1378f16..9598697b0 100644
--- a/loopy/version.py
+++ b/loopy/version.py
@@ -32,4 +32,4 @@ except ImportError:
 else:
     _islpy_version = islpy.version.VERSION_TEXT
 
-DATA_MODEL_VERSION = "v10-islpy%s" % _islpy_version
+DATA_MODEL_VERSION = "v11-islpy%s" % _islpy_version
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 17e0cc543..874bb55fb 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -2094,6 +2094,39 @@ def test_vectorize(ctx_factory):
             ref_knl, ctx, knl,
             parameters=dict(n=30))
 
+
+def test_alias_temporaries(ctx_factory):
+    ctx = ctx_factory()
+
+    knl = lp.make_kernel(
+        "{[i]: 0<=i<n}",
+        """
+        times2(i) := 2*a[i]
+        times3(i) := 3*a[i]
+        times4(i) := 4*a[i]
+
+        x[i] = times2(i)
+        y[i] = times3(i)
+        z[i] = times4(i)
+        """)
+
+    knl = lp.add_and_infer_dtypes(knl, {"a": np.float32})
+
+    ref_knl = knl
+
+    knl = lp.split_iname(knl, "i", 16, outer_tag="g.0", inner_tag="l.0")
+
+    knl = lp.precompute(knl, "times2", "i_inner")
+    knl = lp.precompute(knl, "times3", "i_inner")
+    knl = lp.precompute(knl, "times4", "i_inner")
+
+    knl = lp.alias_temporaries(knl, ["times2_0", "times3_0", "times4_0"])
+
+    lp.auto_test_vs_ref(
+            ref_knl, ctx, knl,
+            parameters=dict(n=30))
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab