From 9a3574b775840ca31eba9b306ff3a0ef4914effc Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 22 Sep 2015 20:32:48 -0500
Subject: [PATCH] Allow variable-length arrays at least during transformation

---
 loopy/__init__.py  | 18 ++++++++++++++++--
 loopy/check.py     | 12 ------------
 test/test_loopy.py | 17 +++++++++++++++++
 3 files changed, 33 insertions(+), 14 deletions(-)

diff --git a/loopy/__init__.py b/loopy/__init__.py
index 2809157eb..ff3a004d9 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -944,6 +944,18 @@ def _add_kernel_axis(kernel, axis_name, start, stop, base_inames):
             .insert_dims(dim_type.set, new_dim_idx, 1)
             .set_dim_name(dim_type.set, new_dim_idx, axis_name))
 
+    from loopy.symbolic import get_dependencies
+    deps = get_dependencies(start) | get_dependencies(stop)
+    assert deps <= kernel.all_params()
+
+    param_names = domain.get_var_names(dim_type.param)
+    for dep in deps:
+        if dep not in param_names:
+            new_dim_idx = domain.dim(dim_type.param)
+            domain = (domain
+                    .insert_dims(dim_type.param, new_dim_idx, 1)
+                    .set_dim_name(dim_type.param, new_dim_idx, dep))
+
     from loopy.isl_helpers import make_slab
     slab = make_slab(domain.get_space(), axis_name, start, stop)
 
@@ -1023,7 +1035,8 @@ def _process_footprint_subscripts(kernel, rule_name, sweep_inames,
 
 
 def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
-        default_tag="l.auto", rule_name=None, temporary_name=None,
+        default_tag="l.auto", rule_name=None,
+        temporary_name=None, temporary_is_local=None,
         footprint_subscripts=None,
         fetch_bounding_box=False):
     """Prefetch all accesses to the variable *var_name*, with all accesses
@@ -1123,7 +1136,8 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
             precompute_inames=dim_arg_names,
             default_tag=default_tag, dtype=arg.dtype,
             fetch_bounding_box=fetch_bounding_box,
-            temporary_name=temporary_name)
+            temporary_name=temporary_name,
+            temporary_is_local=temporary_is_local)
 
     # {{{ remove inames that were temporarily added by slice sweeps
 
diff --git a/loopy/check.py b/loopy/check.py
index 354c6bf35..2673e1bdf 100644
--- a/loopy/check.py
+++ b/loopy/check.py
@@ -29,7 +29,6 @@ from islpy import dim_type
 import islpy as isl
 from loopy.symbolic import WalkMapper
 from loopy.diagnostic import LoopyError, WriteRaceConditionWarning, warn
-from loopy.tools import is_integer
 
 import logging
 logger = logging.getLogger(__name__)
@@ -37,16 +36,6 @@ logger = logging.getLogger(__name__)
 
 # {{{ sanity checks run pre-scheduling
 
-def check_temp_variable_shapes_are_constant(kernel):
-    for tv in six.itervalues(kernel.temporary_variables):
-        if any(not is_integer(s_i) for s_i in tv.shape):
-            raise LoopyError("shape of temporary variable '%s' is not "
-                    "constant (but has to be since the size of "
-                    "the temporary needs to be known at build time). "
-                    "Use loopy.fix_parameters to set variables to "
-                    "constant values." % tv.name)
-
-
 def check_insn_attributes(kernel):
     all_insn_ids = set(insn.id for insn in kernel.instructions)
 
@@ -359,7 +348,6 @@ def pre_schedule_checks(kernel):
     try:
         logger.info("pre-schedule check %s: start" % kernel.name)
 
-        check_temp_variable_shapes_are_constant(kernel)
         check_for_orphaned_user_hardware_axes(kernel)
         check_for_double_use_of_hw_axes(kernel)
         check_insn_attributes(kernel)
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 7cad35048..d7d7dc576 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -2174,6 +2174,23 @@ def test_to_batched(ctx_factory):
     bknl(queue, a=a, x=x)
 
 
+def test_variable_size_temporary():
+    knl = lp.make_kernel(
+         ''' { [i,j]: 0<=i,j<n } ''',
+         ''' out[i] = sum(j, a[i,j])''')
+
+    knl = lp.add_and_infer_dtypes(knl, {"a": np.float32})
+
+    knl = lp.add_prefetch(
+            knl, "a[:,:]", default_tag=None)
+
+    # Make sure that code generation succeeds even if
+    # there are variable-length arrays.
+    knl = lp.preprocess_kernel(knl)
+    for k in lp.generate_loop_schedules(knl):
+        lp.generate_code(k)
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab