From f6fe685cb37085cb2c681e9da7e7f126f33a94b8 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 15 Feb 2015 13:10:32 -0600
Subject: [PATCH] Support unspecified shape axes (Fortran '*' axes)

---
 loopy/auto_test.py                   | 23 ++++++++++++----
 loopy/check.py                       | 30 ++++++++++++++-------
 loopy/compiled.py                    | 40 ++++++++++++++++++++++++----
 loopy/frontend/fortran/translator.py | 14 +++++++---
 loopy/kernel/array.py                | 17 ++++++++++--
 loopy/kernel/creation.py             |  8 +++++-
 loopy/symbolic.py                    |  2 +-
 test/test_fortran.py                 | 27 +++++++++++++++++++
 8 files changed, 135 insertions(+), 26 deletions(-)

diff --git a/loopy/auto_test.py b/loopy/auto_test.py
index dc304780c..ecd74c1a6 100644
--- a/loopy/auto_test.py
+++ b/loopy/auto_test.py
@@ -44,6 +44,19 @@ def is_dtype_supported(dtype):
     return dtype.kind in "biufc"
 
 
+def evaluate_shape(shape, context):
+    from pymbolic import evaluate
+
+    result = []
+    for saxis in shape:
+        if saxis is None:
+            result.append(saxis)
+        else:
+            result.append(evaluate(saxis, context))
+
+    return tuple(result)
+
+
 # {{{ create random argument arrays for testing
 
 def fill_rand(ary):
@@ -96,11 +109,11 @@ def make_ref_args(kernel, impl_arg_info, queue, parameters, fill_value):
             ref_arg_data.append(None)
 
         elif arg.arg_class is GlobalArg or arg.arg_class is ImageArg:
-            if arg.shape is None:
-                raise LoopyError("arrays need known shape to use automatic "
-                        "testing")
+            if arg.shape is None or any(saxis is None for saxis in arg.shape):
+                raise LoopyError("array '%s' needs known shape to use automatic "
+                        "testing" % arg.name)
 
-            shape = evaluate(arg.unvec_shape, parameters)
+            shape = evaluate_shape(arg.unvec_shape, parameters)
             dtype = kernel_arg.dtype
 
             is_output = arg.base_name in kernel.get_written_variables()
@@ -206,7 +219,7 @@ def make_args(kernel, impl_arg_info, queue, ref_arg_data, parameters,
                 raise NotImplementedError("write-mode images not supported in "
                         "automatic testing")
 
-            shape = evaluate(arg.unvec_shape, parameters)
+            shape = evaluate_shape(arg.unvec_shape, parameters)
             assert shape == arg_desc.ref_shape
 
             # must be contiguous
diff --git a/loopy/check.py b/loopy/check.py
index 09ca0c576..477a6336f 100644
--- a/loopy/check.py
+++ b/loopy/check.py
@@ -275,8 +275,13 @@ class _AccessCheckMapper(WalkMapper):
             from loopy.symbolic import get_dependencies, get_access_range
 
             available_vars = set(self.domain.get_var_dict())
+            shape_deps = set()
+            for shape_axis in shape:
+                if shape_axis is not None:
+                    shape_deps.update(get_dependencies(shape_axis))
+
             if not (get_dependencies(subscript) <= available_vars
-                    and get_dependencies(shape) <= available_vars):
+                    and shape_deps <= available_vars):
                 return
 
             if len(subscript) != len(shape):
@@ -297,12 +302,15 @@ class _AccessCheckMapper(WalkMapper):
 
             shape_domain = isl.BasicSet.universe(access_range.get_space())
             for idim in range(len(subscript)):
-                from loopy.isl_helpers import make_slab
-                slab = make_slab(
-                        shape_domain.get_space(), (dim_type.in_, idim),
-                        0, shape[idim])
+                shape_axis = shape[idim]
+
+                if shape_axis is not None:
+                    from loopy.isl_helpers import make_slab
+                    slab = make_slab(
+                            shape_domain.get_space(), (dim_type.in_, idim),
+                            0, shape_axis)
 
