From 05f97b438b8893292c68e1bfea960bce1571bd6e Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 20 May 2012 16:55:13 -0400
Subject: [PATCH] Initial work towards making loopy work on multiple
 (potentially nested) domains.

---
 MEMO                         |  15 ++
 loopy/codegen/__init__.py    |  12 +-
 loopy/codegen/control.py     |  14 +-
 loopy/codegen/instruction.py |   3 +-
 loopy/codegen/loop.py        |  17 +-
 loopy/creation.py            |  49 +++--
 loopy/isl_helpers.py         |   6 +
 loopy/kernel.py              | 342 +++++++++++++++++++++++++++--------
 loopy/schedule.py            |  34 +++-
 test/test_loopy.py           | 148 +++++++++++++++
 10 files changed, 536 insertions(+), 104 deletions(-)

diff --git a/MEMO b/MEMO
index dc1385727..4ba101b76 100644
--- a/MEMO
+++ b/MEMO
@@ -41,13 +41,21 @@ Things to consider
 To-do
 ^^^^^
 
+- Clean up loopy.kernel.
+
 - Group instructions by dependency/inames for scheduling, to
   increase sched. scalability
 
 - Multi-domain
+  - Incorporate loop-bound-mediated iname dependencies into domain
+    parenthood.
+
+  - Reenable codegen sanity check.
 
 - Kernel splitting (via what variables get computed in a kernel)
 
+- test_loopy.py: test_empty_reduction
+
 - What if no universally valid precompute base index expression is found?
   (test_intel_matrix_mul with n = 6*16, e.g.?)
 
@@ -120,6 +128,13 @@ Future ideas
 Dealt with
 ^^^^^^^^^^
 
+- relating to Multi-Domain
+  - Make sure that variables that enter into loop bounds are only written
+    exactly once. [DONE]
+
+  - Make sure that loop bound writes are scheduled before the relevant
+    loops. [DONE]
+
 - add_prefetch tagging
 
 - nbody GPU
diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index b969d9896..537e821aa 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -122,9 +122,10 @@ class CodeGenerationState(object):
 
         self.c_code_mapper = c_code_mapper
 
-    def intersect(self, set):
+    def intersect(self, other):
+        new_impl, new_other = isl.align_two(self.implemented_domain, other)
         return CodeGenerationState(
-                self.implemented_domain & set,
+                new_impl & new_other,
                 self.c_code_mapper)
 
     def fix(self, iname, aff, space):
@@ -289,8 +290,7 @@ def generate_code(kernel, with_annotation=False,
 
     # }}}
 
-    from islpy import align_spaces
-    initial_implemented_domain = align_spaces(kernel.assumptions, kernel.domain)
+    initial_implemented_domain = kernel.assumptions
     codegen_state = CodeGenerationState(initial_implemented_domain, c_code_mapper=ccm)
 
     from loopy.codegen.loop import set_up_hw_parallel_loops
@@ -331,8 +331,8 @@ def generate_code(kernel, with_annotation=False,
                 )
             """))
 
-    from loopy.check import check_implemented_domains
-    assert check_implemented_domains(kernel, gen_code.implemented_domains)
+    #from loopy.check import check_implemented_domains
+    #assert check_implemented_domains(kernel, gen_code.implemented_domains)
 
     # {{{ handle preambles
 
diff --git a/loopy/codegen/control.py b/loopy/codegen/control.py
index a1a213add..1bd89ed69 100644
--- a/loopy/codegen/control.py
+++ b/loopy/codegen/control.py
@@ -159,14 +159,20 @@ def build_loop_nest(kernel, sched_index, codegen_state):
     from pytools import memoize_method
 
     class BoundsCheckCache:
-        def __init__(self, domain, impl_domain):
-            self.domain = domain
+        def __init__(self, kernel, impl_domain):
+            self.kernel = kernel
             self.impl_domain = impl_domain
 
         @memoize_method
         def __call__(self, check_inames):
+            if not check_inames:
+                return []
+
+            domain = isl.align_spaces(
+                    self.kernel.get_inames_domain(check_inames),
+                    self.impl_domain, obj_bigger_ok=True)
             from loopy.codegen.bounds import generate_bounds_checks
-            return generate_bounds_checks(self.domain,
+            return generate_bounds_checks(domain,
                     check_inames, self.impl_domain)
 
     def build_insn_group(sched_indices_and_cond_inames, codegen_state, done_group_lengths=set()):
@@ -183,7 +189,7 @@ def build_loop_nest(kernel, sched_index, codegen_state):
         # Keep growing schedule item group as long as group fulfills minimum
         # size requirement.
 
-        bounds_check_cache = BoundsCheckCache(kernel.domain, codegen_state.implemented_domain)
+        bounds_check_cache = BoundsCheckCache(kernel, codegen_state.implemented_domain)
 
         current_iname_set = cond_inames
 
diff --git a/loopy/codegen/instruction.py b/loopy/codegen/instruction.py
index 46a29fc50..1cf5a8561 100644
--- a/loopy/codegen/instruction.py
+++ b/loopy/codegen/instruction.py
@@ -18,8 +18,9 @@ def generate_instruction_code(kernel, insn, codegen_state):
     from cgen import Assign
     insn_code = Assign(ccm(insn.assignee), ccm(expr))
     from loopy.codegen.bounds import wrap_in_bounds_checks
+    insn_inames = kernel.insn_inames(insn)
     insn_code, impl_domain = wrap_in_bounds_checks(
-            ccm, kernel.domain, kernel.insn_inames(insn),
+            ccm, kernel.get_inames_domain(insn_inames), insn_inames,
             codegen_state.implemented_domain,
             insn_code)
 
diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py
index d1ffea7a3..0267e7f20 100644
--- a/loopy/codegen/loop.py
+++ b/loopy/codegen/loop.py
@@ -9,9 +9,11 @@ from loopy.codegen.control import build_loop_nest
 
 
 def get_simple_loop_bounds(kernel, sched_index, iname, implemented_domain):
+    iname_domain = kernel.get_inames_domain(iname)
+
     from loopy.codegen.bounds import get_bounds_constraints, get_defined_inames
     lower_constraints_orig, upper_constraints_orig, equality_constraints_orig = \
-            get_bounds_constraints(kernel.domain, iname,
+            get_bounds_constraints(iname_domain, iname,
                     frozenset([iname])
                     | frozenset(get_defined_inames(kernel, sched_index+1)),
                     allow_parameters=True)
@@ -35,9 +37,9 @@ def get_slab_decomposition(kernel, iname, sched_index, codegen_state):
     lb_cns_orig, ub_cns_orig = get_simple_loop_bounds(kernel, sched_index, iname,
             codegen_state.implemented_domain)
 
-    lower_incr, upper_incr = kernel.iname_slab_increments.get(iname, (0, 0))
+    space = lb_cns_orig.space
 
-    iname_tp, iname_idx = kernel.iname_to_dim[iname]
+    lower_incr, upper_incr = kernel.iname_slab_increments.get(iname, (0, 0))
 
     if lower_incr or upper_incr:
         bounds = kernel.get_iname_bounds(iname)
@@ -60,9 +62,10 @@ def get_slab_decomposition(kernel, iname, sched_index, codegen_state):
 
         from loopy.isl_helpers import iname_rel_aff
 
+
         if lower_incr:
             assert lower_incr > 0
-            lower_slab = ("initial", isl.BasicSet.universe(kernel.space)
+            lower_slab = ("initial", isl.BasicSet.universe(space)
                     .add_constraint(lb_cns_orig)
                     .add_constraint(ub_cns_orig)
                     .add_constraint(
@@ -78,7 +81,7 @@ def get_slab_decomposition(kernel, iname, sched_index, codegen_state):
 
         if upper_incr:
             assert upper_incr > 0
-            upper_slab = ("final", isl.BasicSet.universe(kernel.space)
+            upper_slab = ("final", isl.BasicSet.universe(space)
                     .add_constraint(lb_cns_orig)
                     .add_constraint(ub_cns_orig)
                     .add_constraint(
@@ -98,7 +101,7 @@ def get_slab_decomposition(kernel, iname, sched_index, codegen_state):
             slabs.append(lower_slab)
         slabs.append((
             ("bulk",
-                (isl.BasicSet.universe(kernel.space)
+                (isl.BasicSet.universe(space)
                     .add_constraint(lower_bulk_bound)
                     .add_constraint(upper_bulk_bound)))))
         if upper_slab:
@@ -108,7 +111,7 @@ def get_slab_decomposition(kernel, iname, sched_index, codegen_state):
 
     else:
         return [("bulk",
-            (isl.BasicSet.universe(kernel.space)
+            (isl.BasicSet.universe(space)
             .add_constraint(lb_cns_orig)
             .add_constraint(ub_cns_orig)))]
 
diff --git a/loopy/creation.py b/loopy/creation.py
index 28be2b095..f2a65064e 100644
--- a/loopy/creation.py
+++ b/loopy/creation.py
@@ -5,7 +5,7 @@ from loopy.symbolic import IdentityMapper
 
 # {{{ sanity checking
 
-def check_kernel(knl):
+def check_for_nonexistent_iname_deps(knl):
     for insn in knl.instructions:
         if not set(insn.forced_iname_deps) <= knl.all_inames():
             raise ValueError("In instruction '%s': "
@@ -15,6 +15,23 @@ def check_kernel(knl):
                         ",".join(
                             set(insn.forced_iname_deps)-knl.all_inames())))
 
+def check_for_multiple_writes_to_loop_bounds(knl):
+    from isl import dim_type
+
+    domain_parameters = set()
+    for dom in knl.domains:
+        domain_parameters.update(dom.get_space().get_var_dict(dim_type.param))
+
+    temp_var_domain_parameters = domain_parameters & set(
+            knl.temporary_variables)
+
+    wmap = knl.writer_map()
+    for tvpar in temp_var_domain_parameters:
+        par_writers = wmap[tvpar]
+        if len(par_writers) != 1:
+            raise RuntimeError("there must be exactly one write to data-dependent "
+                    "domain parameter '%s' (found %d)" % (tvpar, len(par_writers)))
+
 # }}}
 
 # {{{ expand common subexpressions into assignments
@@ -99,9 +116,7 @@ def create_temporaries(knl):
     new_temp_vars = knl.temporary_variables.copy()
 
     for insn in knl.instructions:
-        from loopy.kernel import (
-                find_var_base_indices_and_shape_from_inames,
-                TemporaryVariable)
+        from loopy.kernel import TemporaryVariable
 
         if insn.temp_var_type is not None:
             assignee_name = insn.get_assignee_var_name()
@@ -120,8 +135,8 @@ def create_temporaries(knl):
                 assignee_indices.append(index_expr.name)
 
             base_indices, shape = \
-                    find_var_base_indices_and_shape_from_inames(
-                            knl.domain, assignee_indices, knl.cache_manager)
+                    knl.find_var_base_indices_and_shape_from_inames(
+                            assignee_indices, knl.cache_manager)
 
             new_temp_vars[assignee_name] = TemporaryVariable(
                     name=assignee_name,
@@ -187,7 +202,7 @@ def duplicate_reduction_inames(kernel):
 
     # }}}
 
-    new_domain = kernel.domain
+    new_domains = kernel.domains
     new_insns = []
 
     new_iname_to_tag = kernel.iname_to_tag.copy()
@@ -203,13 +218,13 @@ def duplicate_reduction_inames(kernel):
 
         from loopy.isl_helpers import duplicate_axes
         for old, new in zip(old_insn_inames, new_insn_inames):
-            new_domain = duplicate_axes(new_domain, [old], [new])
+            new_domains = duplicate_axes(new_domains, [old], [new])
             if old in kernel.iname_to_tag:
                 new_iname_to_tag[new] = kernel.iname_to_tag[old]
 
     return kernel.copy(
             instructions=new_insns,
-            domain=new_domain,
+            domains=new_domains,
             iname_to_tag=new_iname_to_tag)
 
 # }}}
@@ -218,7 +233,7 @@ def duplicate_reduction_inames(kernel):
 
 def duplicate_inames(knl):
     new_insns = []
-    new_domain = knl.domain
+    new_domains = knl.domains
     new_iname_to_tag = knl.iname_to_tag.copy()
 
     newly_created_vars = set()
@@ -256,7 +271,7 @@ def duplicate_inames(knl):
             newly_created_vars.update(new_inames)
 
             from loopy.isl_helpers import duplicate_axes
-            new_domain = duplicate_axes(new_domain, inames_to_duplicate, new_inames)
+            new_domains = duplicate_axes(new_domains, inames_to_duplicate, new_inames)
 
             from loopy.symbolic import SubstitutionMapper
             from pymbolic.mapper.substitutor import make_subst_func
@@ -284,7 +299,7 @@ def duplicate_inames(knl):
 
     return knl.copy(
             instructions=new_insns,
-            domain=new_domain,
+            domains=new_domains,
             iname_to_tag=new_iname_to_tag)
 # }}}
 
@@ -304,7 +319,7 @@ def make_kernel(*args, **kwargs):
             knl.iname_to_tag_requests).copy(
                     iname_to_tag_requests=[])
 
-    check_kernel(knl)
+    check_for_nonexistent_iname_deps(knl)
 
     knl = create_temporaries(knl)
     knl = duplicate_reduction_inames(knl)
@@ -320,6 +335,14 @@ def make_kernel(*args, **kwargs):
 
     knl = expand_cses(knl)
 
+    # -------------------------------------------------------------------------
+    # Ordering dependency:
+    # -------------------------------------------------------------------------
+    # Must create temporary before checking for writes to temporary variables
+    # that are domain parameters.
+    # -------------------------------------------------------------------------
+    check_for_multiple_writes_to_loop_bounds(knl)
+
     return knl
 
 # }}}
diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py
index a509b8a30..3eb1ccba2 100644
--- a/loopy/isl_helpers.py
+++ b/loopy/isl_helpers.py
@@ -200,6 +200,11 @@ def static_value_of_pw_aff(pw_aff, constants_only, context=None):
 
 
 def duplicate_axes(isl_obj, duplicate_inames, new_inames):
+    if isinstance(isl_obj, list):
+        return [
+                duplicate_axes(i, duplicate_inames, new_inames)
+                for i in isl_obj]
+
     if not duplicate_inames:
         return isl_obj
 
@@ -244,6 +249,7 @@ def duplicate_axes(isl_obj, duplicate_inames, new_inames):
 
 
 
+
 def is_nonnegative(expr, over_set):
     space = over_set.get_space()
     from loopy.symbolic import aff_from_expr
diff --git a/loopy/kernel.py b/loopy/kernel.py
index 305098315..b8b1b4b88 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -525,11 +525,58 @@ def _generate_unique_possibilities(prefix):
         yield "%s_%d" % (prefix, try_num)
         try_num += 1
 
+_IDENTIFIER_RE = re.compile(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\b")
+
+def _gather_identifiers(s):
+    return set(_IDENTIFIER_RE.findall(s))
+
+def _parse_domains(ctx, args_and_vars, domains):
+    result = []
+    available_parameters = args_and_vars.copy()
+    used_inames = set()
+
+    for dom in domains:
+        if isinstance(dom, str):
+            if not dom.lstrip().startswith("["):
+                # i.e. if no parameters are already given
+                ids = _gather_identifiers(dom)
+                parameters = ids & available_parameters
+                dom = "[%s] -> %s" % (",".join(parameters), dom)
+
+            try:
+                dom = isl.Set.read_from_str(ctx, dom)
+            except:
+                print "failed to parse domain '%s'" % dom
+                raise
+        else:
+            assert isinstance(dom, (isl.Set, isl.BasicSet))
+            # assert dom.get_ctx() == ctx
+
+        for i_iname in xrange(dom.dim(dim_type.set)):
+            iname = dom.get_dim_name(dim_type.set, i_iname)
+
+            if iname is None:
+                raise RuntimeError("domain '%s' provided no iname at index "
+                        "%d (redefined iname?)" % (dom, i_iname))
+
+            if iname in used_inames:
+                raise RuntimeError("domain '%s' redefines iname '%s' "
+                        "that is part of a previous domain" % (dom, iname))
+
+            used_inames.add(iname)
+            available_parameters.add(iname)
+
+        result.append(dom)
+
+    return result
+
+
+
 
 class LoopKernel(Record):
     """
     :ivar device: :class:`pyopencl.Device`
-    :ivar domain: :class:`islpy.BasicSet`
+    :ivar domains: :class:`islpy.BasicSet`
     :ivar instructions:
     :ivar args:
     :ivar schedule:
@@ -580,7 +627,7 @@ class LoopKernel(Record):
     :ivar iname_to_tag_requests:
     """
 
-    def __init__(self, device, domain, instructions, args=None, schedule=None,
+    def __init__(self, device, domains, instructions, args=[], schedule=None,
             name="loopy_kernel",
             preambles=[],
             preamble_generators=[default_preamble_generator],
@@ -614,12 +661,10 @@ class LoopKernel(Record):
         if cache_manager is None:
             cache_manager = SetOperationCacheManager()
 
-        if isinstance(domain, str):
-            ctx = isl.Context()
-            domain = isl.Set.read_from_str(ctx, domain)
-
         iname_to_tag_requests = {}
 
+        # {{{ parse instructions
+
         INAME_ENTRY_RE = re.compile(
                 r"^\s*(?P<iname>\w+)\s*(?:\:\s*(?P<tag>[\w.]+))?\s*$")
         INSN_RE = re.compile(
@@ -720,9 +765,9 @@ class LoopKernel(Record):
                     raise RuntimeError("left hand side of assignment '%s' must "
                             "be variable or subscript" % lhs)
 
-                insns.append(
+                parsed_instructions.append(
                         Instruction(
-                            id=self.make_unique_instruction_id(insns, based_on=label),
+                            id=self.make_unique_instruction_id(parsed_instructions, based_on=label),
                             insn_deps=insn_deps,
                             forced_iname_deps=forced_iname_deps,
                             assignee=lhs, expression=rhs,
@@ -756,8 +801,8 @@ class LoopKernel(Record):
         def parse_if_necessary(insn):
             if isinstance(insn, Instruction):
                 if insn.id is None:
-                    insn = insn.copy(id=self.make_unique_instruction_id(insns))
-                insns.append(insn)
+                    insn = insn.copy(id=self.make_unique_instruction_id(parsed_instructions))
+                parsed_instructions.append(insn)
                 return
 
             if not isinstance(insn, str):
@@ -768,10 +813,7 @@ class LoopKernel(Record):
             for insn in expand_defines(insn, defines):
                 parse_insn(insn)
 
-
-        # }}}
-
-        insns = []
+        parsed_instructions = []
 
         substitutions = substitutions.copy()
 
@@ -779,23 +821,57 @@ class LoopKernel(Record):
             # must construct list one-by-one to facilitate unique id generation
             parse_if_necessary(insn)
 
-        if len(set(insn.id for insn in insns)) != len(insns):
+        if len(set(insn.id for insn in parsed_instructions)) != len(parsed_instructions):
             raise RuntimeError("instruction ids do not appear to be unique")
 
+        # }}}
+
+        # Ordering dependency:
+        # Domain construction needs to know what temporary variables are
+        # available. That information can only be obtained once instructions
+        # are parsed.
+
+        # {{{ construct domains
+
+        if isinstance(domains, str):
+            domains = [domains]
+
+        ctx = isl.Context()
+        scalar_arg_names = set(arg.name for arg in args if isinstance(arg, ScalarArg))
+        var_names = (
+                set(temporary_variables)
+                | set(insn.get_assignee_var_name()
+                    for insn in parsed_instructions
+                    if insn.temp_var_type is not None))
+        domains = _parse_domains(ctx, scalar_arg_names | var_names, domains)
+
+        # }}}
+
+        # {{{ process assumptions
+
         if assumptions is None:
-            assumptions_space = domain.get_space().params()
+            assumptions_space = domains[0].get_space()
             assumptions = isl.Set.universe(assumptions_space)
-
         elif isinstance(assumptions, str):
-            s = domain.get_space()
-            assumptions = isl.BasicSet.read_from_str(domain.get_ctx(),
-                    "[%s] -> { : %s}"
-                    % (",".join(s.get_dim_name(dim_type.param, i)
-                        for i in range(s.dim(dim_type.param))),
-                        assumptions))
+            all_inames = set()
+            all_params = set()
+            for dom in domains:
+                all_inames.update(dom.get_var_names(dim_type.set))
+                all_params.update(dom.get_var_names(dim_type.param))
+
+            domain_parameters = all_params-all_inames
+
+            assumptions_set_str = "[%s] -> { : %s}" \
+                    % (",".join(s for s in domain_parameters),
+                        assumptions)
+            assumptions = isl.BasicSet.read_from_str(domains[0].get_ctx(),
+                    assumptions_set_str)
+
+        # }}}
 
         Record.__init__(self,
-                device=device,  domain=domain, instructions=insns,
+                device=device, domains=domains,
+                instructions=parsed_instructions,
                 args=args,
                 schedule=schedule,
                 name=name,
@@ -831,6 +907,8 @@ class LoopKernel(Record):
 
     # }}}
 
+    # {{{ unique ids
+
     def make_unique_instruction_id(self, insns=None, based_on="insn", extra_used_ids=set()):
         if insns is None:
             insns = self.instructions
@@ -841,16 +919,132 @@ class LoopKernel(Record):
             if id_str not in used_ids:
                 return id_str
 
+    # }}}
+
+    # {{{ name listing
+
     @memoize_method
     def all_inames(self):
-        from islpy import dim_type
-        return set(self.space.get_var_dict(dim_type.set).iterkeys())
+        result = set()
+        for dom in self.domains:
+            result.update(dom.get_var_names(dim_type.set))
+        return result
 
     @memoize_method
     def non_iname_variable_names(self):
         return (set(self.arg_dict.iterkeys())
                 | set(self.temporary_variables.iterkeys()))
 
+    # }}}
+
+    # {{{ domain handling
+
+    @memoize_method
+    def parents_per_domain(self):
+        """Return a list corresponding to self.domains (by index)
+        containing domain indices which are nested around this
+        domain.
+
+        Each domains nest list walks from the leaves of the nesting
+        tree to the root.
+        """
+
+        domain_parents = []
+        iname_set_stack = []
+        result = []
+
+        for dom in self.domains:
+            parameters = set(dom.get_var_names(dim_type.param))
+            inames = set(dom.get_var_names(dim_type.set))
+
+            discard_level_count = 0
+            while discard_level_count < len(iname_set_stack):
+                last_inames = iname_set_stack[-1-discard_level_count]
+
+                if last_inames & parameters:
+                    break
+                else:
+                    discard_level_count += 1
+
+            if discard_level_count:
+                iname_set_stack = iname_set_stack[:-discard_level_count]
+
+            if domain_parents:
+                parent = len(result)-1
+            else:
+                parent = None
+
+            for i in range(discard_level_count):
+                assert parent is not None
+                parent = domain_parents[parent]
+
+            # found this domain's parent
+            domain_parents.append(parent)
+
+            # keep walking up tree to make result
+            dom_result = []
+            while parent is not None:
+                dom_result.insert(0, parent)
+                parent = domain_parents[parent]
+
+            result.append(dom_result)
+
+            if iname_set_stack:
+                parent_inames = iname_set_stack[-1]
+            else:
+                parent_inames = set()
+            iname_set_stack.append(parent_inames | inames)
+
+        return result
+
+    @memoize_method
+    def _get_home_domain_map(self):
+        return dict(
+                (iname, i_domain)
+                for i_domain, dom in enumerate(self.domains)
+                for iname in dom.get_var_names(dim_type.set))
+
+    def get_home_domain_index(self, iname):
+        return self._get_home_domain_map()[iname]
+
+    @memoize_method
+    def combine_domains(self, domains):
+        assert isinstance(domains, frozenset) # for caching
+
+        result = None
+        assert domains
+        for dom_index in domains:
+            dom = self.domains[dom_index]
+            if result is None:
+                result = dom
+            else:
+                aligned_result, aligned_dom = isl.align_two(result, dom)
+                result = aligned_result & aligned_dom
+
+        return result
+
+    def get_effective_domain(self, domain_index):
+        return self.combine_domains(
+                frozenset([domain_index]
+                    + self.get_parents_per_domain()[domain_index]))
+
+    def get_inames_domain(self, inames):
+        if isinstance(inames, str):
+            inames = [inames]
+
+        hdm = self._get_home_domain_map()
+        ppd = self.parents_per_domain()
+
+        domain_indices = set()
+        for iname in inames:
+            home_domain_index = hdm[iname]
+            domain_indices.add(home_domain_index)
+            domain_indices.update(ppd[home_domain_index])
+
+        return self.combine_domains(frozenset(domain_indices))
+
+    # }}}
+
     @memoize_method
     def all_insn_inames(self):
         """Return a mapping from instruction ids to inames inside which
@@ -945,11 +1139,6 @@ class LoopKernel(Record):
 
         return result
 
-    @property
-    @memoize_method
-    def iname_to_dim(self):
-        return self.domain.get_space().get_var_dict()
-
     @memoize_method
     def get_written_variables(self):
         return set(
@@ -962,7 +1151,7 @@ class LoopKernel(Record):
                 set(self.temporary_variables.iterkeys())
                 | set(self.substitutions.iterkeys())
                 | set(arg.name for arg in self.args)
-                | set(self.iname_to_dim.keys()))
+                | set(self.all_inames()))
 
     def make_unique_var_name(self, based_on="var", extra_used_vars=set()):
         used_vars = self.all_variable_names() | extra_used_vars
