From f5ed460dcb98b5cb18848725175f4191df7c12d7 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 13 Apr 2015 22:45:53 -0500
Subject: [PATCH] Factor array_buffer out of precompute

---
 loopy/array_buffer.py | 408 ++++++++++++++++++++++++++++++++++++
 loopy/precompute.py   | 476 ++++++++----------------------------------
 2 files changed, 496 insertions(+), 388 deletions(-)
 create mode 100644 loopy/array_buffer.py

diff --git a/loopy/array_buffer.py b/loopy/array_buffer.py
new file mode 100644
index 000000000..c15935b73
--- /dev/null
+++ b/loopy/array_buffer.py
@@ -0,0 +1,408 @@
+from __future__ import division, absolute_import
+from six.moves import range, zip
+
+__copyright__ = "Copyright (C) 2012-2015 Andreas Kloeckner"
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+import islpy as isl
+from islpy import dim_type
+from loopy.symbolic import (get_dependencies, SubstitutionMapper)
+from pymbolic.mapper.substitutor import make_subst_func
+
+from pytools import Record
+from pymbolic import var
+
+
+class AccessDescriptor(Record):
+    """
+    .. attribute:: identifier
+
+        An identifier under user control, used to connect this access descriptor
+        to the access that generated it. Any Python value.
+    """
+
+    __slots__ = [
+            "identifier",
+            "expands_footprint",
+            "storage_axis_exprs",
+            ]
+
+
+def to_parameters_or_project_out(param_inames, set_inames, set):
+    for iname in list(set.get_space().get_var_dict().keys()):
+        if iname in param_inames:
+            dt, idx = set.get_space().get_var_dict()[iname]
+            set = set.move_dims(
+                    dim_type.param, set.dim(dim_type.param),
+                    dt, idx, 1)
+        elif iname in set_inames:
+            pass
+        else:
+            dt, idx = set.get_space().get_var_dict()[iname]
+            set = set.project_out(dt, idx, 1)
+
+    return set
+
+
+# {{{ construct storage->sweep map
+
+def build_per_access_storage_to_domain_map(accdesc, domain,
+        storage_axis_names,
+        prime_sweep_inames):
+
+    map_space = domain.space
+    stor_dim = len(storage_axis_names)
+    rn = map_space.dim(dim_type.out)
+
+    map_space = map_space.add_dims(dim_type.in_, stor_dim)
+    for i, saxis in enumerate(storage_axis_names):
+        # arg names are initially primed, to be replaced with unprimed
+        # base-0 versions below
+
+        map_space = map_space.set_dim_name(dim_type.in_, i, saxis+"'")
+
+    # map_space: [stor_axes'] -> [domain](dup_sweep_index)[dup_sweep](rn)
+
+    set_space = map_space.move_dims(
+            dim_type.out, rn,
+            dim_type.in_, 0, stor_dim).range()
+
+    # set_space: [domain](dup_sweep_index)[dup_sweep](rn)[stor_axes']
+
+    stor2sweep = None
+
+    from loopy.symbolic import aff_from_expr
+
+    for saxis, sa_expr in zip(storage_axis_names, accdesc.storage_axis_exprs):
+        cns = isl.Constraint.equality_from_aff(
+                aff_from_expr(set_space,
+                    var(saxis+"'") - prime_sweep_inames(sa_expr)))
+
+        cns_map = isl.BasicMap.from_constraint(cns)
+        if stor2sweep is None:
+            stor2sweep = cns_map
+        else:
+            stor2sweep = stor2sweep.intersect(cns_map)
+
+    if stor2sweep is not None:
+        stor2sweep = stor2sweep.move_dims(
+                dim_type.in_, 0,
+                dim_type.out, rn, stor_dim)
+
+    # stor2sweep is back in map_space
+    return stor2sweep
+
+
+def move_to_par_from_out(s2smap, except_inames):
+    while True:
+        var_dict = s2smap.get_var_dict(dim_type.out)
+        todo_inames = set(var_dict) - except_inames
+        if todo_inames:
+            iname = todo_inames.pop()
+
+            _, dim_idx = var_dict[iname]
+            s2smap = s2smap.move_dims(
+                    dim_type.param, s2smap.dim(dim_type.param),
+                    dim_type.out, dim_idx, 1)
+        else:
+            return s2smap
+
+
+def build_global_storage_to_sweep_map(kernel, access_descriptors,
+        domain_dup_sweep, dup_sweep_index,
+        storage_axis_names,
+        sweep_inames, primed_sweep_inames, prime_sweep_inames):
+    # The storage map goes from storage axes to the domain.
+    # The first len(arg_names) storage dimensions are the rule's arguments.
+
+    global_stor2sweep = None
+
+    # build footprint
+    for accdesc in access_descriptors:
+        if accdesc.expands_footprint:
+            stor2sweep = build_per_access_storage_to_domain_map(
+                    accdesc, domain_dup_sweep,
+                    storage_axis_names,
+                    prime_sweep_inames)
+
+            if global_stor2sweep is None:
+                global_stor2sweep = stor2sweep
+            else:
+                global_stor2sweep = global_stor2sweep.union(stor2sweep)
+
+    if isinstance(global_stor2sweep, isl.BasicMap):
+        global_stor2sweep = isl.Map.from_basic_map(global_stor2sweep)
+    global_stor2sweep = global_stor2sweep.intersect_range(domain_dup_sweep)
+
+    # space for global_stor2sweep:
+    # [stor_axes'] -> [domain](dup_sweep_index)[dup_sweep](rn)
+
+    return global_stor2sweep
+
+# }}}
+
+
+# {{{ compute storage bounds
+
+def find_var_base_indices_and_shape_from_inames(
+        domain, inames, cache_manager, context=None):
+    base_indices_and_sizes = [
+            cache_manager.base_index_and_length(domain, iname, context)
+            for iname in inames]
+    return list(zip(*base_indices_and_sizes))
+
+
+def compute_bounds(kernel, domain, stor2sweep,
+        primed_sweep_inames, storage_axis_names):
+
+    bounds_footprint_map = move_to_par_from_out(
+            stor2sweep, except_inames=frozenset(primed_sweep_inames))
+
+    # compute bounds for each storage axis
+    storage_domain = bounds_footprint_map.domain().coalesce()
+
+    if not storage_domain.is_bounded():
+        raise RuntimeError("sweep did not result in a bounded storage domain")
+
+    return find_var_base_indices_and_shape_from_inames(
+            storage_domain, [saxis+"'" for saxis in storage_axis_names],
+            kernel.cache_manager, context=kernel.assumptions)
+
+# }}}
+
+
+# {{{ array-to-buffer map
+
+class ArrayToBufferMap(object):
+    def __init__(self, kernel, domain, sweep_inames, access_descriptors,
+            storage_axis_count):
+        self.kernel = kernel
+        self.sweep_inames = sweep_inames
+
+        storage_axis_names = self.storage_axis_names = [
+                "_loopy_storage_%d" % i for i in range(storage_axis_count)]
+
+        # {{{ duplicate sweep inames
+
+        # The duplication is necessary, otherwise the storage fetch
+        # inames remain weirdly tied to the original sweep inames.
+
+        self.primed_sweep_inames = [psin+"'" for psin in sweep_inames]
+
+        from loopy.isl_helpers import duplicate_axes
+        dup_sweep_index = domain.space.dim(dim_type.out)
+        domain_dup_sweep = duplicate_axes(
+                domain, sweep_inames,
+                self.primed_sweep_inames)
+
+        self.prime_sweep_inames = SubstitutionMapper(make_subst_func(
+            dict((sin, var(psin))
+                for sin, psin in zip(sweep_inames, self.primed_sweep_inames))))
+
+        # # }}}
+
+        self.stor2sweep = build_global_storage_to_sweep_map(
+                kernel, access_descriptors,
+                domain_dup_sweep, dup_sweep_index,
+                storage_axis_names,
+                sweep_inames, self.primed_sweep_inames, self.prime_sweep_inames)
+
+        storage_base_indices, storage_shape = compute_bounds(
+                kernel, domain, self.stor2sweep, self.primed_sweep_inames,
+                storage_axis_names)
+
+        # compute augmented domain
+
+        # {{{ filter out unit-length dimensions
+
+        non1_storage_axis_flags = []
+        non1_storage_shape = []
+
+        for saxis, bi, l in zip(
+                storage_axis_names, storage_base_indices, storage_shape):
+            has_length_non1 = l != 1
+
+            non1_storage_axis_flags.append(has_length_non1)
+
+            if has_length_non1:
+                non1_storage_shape.append(l)
+
+        # }}}
+
+        # {{{ subtract off the base indices
+        # add the new, base-0 indices as new in dimensions
+
+        sp = self.stor2sweep.get_space()
+        stor_idx = sp.dim(dim_type.out)
+
+        n_stor = storage_axis_count
+        nn1_stor = len(non1_storage_shape)
+
+        aug_domain = self.stor2sweep.move_dims(
+                dim_type.out, stor_idx,
+                dim_type.in_, 0,
+                n_stor).range()
+
+        # aug_domain space now:
+        # [domain](dup_sweep_index)[dup_sweep](stor_idx)[stor_axes']
+
+        aug_domain = aug_domain.insert_dims(dim_type.set, stor_idx, nn1_stor)
+
+        inew = 0
+        for i, name in enumerate(storage_axis_names):
+            if non1_storage_axis_flags[i]:
+                aug_domain = aug_domain.set_dim_name(
+                        dim_type.set, stor_idx + inew, name)
+                inew += 1
+
+        # aug_domain space now:
+        # [domain](dup_sweep_index)[dup_sweep](stor_idx)[stor_axes'][n1_stor_axes]
+
+        from loopy.symbolic import aff_from_expr
+        for saxis, bi, s in zip(storage_axis_names, storage_base_indices,
+                storage_shape):
+            if s != 1:
+                cns = isl.Constraint.equality_from_aff(
+                        aff_from_expr(aug_domain.get_space(),
+                            var(saxis) - (var(saxis+"'") - bi)))
+
+                aug_domain = aug_domain.add_constraint(cns)
+
+        # }}}
+
+        # eliminate (primed) storage axes with non-zero base indices
+        aug_domain = aug_domain.project_out(dim_type.set, stor_idx+nn1_stor, n_stor)
+
+        # eliminate duplicated sweep_inames
+        nsweep = len(sweep_inames)
+        aug_domain = aug_domain.project_out(dim_type.set, dup_sweep_index, nsweep)
+
+        self.non1_storage_axis_flags = non1_storage_axis_flags
+        self.aug_domain = aug_domain
+        self.storage_base_indices = storage_base_indices
+        self.non1_storage_shape = non1_storage_shape
+
+    def augment_domain_with_sweep(self, domain, new_non1_storage_axis_names,
+            boxify_sweep=False):
+
+        renamed_aug_domain = self.aug_domain
+        first_storage_index = (
+                renamed_aug_domain.dim(dim_type.set)
+                - len(self.non1_storage_shape))
+
+        inon1 = 0
+        for i, old_name in enumerate(self.storage_axis_names):
+            if not self.non1_storage_axis_flags[i]:
+                continue
+
+            new_name = new_non1_storage_axis_names[inon1]
+
+            assert (
+                    renamed_aug_domain.get_dim_name(
+                        dim_type.set, first_storage_index+inon1)
+                    == old_name)
+            renamed_aug_domain = renamed_aug_domain.set_dim_name(
+                    dim_type.set, first_storage_index+inon1, new_name)
+
+            inon1 += 1
+
+        domain, renamed_aug_domain = isl.align_two(domain, renamed_aug_domain)
+        domain = domain & renamed_aug_domain
+
+        from loopy.isl_helpers import convexify, boxify
+        if boxify_sweep:
+            return boxify(self.kernel.cache_manager, domain,
+                    new_non1_storage_axis_names, self.kernel.assumptions)
+        else:
+            return convexify(domain)
+
+    def is_access_descriptor_in_footprint(self, accdesc):
+        if accdesc.expands_footprint:
+            return True
+
+        # Make all inames except the sweep parameters. (The footprint may depend on
+        # those.) (I.e. only leave sweep inames as out parameters.)
+
+        global_s2s_par_dom = move_to_par_from_out(
+                self.stor2sweep,
+                except_inames=frozenset(self.primed_sweep_inames)).domain()
+
+        arg_inames = (
+                set(global_s2s_par_dom.get_var_names(dim_type.param))
+                & self.kernel.all_inames())
+
+        for arg in accdesc.args:
+            arg_inames.update(get_dependencies(arg))
+        arg_inames = frozenset(arg_inames)
+
+        from loopy.kernel import CannotBranchDomainTree
+        try:
+            usage_domain = self.kernel.get_inames_domain(arg_inames)
+        except CannotBranchDomainTree:
+            return False
+
+        for i in range(usage_domain.dim(dim_type.set)):
+            iname = usage_domain.get_dim_name(dim_type.set, i)
+            if iname in self.sweep_inames:
+                usage_domain = usage_domain.set_dim_name(
+                        dim_type.set, i, iname+"'")
+
+        stor2sweep = build_per_access_storage_to_domain_map(accdesc,
+                usage_domain, self.storage_axis_names,
+                self.prime_sweep_inames)
+
+        if isinstance(stor2sweep, isl.BasicMap):
+            stor2sweep = isl.Map.from_basic_map(stor2sweep)
+
+        stor2sweep = stor2sweep.intersect_range(usage_domain)
+
+        stor2sweep = move_to_par_from_out(stor2sweep,
+                except_inames=frozenset(self.primed_sweep_inames))
+
+        s2s_domain = stor2sweep.domain()
+        s2s_domain, aligned_g_s2s_parm_dom = isl.align_two(
+                s2s_domain, global_s2s_par_dom)
+
+        arg_restrictions = (
+                aligned_g_s2s_parm_dom
+                .eliminate(dim_type.set, 0,
+                    aligned_g_s2s_parm_dom.dim(dim_type.set))
+                .remove_divs())
+
+        return (arg_restrictions & s2s_domain).is_subset(
+                aligned_g_s2s_parm_dom)
+
+
+class NoOpArrayToBufferMap(object):
+    non1_storage_axis_names = ()
+    storage_base_indices = ()
+    non1_storage_shape = ()
+
+    def is_access_descriptor_in_footprint(self, accdesc):
+        # no index dependencies--every reference to the subst rule
+        # is necessarily in the footprint.
+
+        return True
+
+# }}}
+
+# vim: foldmethod=marker
diff --git a/loopy/precompute.py b/loopy/precompute.py
index 8bd049b12..de5cd9e02 100644
--- a/loopy/precompute.py
+++ b/loopy/precompute.py
@@ -1,8 +1,6 @@
-from __future__ import division
-from __future__ import absolute_import
+from __future__ import division, absolute_import
 import six
-from six.moves import range
-from six.moves import zip
+from six.moves import range, zip
 
 __copyright__ = "Copyright (C) 2012 Andreas Kloeckner"
 
@@ -28,337 +26,35 @@ THE SOFTWARE.
 
 
 import islpy as isl
-from islpy import dim_type
 from loopy.symbolic import (get_dependencies, SubstitutionMapper,
         ExpandingIdentityMapper)
 from pymbolic.mapper.substitutor import make_subst_func
 import numpy as np
 
-from pytools import Record
 from pymbolic import var
 
+from loopy.array_buffer import (ArrayToBufferMap, NoOpArrayToBufferMap,
+        AccessDescriptor)
 
-class InvocationDescriptor(Record):
-    __slots__ = [
-            "args",
-            "expands_footprint",
-            "is_in_footprint",
 
-            # Remember where the invocation happened, in terms of the expansion
-            # call stack.
-            "expansion_stack",
-            ]
+class RuleAccessDescriptor(AccessDescriptor):
+    __slots__ = ["args", "expansion_stack"]
 
 
-def to_parameters_or_project_out(param_inames, set_inames, set):
-    for iname in list(set.get_space().get_var_dict().keys()):
-        if iname in param_inames:
-            dt, idx = set.get_space().get_var_dict()[iname]
-            set = set.move_dims(
-                    dim_type.param, set.dim(dim_type.param),
-                    dt, idx, 1)
-        elif iname in set_inames:
-            pass
-        else:
-            dt, idx = set.get_space().get_var_dict()[iname]
-            set = set.project_out(dt, idx, 1)
-
-    return set
-
-
-# {{{ construct storage->sweep map
-
-def build_per_access_storage_to_domain_map(invdesc, domain,
-        storage_axis_names, storage_axis_sources,
-        prime_sweep_inames):
-
-    map_space = domain.get_space()
-    stor_dim = len(storage_axis_names)
-    rn = map_space.dim(dim_type.out)
-
-    map_space = map_space.add_dims(dim_type.in_, stor_dim)
-    for i, saxis in enumerate(storage_axis_names):
-        # arg names are initially primed, to be replaced with unprimed
-        # base-0 versions below
-
-        map_space = map_space.set_dim_name(dim_type.in_, i, saxis+"'")
+def access_descriptor_id(args, expansion_stack):
+    return (args, expansion_stack)
 
-    # map_space: [stor_axes'] -> [domain](dup_sweep_index)[dup_sweep](rn)
 
-    set_space = map_space.move_dims(
-            dim_type.out, rn,
-            dim_type.in_, 0, stor_dim).range()
+def storage_axis_exprs(storage_axis_sources, args):
+    result = []
 
-    # set_space: [domain](dup_sweep_index)[dup_sweep](rn)[stor_axes']
-
-    stor2sweep = None
-
-    from loopy.symbolic import aff_from_expr
-
-    for saxis, saxis_source in zip(storage_axis_names, storage_axis_sources):
+    for saxis_source in storage_axis_sources:
         if isinstance(saxis_source, int):
-            # an argument
-            cns = isl.Constraint.equality_from_aff(
-                    aff_from_expr(set_space,
-                        var(saxis+"'")
-                        - prime_sweep_inames(invdesc.args[saxis_source])))
-        else:
-            # a 'bare' sweep iname
-            cns = isl.Constraint.equality_from_aff(
-                    aff_from_expr(set_space,
-                        var(saxis+"'")
-                        - prime_sweep_inames(var(saxis_source))))
-
-        cns_map = isl.BasicMap.from_constraint(cns)
-        if stor2sweep is None:
-            stor2sweep = cns_map
+            result.append(args[saxis_source])
         else:
-            stor2sweep = stor2sweep.intersect(cns_map)
-
-    if stor2sweep is not None:
-        stor2sweep = stor2sweep.move_dims(
-                dim_type.in_, 0,
-                dim_type.out, rn, stor_dim)
-
-    # stor2sweep is back in map_space
-    return stor2sweep
-
-
-def move_to_par_from_out(s2smap, except_inames):
-    while True:
-        var_dict = s2smap.get_var_dict(dim_type.out)
-        todo_inames = set(var_dict) - except_inames
-        if todo_inames:
-            iname = todo_inames.pop()
-
-            _, dim_idx = var_dict[iname]
-            s2smap = s2smap.move_dims(
-                    dim_type.param, s2smap.dim(dim_type.param),
-                    dim_type.out, dim_idx, 1)
-        else:
-            return s2smap
-
-
-def build_global_storage_to_sweep_map(kernel, invocation_descriptors,
-        domain_dup_sweep, dup_sweep_index,
-        storage_axis_names, storage_axis_sources,
-        sweep_inames, primed_sweep_inames, prime_sweep_inames):
-    """
-    As a side effect, this fills out is_in_footprint in the
-    invocation descriptors.
-    """
-
-    # The storage map goes from storage axes to the domain.
-    # The first len(arg_names) storage dimensions are the rule's arguments.
-
-    global_stor2sweep = None
-
-    # build footprint
-    for invdesc in invocation_descriptors:
-        if invdesc.expands_footprint:
-            stor2sweep = build_per_access_storage_to_domain_map(
-                    invdesc, domain_dup_sweep,
-                    storage_axis_names, storage_axis_sources,
-                    prime_sweep_inames)
-
-            if global_stor2sweep is None:
-                global_stor2sweep = stor2sweep
-            else:
-                global_stor2sweep = global_stor2sweep.union(stor2sweep)
-
-            invdesc.is_in_footprint = True
-
-    if isinstance(global_stor2sweep, isl.BasicMap):
-        global_stor2sweep = isl.Map.from_basic_map(global_stor2sweep)
-    global_stor2sweep = global_stor2sweep.intersect_range(domain_dup_sweep)
-
-    # space for global_stor2sweep:
-    # [stor_axes'] -> [domain](dup_sweep_index)[dup_sweep](rn)
-
-    # {{{ check if non-footprint-building invocation descriptors fall into footprint
-
-    # Make all inames except the sweep parameters. (The footprint may depend on
-    # those.) (I.e. only leave sweep inames as out parameters.)
-    global_s2s_par_dom = move_to_par_from_out(
-            global_stor2sweep, except_inames=frozenset(primed_sweep_inames)).domain()
-
-    for invdesc in invocation_descriptors:
-        if not invdesc.expands_footprint:
-            arg_inames = (
-                    set(global_s2s_par_dom.get_var_names(dim_type.param))
-                    & kernel.all_inames())
-
-            for arg in invdesc.args:
-                arg_inames.update(get_dependencies(arg))
-            arg_inames = frozenset(arg_inames)
-
-            from loopy.kernel import CannotBranchDomainTree
-            try:
-                usage_domain = kernel.get_inames_domain(arg_inames)
-            except CannotBranchDomainTree:
-                # and that's the end of that.
-                invdesc.is_in_footprint = False
-                continue
-
-            for i in range(usage_domain.dim(dim_type.set)):
-                iname = usage_domain.get_dim_name(dim_type.set, i)
-                if iname in sweep_inames:
-                    usage_domain = usage_domain.set_dim_name(
-                            dim_type.set, i, iname+"'")
-
-            stor2sweep = build_per_access_storage_to_domain_map(invdesc,
-                    usage_domain, storage_axis_names, storage_axis_sources,
-                    prime_sweep_inames)
-
-            if isinstance(stor2sweep, isl.BasicMap):
-                stor2sweep = isl.Map.from_basic_map(stor2sweep)
-
-            stor2sweep = stor2sweep.intersect_range(usage_domain)
-
-            stor2sweep = move_to_par_from_out(stor2sweep,
-                    except_inames=frozenset(primed_sweep_inames))
-
-            s2s_domain = stor2sweep.domain()
-            s2s_domain, aligned_g_s2s_parm_dom = isl.align_two(
-                    s2s_domain, global_s2s_par_dom)
-
-            arg_restrictions = (
-                    aligned_g_s2s_parm_dom
-                    .eliminate(dim_type.set, 0,
-                        aligned_g_s2s_parm_dom.dim(dim_type.set))
-                    .remove_divs())
-
-            is_in_footprint = (arg_restrictions & s2s_domain).is_subset(
-                    aligned_g_s2s_parm_dom)
-
-            invdesc.is_in_footprint = is_in_footprint
-
-    # }}}
-
-    return global_stor2sweep
-
-# }}}
-
-
-# {{{ compute storage bounds
-
-def find_var_base_indices_and_shape_from_inames(
-        domain, inames, cache_manager, context=None):
-    base_indices_and_sizes = [
-            cache_manager.base_index_and_length(domain, iname, context)
-            for iname in inames]
-    return list(zip(*base_indices_and_sizes))
-
-
-def compute_bounds(kernel, domain, stor2sweep,
-        primed_sweep_inames, storage_axis_names):
-
-    bounds_footprint_map = move_to_par_from_out(
-            stor2sweep, except_inames=frozenset(primed_sweep_inames))
-
-    # compute bounds for each storage axis
-    storage_domain = bounds_footprint_map.domain().coalesce()
-
-    if not storage_domain.is_bounded():
-        raise RuntimeError("sweep did not result in a bounded storage domain")
-
-    return find_var_base_indices_and_shape_from_inames(
-            storage_domain, [saxis+"'" for saxis in storage_axis_names],
-            kernel.cache_manager, context=kernel.assumptions)
+            result.append(var(saxis_source))
 
-# }}}
-
-
-def get_access_info(kernel, domain,
-        storage_axis_names, storage_axis_sources,
-        sweep_inames, invocation_descriptors):
-
-    # {{{ duplicate sweep inames
-
-    # The duplication is necessary, otherwise the storage fetch
-    # inames remain weirdly tied to the original sweep inames.
-
-    primed_sweep_inames = [psin+"'" for psin in sweep_inames]
-    from loopy.isl_helpers import duplicate_axes
-    dup_sweep_index = domain.space.dim(dim_type.out)
-    domain_dup_sweep = duplicate_axes(
-            domain, sweep_inames,
-            primed_sweep_inames)
-
-    prime_sweep_inames = SubstitutionMapper(make_subst_func(
-        dict((sin, var(psin))
-            for sin, psin in zip(sweep_inames, primed_sweep_inames))))
-
-    # }}}
-
-    stor2sweep = build_global_storage_to_sweep_map(
-            kernel, invocation_descriptors,
-            domain_dup_sweep, dup_sweep_index,
-            storage_axis_names, storage_axis_sources,
-            sweep_inames, primed_sweep_inames, prime_sweep_inames)
-
-    storage_base_indices, storage_shape = compute_bounds(
-            kernel, domain, stor2sweep, primed_sweep_inames,
-            storage_axis_names)
-
-    # compute augmented domain
-
-    # {{{ filter out unit-length dimensions
-
-    non1_storage_axis_names = []
-    non1_storage_shape = []
-
-    for saxis, bi, l in zip(storage_axis_names, storage_base_indices, storage_shape):
-        if l != 1:
-            non1_storage_axis_names.append(saxis)
-            non1_storage_shape.append(l)
-
-    # }}}
-
-    # {{{ subtract off the base indices
-    # add the new, base-0 indices as new in dimensions
-
-    sp = stor2sweep.get_space()
-    stor_idx = sp.dim(dim_type.out)
-
-    n_stor = len(storage_axis_names)
-    nn1_stor = len(non1_storage_axis_names)
-
-    aug_domain = stor2sweep.move_dims(
-            dim_type.out, stor_idx,
-            dim_type.in_, 0,
-            n_stor).range()
-
-    # aug_domain space now:
-    # [domain](dup_sweep_index)[dup_sweep](stor_idx)[stor_axes']
-
-    aug_domain = aug_domain.insert_dims(dim_type.set, stor_idx, nn1_stor)
-    for i, name in enumerate(non1_storage_axis_names):
-        aug_domain = aug_domain.set_dim_name(dim_type.set, stor_idx+i, name)
-
-    # aug_domain space now:
-    # [domain](dup_sweep_index)[dup_sweep](stor_idx)[stor_axes'][n1_stor_axes]
-
-    from loopy.symbolic import aff_from_expr
-    for saxis, bi, s in zip(storage_axis_names, storage_base_indices, storage_shape):
-        if s != 1:
-            cns = isl.Constraint.equality_from_aff(
-                    aff_from_expr(aug_domain.get_space(),
-                        var(saxis) - (var(saxis+"'") - bi)))
-
-            aug_domain = aug_domain.add_constraint(cns)
-
-    # }}}
-
-    # eliminate (primed) storage axes with non-zero base indices
-    aug_domain = aug_domain.project_out(dim_type.set, stor_idx+nn1_stor, n_stor)
-
-    # eliminate duplicated sweep_inames
-    nsweep = len(sweep_inames)
-    aug_domain = aug_domain.project_out(dim_type.set, dup_sweep_index, nsweep)
-
-    return (non1_storage_axis_names, aug_domain,
-            storage_base_indices, non1_storage_shape)
+    return result
 
 
 def simplify_via_aff(expr):
@@ -369,7 +65,7 @@ def simplify_via_aff(expr):
         expr))
 
 
-class InvocationGatherer(ExpandingIdentityMapper):
+class RuleInvocationGatherer(ExpandingIdentityMapper):
     def __init__(self, kernel, subst_name, subst_tag, within):
         ExpandingIdentityMapper.__init__(self,
                 kernel.substitutions, kernel.get_var_name_generator())
@@ -383,7 +79,7 @@ class InvocationGatherer(ExpandingIdentityMapper):
         self.subst_tag = subst_tag
         self.within = within
 
-        self.invocation_descriptors = []
+        self.access_descriptors = []
 
     def map_substitution(self, name, tag, arguments, expn_state):
         process_me = name == self.subst_name
@@ -423,17 +119,21 @@ class InvocationGatherer(ExpandingIdentityMapper):
             return ExpandingIdentityMapper.map_substitution(
                     self, name, tag, arguments, expn_state)
 
-        self.invocation_descriptors.append(
-                InvocationDescriptor(
-                    args=[arg_context[arg_name] for arg_name in rule.arguments],
-                    expansion_stack=expn_state.stack))
+        args = [arg_context[arg_name] for arg_name in rule.arguments]
+
+        # Do not set expands_footprint here, it is set below.
+        self.access_descriptors.append(
+                RuleAccessDescriptor(
+                    identifier=access_descriptor_id(args, expn_state.stack),
+                    args=args,
+                    ))
 
         return 0  # exact value irrelevant
 
 
-class InvocationReplacer(ExpandingIdentityMapper):
+class RuleInvocationReplacer(ExpandingIdentityMapper):
     def __init__(self, kernel, subst_name, subst_tag, within,
-            invocation_descriptors,
+            access_descriptors, array_base_map,
             storage_axis_names, storage_axis_sources,
             storage_base_indices, non1_storage_axis_names,
             target_var_name):
@@ -449,7 +149,8 @@ class InvocationReplacer(ExpandingIdentityMapper):
         self.subst_tag = subst_tag
         self.within = within
 
-        self.invocation_descriptors = invocation_descriptors
+        self.access_descriptors = access_descriptors
+        self.array_base_map = array_base_map
 
         self.storage_axis_names = storage_axis_names
         self.storage_axis_sources = storage_axis_sources
@@ -477,21 +178,21 @@ class InvocationReplacer(ExpandingIdentityMapper):
             return ExpandingIdentityMapper.map_substitution(
                     self, name, tag, arguments, expn_state)
 
-        matching_invdesc = None
-        for invdesc in self.invocation_descriptors:
-            if invdesc.args == args and expn_state.stack:
+        matching_accdesc = None
+        for accdesc in self.access_descriptors:
+            if accdesc.identifier == access_descriptor_id(args, expn_state.stack):
                 # Could be more than one, that's fine.
-                matching_invdesc = invdesc
+                matching_accdesc = accdesc
                 break
 
-        assert matching_invdesc is not None
+        assert matching_accdesc is not None
 
-        invdesc = matching_invdesc
-        del matching_invdesc
+        accdesc = matching_accdesc
+        del matching_accdesc
 
         # }}}
 
-        if not invdesc.is_in_footprint:
+        if not self.array_base_map.is_access_descriptor_in_footprint(accdesc):
             return ExpandingIdentityMapper.map_substitution(
                     self, name, tag, arguments, expn_state)
 
@@ -646,11 +347,17 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
     from loopy.context_matching import parse_stack_match
     within = parse_stack_match(within)
 
+    from loopy.kernel.data import parse_tag
+    default_tag = parse_tag(default_tag)
+
+    subst = kernel.substitutions[subst_name]
+    c_subst_name = subst_name.replace(".", "_")
+
     # }}}
 
-    # {{{ process invocations in footprint generators, start invocation_descriptors
+    # {{{ process invocations in footprint generators, start access_descriptors
 
-    invocation_descriptors = []
+    access_descriptors = []
 
     if footprint_generators:
         from pymbolic.primitives import Variable, Call
@@ -664,35 +371,29 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
                 raise ValueError("footprint generator must "
                         "be substitution rule invocation")
 
-            invocation_descriptors.append(
-                    InvocationDescriptor(args=args,
+            access_descriptors.append(
+                    RuleAccessDescriptor(
+                        identifier=access_descriptor_id(args, None),
                         expands_footprint=True,
-                        expansion_stack=None))
+                        args=args
+                        ))
 
     # }}}
 
-    c_subst_name = subst_name.replace(".", "_")
+    # {{{ gather up invocations in kernel code, finish access_descriptors
 
-    from loopy.kernel.data import parse_tag
-    default_tag = parse_tag(default_tag)
-
-    subst = kernel.substitutions[subst_name]
-    arg_names = subst.arguments
-
-    # {{{ gather up invocations in kernel code, finish invocation_descriptors
-
-    invg = InvocationGatherer(kernel, subst_name, subst_tag, within)
+    invg = RuleInvocationGatherer(kernel, subst_name, subst_tag, within)
 
     import loopy as lp
     for insn in kernel.instructions:
         if isinstance(insn, lp.ExpressionInstruction):
             invg(insn.expression, insn.id, insn.tags)
 