-                shape_domain = shape_domain.intersect(slab)
+                    shape_domain = shape_domain.intersect(slab)
 
             if not access_range.is_subset(shape_domain):
                 raise LoopyError("'%s' in instruction '%s' "
@@ -391,11 +399,15 @@ def check_that_shapes_and_strides_are_arguments(kernel):
     for arg in kernel.args:
         if isinstance(arg, ArrayBase):
             if isinstance(arg.shape, tuple):
-                deps = get_dependencies(arg.shape)
-                if not deps <= integer_arg_names:
+                shape_deps = set()
+                for shape_axis in arg.shape:
+                    if shape_axis is not None:
+                        shape_deps.update(get_dependencies(shape_axis))
+
+                if not shape_deps <= integer_arg_names:
                     raise LoopyError("'%s' has a shape that depends on "
                             "non-argument(s): %s" % (
-                                arg.name, ", ".join(deps-integer_arg_names)))
+                                arg.name, ", ".join(shape_deps-integer_arg_names)))
 
             if arg.dim_tags is None:
                 continue
diff --git a/loopy/compiled.py b/loopy/compiled.py
index 463e24921..c01981f61 100644
--- a/loopy/compiled.py
+++ b/loopy/compiled.py
@@ -477,14 +477,44 @@ def generate_array_arg_setup(gen, kernel, impl_arg_info, options):
                             "(got: %%s, expected: %s)\" %% %s.dtype)"
                             % (arg.name, arg.dtype, arg.name))
 
-                if kernel_arg.shape is not None:
+                # {{{ generate shape checking code
+
+                def strify_allowing_none(shape_axis):
+                    if shape_axis is None:
+                        return "None"
+                    else:
+                        return strify(shape_axis)
+
+                shape_mismatch_msg = (
+                        "raise TypeError(\"shape mismatch on argument '%s' "
+                        "(got: %%s, expected: %%s)\" "
+                        "%% (%s.shape, (%s,)))"
+                        % (arg.name, arg.name,
+                            ", ".join(strify_allowing_none(sa)
+                                for sa in arg.unvec_shape)))
+
+                if any(shape_axis is None for shape_axis in kernel_arg.shape):
+                    gen("if len(%s.shape) != %s:"
+                            % (arg.name, len(arg.unvec_shape)))
+                    with Indentation(gen):
+                        gen(shape_mismatch_msg)
+
+                    for i, shape_axis in enumerate(arg.unvec_shape):
+                        if shape_axis is None:
+                            continue
+
+                        gen("if %s.shape[%d] != %s:"
+                                % (arg.name, i, strify(shape_axis)))
+                        with Indentation(gen):
+                            gen(shape_mismatch_msg)
+
+                elif kernel_arg.shape is not None:
                     gen("if %s.shape != %s:"
                             % (arg.name, strify(arg.unvec_shape)))
                     with Indentation(gen):
-                        gen("raise TypeError(\"shape mismatch on argument '%s' "
-                                "(got: %%s, expected: %%s)\" "
-                                "%% (%s.shape, %s))"
-                                % (arg.name, arg.name, strify(arg.unvec_shape)))
+                        gen(shape_mismatch_msg)
+
+                # }}}
 
                 if arg.unvec_strides and kernel_arg.dim_tags:
                     itemsize = kernel_arg.dtype.itemsize