@@ -989,11 +1178,6 @@ class LoopKernel(Record):
     def id_to_insn(self):
         return dict((insn.id, insn) for insn in self.instructions)
 
-    @property
-    @memoize_method
-    def space(self):
-        return self.domain.get_space()
-
     @property
     @memoize_method
     def arg_dict(self):
@@ -1005,8 +1189,9 @@ class LoopKernel(Record):
         if self.args is None:
             return []
         else:
-            loop_arg_names = [self.space.get_dim_name(dim_type.param, i)
-                    for i in range(self.space.dim(dim_type.param))]
+            from pytools import flatten
+            loop_arg_names = list(flatten(dom.get_var_names(dim_type.param)
+                    for dom in self.domains))
             return [arg.name for arg in self.args if isinstance(arg, ScalarArg)
                     if arg.name in loop_arg_names]
 
@@ -1037,6 +1222,29 @@ class LoopKernel(Record):
                 upper_bound_pw_aff=upper_bound_pw_aff,
                 size=size)
 
+    def find_var_base_indices_and_shape_from_inames(
+            self, inames, cache_manager, context=None):
+        base_indices = []
+        shape = []
+
+        for iname in inames:
+            domain = self.get_inames_domain(iname)
+            iname_to_dim = domain.space.get_var_dict()
+            lower_bound_pw_aff = cache_manager.dim_min(domain, iname_to_dim[iname][1])
+            upper_bound_pw_aff = cache_manager.dim_max(domain, iname_to_dim[iname][1])
+
+            from loopy.isl_helpers import static_max_of_pw_aff, static_value_of_pw_aff
+            from loopy.symbolic import pw_aff_to_expr
+
+            shape.append(pw_aff_to_expr(static_max_of_pw_aff(
+                    upper_bound_pw_aff - lower_bound_pw_aff + 1, constants_only=True,
+                    context=context)))
+            base_indices.append(pw_aff_to_expr(
+                static_value_of_pw_aff(lower_bound_pw_aff, constants_only=False,
+                    context=context)))
+
+        return base_indices, shape
+
     @memoize_method
     def get_constant_iname_length(self, iname):
         from loopy.isl_helpers import static_max_of_pw_aff
@@ -1103,8 +1311,6 @@ class LoopKernel(Record):
             size_list = []
             sorted_axes = sorted(size_dict.iterkeys())
 
-            zero_aff = isl.Aff.zero_on_domain(self.space.params())
-
             while sorted_axes or forced_sizes:
                 if sorted_axes:
                     cur_axis = sorted_axes.pop(0)
@@ -1113,16 +1319,14 @@ class LoopKernel(Record):
 
                 if len(size_list) in forced_sizes:
                     size_list.append(
-                            isl.PwAff.from_aff(
-                                zero_aff + forced_sizes.pop(len(size_list))))
+                           forced_sizes.pop(len(size_list)))
                     continue
 
                 assert cur_axis is not None
 
-                while cur_axis > len(size_list):
+                if cur_axis > len(size_list):
                     raise RuntimeError("%s axis %d unused" % (
                         which, len(size_list)))
-                    size_list.append(zero_aff + 1)
 
                 size_list.append(size_dict[cur_axis])
 
@@ -1161,8 +1365,12 @@ class LoopKernel(Record):
         always nested around them.
         """
         result = {}
+
+        # {{{ examine instructions
+
         iname_to_insns = self.iname_to_insns()
 
+        # examine pairs of all inames--O(n**2), I know.
         for inner_iname in self.all_inames():
             result[inner_iname] = set()
             for outer_iname in self.all_inames():
@@ -1172,6 +1380,20 @@ class LoopKernel(Record):
                 if iname_to_insns[inner_iname] < iname_to_insns[outer_iname]:
                     result[inner_iname].add(outer_iname)
 
+        # }}}
+
+        # {{{ examine domains
+
+        for i_dom, (dom, parent_indices) in enumerate(
+                zip(self.domains, self.parents_per_domain())):
+            for parent_index in parent_indices:
+                for iname in dom.get_var_names(dim_type.set):
+                    parent = self.domains[parent_index]
+                    for parent_iname in parent.get_var_names(dim_type.set):
+                        result[iname].add(parent_iname)
+
+        # }}}
+
         return result
 
     def map_expressions(self, func, exclude_instructions=False):
@@ -1197,8 +1419,9 @@ class LoopKernel(Record):
             lines.append("%s: %s" % (iname, self.iname_to_tag.get(iname)))
 
         lines.append(sep)
-        lines.append("DOMAIN:")
-        lines.append(str(self.domain))
+        lines.append("DOMAINS:")
+        for dom, parents in zip(self.domains, self.parents_per_domain()):
+            lines.append(str(dom))
 
         if self.substitutions:
             lines.append(sep)
@@ -1306,31 +1529,6 @@ def find_all_insn_inames(instructions, all_inames,
 
 
 
-def find_var_base_indices_and_shape_from_inames(
-        domain, inames, cache_manager, context=None):
-    base_indices = []
-    shape = []
-
-    iname_to_dim = domain.get_space().get_var_dict()
-    for iname in inames:
-        lower_bound_pw_aff = cache_manager.dim_min(domain, iname_to_dim[iname][1])
-        upper_bound_pw_aff = cache_manager.dim_max(domain, iname_to_dim[iname][1])
-
-        from loopy.isl_helpers import static_max_of_pw_aff, static_value_of_pw_aff
-        from loopy.symbolic import pw_aff_to_expr
-
-        shape.append(pw_aff_to_expr(static_max_of_pw_aff(
-                upper_bound_pw_aff - lower_bound_pw_aff + 1, constants_only=True,
-                context=context)))
-        base_indices.append(pw_aff_to_expr(
-            static_value_of_pw_aff(lower_bound_pw_aff, constants_only=False,
-                context=context)))
-
-    return base_indices, shape
-
-
-
-
 def get_dot_dependency_graph(kernel, iname_cluster=False, iname_edge=True):
     lines = []
     for insn in kernel.instructions:
diff --git a/loopy/schedule.py b/loopy/schedule.py
index 686f6f8e7..a6bc240c4 100644
--- a/loopy/schedule.py
+++ b/loopy/schedule.py
@@ -446,9 +446,41 @@ def generate_loop_schedules_internal(kernel, loop_priority, schedule=[], allow_b
         useful_loops = []
 
         for iname in needed_inames:
-            if not kernel.loop_nest_map()[iname] <= active_inames_set | parallel_inames:
+
+            # {{{ check if scheduling this iname now is allowed/plausible
+
+            currently_accessible_inames = active_inames_set | parallel_inames
+            if not kernel.loop_nest_map()[iname] <= currently_accessible_inames:
+                continue
+
+            iname_home_domain = kernel.domains[kernel.get_home_domain_index(iname)]
+            from islpy import dim_type
+            iname_home_domain_params = set(iname_home_domain.get_var_names(dim_type.param))
+
+            # The previous check should have ensured this is true, because
+            # Kernel.loop_nest_map takes the domain dependency graph into
+            # consideration.
+            assert (iname_home_domain_params & kernel.all_inames()
+                    <= currently_accessible_inames)
+
+            # Check if any parameters are temporary variables, and if so, if their
+            # writes have already been scheduled.
+
+            data_dep_written = True
+            for domain_par in (
+                    iname_home_domain_params
+                    &
+                    set(kernel.temporary_variables)):
+                writer_insn, = kernel.writer_map()[domain_par]
+                if writer_insn not in scheduled_insn_ids:
+                    data_dep_written = False
+                    break
+
+            if not data_dep_written:
                 continue
 
+            # }}}
+
             # {{{ determine if that gets us closer to being able to schedule an insn
 
             useful = False
diff --git a/test/test_loopy.py b/test/test_loopy.py
index a40a157f4..e07854b7c 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -181,6 +181,154 @@ def test_argmax(ctx_factory):
 
 
 
+def test_empty_reduction(ctx_factory):
+    dtype = np.dtype(np.float32)
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    knl = lp.make_kernel(ctx.devices[0],
+            [
+                "{[i]: 0<=i<20}",
+                "{[j]: 0<=j<0}"
+                ],
+            [
+                "a[i] = sum(j, j)",
+                ],
+            [
+                lp.GlobalArg("a", dtype, (20,)),
+                ])
+    print knl
+
+    cknl = lp.CompiledKernel(ctx, knl)
+    cknl.print_code()
+
+    evt, (a,) = cknl(queue)
+
+
+
+
+
+def test_nested_dependent_reduction(ctx_factory):
+    dtype = np.dtype(np.float32)
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    knl = lp.make_kernel(ctx.devices[0],
+            [
+                "{[i]: 0<=i<20}",
+                "{[j]: 0<=j<i+sumlen}"
+                ],
+            [
+                "<> sumlen = l[i]",
+                "a[i] = sum(j, j)",
+                ],
+            [
+                lp.GlobalArg("a", dtype, (20,)),
+                lp.GlobalArg("l", np.int32, (20,)),
+                ])
+    print knl
+    1/0
+
+    cknl = lp.CompiledKernel(ctx, knl)
+    cknl.print_code()
+
+    evt, (a,) = cknl(queue)
+
+
+
+
+
+def test_dependent_loop_bounds(ctx_factory):
+    dtype = np.dtype(np.float32)
+    ctx = ctx_factory()
+
+    knl = lp.make_kernel(ctx.devices[0],
+            [
+                "{[i]: 0<=i<n}",
+                "{[jj]: 0<=jj<row_len}",
+                ],
+            [
+                "<> row_len = a_rowstarts[i+1] - a_rowstarts[i]",
+                "ax[i] = sum(jj, a_values[a_rowstarts[i]+jj])",
+                ],
+            [
+                lp.GlobalArg("a_rowstarts", np.int32),
+                lp.GlobalArg("a_indices", np.int32),
+                lp.GlobalArg("a_values", dtype),
+                lp.GlobalArg("x", dtype),
+                lp.GlobalArg("ax", dtype),
+                lp.ScalarArg("n", np.int32),
+                ],
+            assumptions="n>=1 and row_len>=1")
+
+    cknl = lp.CompiledKernel(ctx, knl)
+    print "---------------------------------------------------"
+    cknl.print_code()
+    print "---------------------------------------------------"
+
+
+
+
+def test_dependent_loop_bounds_2(ctx_factory):
+    dtype = np.dtype(np.float32)
+    ctx = ctx_factory()
+
+    knl = lp.make_kernel(ctx.devices[0],
+            "[n,row_len] -> {[i,jj]: 0<=i<n and 0<=jj<row_len}",
+            [
+                "<> row_start = a_rowstarts[i]",
+                "<> row_len = a_rowstarts[i+1] - row_start",
+                "ax[i] = sum(jj, a_values[row_start+jj])",
+                ],
+            [
+                lp.GlobalArg("a_rowstarts", np.int32),
+                lp.GlobalArg("a_indices", np.int32),
+                lp.GlobalArg("a_values", dtype),
+                lp.GlobalArg("x", dtype),
+                lp.GlobalArg("ax", dtype),
+                lp.ScalarArg("n", np.int32),
+                ],
+            assumptions="n>=1 and row_len>=1")
+
+    knl = lp.split_dimension(knl, "i", 128, outer_tag="g.0",
+            inner_tag="l.0")
+    cknl = lp.CompiledKernel(ctx, knl)
+    print "---------------------------------------------------"
+    cknl.print_code()
+    print "---------------------------------------------------"
+
+
+
+
+
+def test_dependent_loop_bounds_3(ctx_factory):
+    dtype = np.dtype(np.float32)
+    ctx = ctx_factory()
+
+    knl = lp.make_kernel(ctx.devices[0],
+            "[n,row_len] -> {[i,j]: 0<=i<n and 0<=j<row_len}",
+            [
+                "<> row_len = a_row_lengths[i]",
+                "a[i,j] = 1",
+                ],
+            [
+                lp.GlobalArg("a_row_lengths", np.int32),
+                lp.GlobalArg("a", dtype, shape=("n,n"), order="C"),
+                lp.ScalarArg("n", np.int32),
+                ])
+
+    knl = lp.split_dimension(knl, "i", 128, outer_tag="g.0",
+            inner_tag="l.0")
+    knl = lp.split_dimension(knl, "j", 128, outer_tag="g.1",
+            inner_tag="l.1")
+    cknl = lp.CompiledKernel(ctx, knl)
+    print "---------------------------------------------------"
+    cknl.print_code()
+    print "---------------------------------------------------"
+
+
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1:
-- 
GitLab