-    for invdesc in invg.invocation_descriptors:
-        invocation_descriptors.append(
-                invdesc.copy(expands_footprint=footprint_generators is None))
+    for accdesc in invg.access_descriptors:
+        access_descriptors.append(
+                accdesc.copy(expands_footprint=footprint_generators is None))
 
-    if not invocation_descriptors:
+    if not access_descriptors:
         raise RuntimeError("no invocations of '%s' found" % subst_name)
 
     # }}}
@@ -704,9 +405,9 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
 
     expanding_usage_arg_deps = set()
 
-    for invdesc in invocation_descriptors:
-        if invdesc.expands_footprint:
-            for arg in invdesc.args:
+    for accdesc in access_descriptors:
+        if accdesc.expands_footprint:
+            for arg in accdesc.args:
                 expanding_usage_arg_deps.update(
                         get_dependencies(arg) & kernel.all_inames())
 
@@ -735,7 +436,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
     if storage_axes is None:
         storage_axes = (
                 list(extra_storage_axes)
-                + list(range(len(arg_names))))
+                + list(range(len(subst.arguments))))
 
     expr_subst_dict = {}
 
@@ -784,6 +485,16 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
 
     # }}}
 
+    # {{{ fill out access_descriptors[...].storage_axis_exprs
+
+    access_descriptors = [
+            accdesc.copy(
+                storage_axis_exprs=storage_axis_exprs(
+                    storage_axis_sources, accdesc.args))
+            for accdesc in access_descriptors]
+
+    # }}}
+
     expanding_inames = sweep_inames_set | frozenset(expanding_usage_arg_deps)
     assert expanding_inames <= kernel.all_inames()
 
@@ -805,37 +516,26 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
 
         # }}}
 
-        (non1_storage_axis_names, new_domain,
-                storage_base_indices, non1_storage_shape) = \
-                        get_access_info(kernel, domch.domain,
-                                storage_axis_names, storage_axis_sources,
-                                sweep_inames, invocation_descriptors)
+        abm = ArrayToBufferMap(kernel, domch.domain, sweep_inames,
+                access_descriptors, len(storage_axis_names))
 
-        from loopy.isl_helpers import convexify, boxify
-        if fetch_bounding_box:
-            new_domain = boxify(kernel.cache_manager, new_domain,
-                    non1_storage_axis_names, kernel.assumptions)
-        else:
-            new_domain = convexify(new_domain)
-
-        for saxis in storage_axis_names:
-            if saxis not in non1_storage_axis_names:
+        non1_storage_axis_names = []
+        for i, saxis in enumerate(storage_axis_names):
+            if abm.non1_storage_axis_flags[i]:
+                non1_storage_axis_names.append(saxis)
+            else:
                 del new_iname_to_tag[saxis]
 
-        new_kernel_domains = domch.get_domains_with(new_domain)
+        new_kernel_domains = domch.get_domains_with(
+                abm.augment_domain_with_sweep(
+                    domch.domain, non1_storage_axis_names,
+                    boxify_sweep=fetch_bounding_box))
+
     else:
         # leave kernel domains unchanged
         new_kernel_domains = kernel.domains
 
-        non1_storage_axis_names = ()
-        storage_base_indices = ()
-        non1_storage_shape = ()
-
-        # no index dependencies--every reference to the subst rule
-        # is necessarily in the footprint.
-
-        for invdesc in invocation_descriptors:
-            invdesc.is_in_footprint = True
+        abm = NoOpArrayToBufferMap()
 
     # {{{ set up compute insn
 
@@ -856,7 +556,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
     compute_expr = (SubstitutionMapper(
         make_subst_func(dict(
             (arg_name, zero_length_1_arg(arg_name)+bi)
-            for arg_name, bi in zip(storage_axis_names, storage_base_indices)
+            for arg_name, bi in zip(storage_axis_names, abm.storage_base_indices)
             )))
         (compute_expr))
 
@@ -870,10 +570,10 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
 
     # {{{ substitute rule into expressions in kernel (if within footprint)
 
-    invr = InvocationReplacer(kernel, subst_name, subst_tag, within,
-            invocation_descriptors,
+    invr = RuleInvocationReplacer(kernel, subst_name, subst_tag, within,
+            access_descriptors, abm,
             storage_axis_names, storage_axis_sources,
-            storage_base_indices, non1_storage_axis_names,
+            abm.storage_base_indices, non1_storage_axis_names,
             target_var_name)
 
     kernel = invr.map_kernel(kernel)
@@ -897,8 +597,8 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
     temp_var = lp.TemporaryVariable(
             name=target_var_name,
             dtype=dtype,
-            base_indices=(0,)*len(non1_storage_shape),
-            shape=tuple(non1_storage_shape),
+            base_indices=(0,)*len(abm.non1_storage_shape),
+            shape=tuple(abm.non1_storage_shape),
             is_local=temporary_is_local)
 
     new_temporary_variables[target_var_name] = temp_var
-- 
GitLab