diff --git a/loopy/frontend/fortran/translator.py b/loopy/frontend/fortran/translator.py
index e4a24ced7..b7cdeb195 100644
--- a/loopy/frontend/fortran/translator.py
+++ b/loopy/frontend/fortran/translator.py
@@ -35,6 +35,7 @@ from loopy.frontend.fortran.diagnostic import (
 import islpy as isl
 from islpy import dim_type
 from loopy.symbolic import IdentityMapper
+from pymbolic.primitives import Wildcard
 
 
 # {{{ subscript base shifter
@@ -158,9 +159,16 @@ class Scope(object):
         shape = []
         for i, dim in enumerate(dims):
             if len(dim) == 1:
-                shape.append(dim[0])
+                if isinstance(dim[0], Wildcard):
+                    shape.append(None)
+                else:
+                    shape.append(dim[0])
+
             elif len(dim) == 2:
-                shape.append(dim[1]-dim[0]+1)
+                if isinstance(dim[0], Wildcard):
+                    shape.append(None)
+                else:
+                    shape.append(dim[1]-dim[0]+1)
             else:
                 raise TranslationError("dimension axis %d "
                         "of '%s' not understood: %s"
@@ -418,7 +426,7 @@ class F2LoopyTranslator(FTreeWalkerBase):
         raise NotImplementedError
 
     def map_Entry(self, node):
-        raise NotImplementedError
+        raise NotImplementedError("entry")
 
     # {{{ control flow
 
diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py
index 550343755..ee14042d1 100644
--- a/loopy/kernel/array.py
+++ b/loopy/kernel/array.py
@@ -438,7 +438,11 @@ def convert_computed_to_fixed_dim_tags(name, num_user_axes, num_target_axes,
                     # unable to normalize without known shape
                     return None
 
-                stride_so_far *= shape[dim_tag_index]
+                shape_axis = shape[dim_tag_index]
+                if shape_axis is None:
+                    stride_so_far = None
+                else:
+                    stride_so_far *= shape_axis
 
                 if dim_tag.pad_to is not None:
                     from pytools import div_ceil
@@ -548,6 +552,9 @@ class ArrayBase(Record):
               expression involving kernel parameters, or a (potentially-comma
               separated) or a string that can be parsed to such an expression.
 
+              Any element of the shape tuple not used to compute strides
+              may be *None*.
+
             * A string which can be parsed into the previous form.
 
         :arg dim_tags: A comma-separated list of tags as understood by
@@ -818,7 +825,13 @@ class ArrayBase(Record):
         import loopy as lp
 
         if self.shape is not None and self.shape is not lp.auto:
-            kwargs["shape"] = tuple(mapper(s) for s in self.shape)
+            def none_pass_mapper(s):
+                if s is None:
+                    return s
+                else:
+                    return mapper(s)
+
+            kwargs["shape"] = tuple(none_pass_mapper(s) for s in self.shape)
 
         if self.dim_tags is not None:
             kwargs["dim_tags"] = [dim_tag.map_expr(mapper)
diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py
index c96ad5c51..a21d983d7 100644
--- a/loopy/kernel/creation.py
+++ b/loopy/kernel/creation.py
@@ -1095,7 +1095,13 @@ def make_kernel(domains, instructions, kernel_data=["..."], **kwargs):
             continue
 
         if isinstance(dat, ArrayBase) and isinstance(dat.shape, tuple):
-            dat = dat.copy(shape=expand_defines_in_expr(dat.shape, defines))
+            new_shape = []
+            for shape_axis in dat.shape:
+                if shape_axis is not None:
+                    new_shape.append(expand_defines_in_expr(shape_axis, defines))
+                else:
+                    new_shape.append(shape_axis)
+            dat = dat.copy(shape=tuple(new_shape))
 
         for arg_name in dat.name.split(","):
             arg_name = arg_name.strip()
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index 10acada14..e644ce889 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -362,7 +362,7 @@ def rename_subst_rules_in_instructions(insns, renames):
 
 class ExpandingIdentityMapper(IdentityMapper):
     """Note: the third argument dragged around by this mapper is the
-    current expansion expansion state.
+    current :class:`ExpansionState`.
     """
 
     def __init__(self, old_subst_rules, make_unique_var_name):
diff --git a/test/test_fortran.py b/test/test_fortran.py
index be6658131..d887cd401 100644
--- a/test/test_fortran.py
+++ b/test/test_fortran.py
@@ -96,6 +96,33 @@ def test_fill_const(ctx_factory):
     lp.auto_test_vs_ref(knl, ctx, knl, parameters=dict(n=5, a=5))
 
 
+def test_asterisk_in_shape(ctx_factory):
+    fortran_src = """
+        subroutine fill(out, out2, inp, n)
+          implicit none
+
+          real*8 a, out(n), out2(n), inp(*)
+          integer n
+
+          do i = 1, n
+            a = inp(n)
+            out(i) = 5*a
+            out2(i) = 6*a
+          end do
+        end
+        """
+
+    from loopy.frontend.fortran import f2loopy
+    knl, = f2loopy(fortran_src)
+
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    knl(queue, inp=np.array([1, 2, 3.]), n=3)
+
+    #lp.auto_test_vs_ref(knl, ctx, knl, parameters=dict(n=5))
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab