From e9852a20531ece2aa37e2bdbe5dc08219c8e975a Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sat, 28 Nov 2015 17:57:24 -0600
Subject: [PATCH] Document dim_tags syntax, introduce optional dim_tags

---
 doc/reference.rst     | 60 +++++++++++++++++++++++++++++++++++++++++++
 loopy/__init__.py     |  9 ++++---
 loopy/kernel/array.py | 53 +++++++++++++++++++++++++-------------
 3 files changed, 101 insertions(+), 21 deletions(-)

diff --git a/doc/reference.rst b/doc/reference.rst
index bf6f4fbcb..cbd96f67e 100644
--- a/doc/reference.rst
+++ b/doc/reference.rst
@@ -81,6 +81,66 @@ Tag                       Meaning
 * Causes a loop (unrolled or not) to be opened/generated for each
   involved instruction
 
+.. _data-dim-tags:
+
+Data Axis Tags
+--------------
+
+Data axis tags specify how a multi-dimensional array (which is loopy's
+main way of storing data) is represented in (linear, 1D) computer
+memory. This storage format is given as a number of "tags", as listed
+in the table below. Each axis of an array has a tag corresponding to it.
+In the user interface, array dim tags are specified as a tuple of these
+tags or a comma-separated string containing them, such as the following::
+
+    c,vec,sep,c
+
+The interpretation of these tags is order-dependent, they are read
+from left to right.
+
+===================== ====================================================
+Tag                   Meaning
+===================== ====================================================
+``c``                 Nest current axis around the ones that follow
+``f``                 Nest current axis inside the ones that follow
+``N0`` ... ``N9``     Specify an explicit nesting level for this axis
+``stride:EXPR``       A fixed stride
+``sep``               Implement this axis by mapping to separate arrays
+``vec``               Implement this axis as entries in a vector
+===================== ====================================================
+
+``sep`` and ``vec`` obviously require the number of entries
+in the array along their respective axis to be known at code
+generation time.
+
+When the above speaks about 'nesting levels', this means that axes
+"nested inside" others are "faster-moving" when viewed from linear
+memory.
+
+In addition, each tag may be followed by a question mark (``?``),
+which indicates that if there are more dimension tags specified
+than array axes present, that this axis should be omitted. Axes
+with question marks are omitted in a left-first manner until the correct
+number of dimension tags is achieved.
+
+Some examples follow, all of which use a three-dimensional array of shape
+*(3, M, 4)*. For simplicity, we assume that array entries have size one.
+
+*   ``c,c,c``: The axes will have strides *(M*4, 4, 1)*,
+    leading to a C-like / row-major layout.
+
+*   ``f,f,f``: The axes will have strides *(1, 3, 3*M)*,
+    leading to a Fortran-like / row-major layout.
+
+*   ``sep,c,c``: The array will be mapped to three arrays of
+    shape *(M, 4)*, each with strides *(4, 1)*.
+
+*   ``c,c,vec``: The array will be mapped to an array of
+    ``float4`` vectors, with (``float4``-based) strides of
+    *(M, 1)*.
+
+*   ``N1,N0,N2``: The axes will have strides *(M, 1, 3*M)*.
+
 .. _creating-kernels:
 
 Creating Kernels
diff --git a/loopy/__init__.py b/loopy/__init__.py
index 1f745b93d..b805c5643 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -1297,6 +1297,7 @@ def tag_data_axes(knl, ary_names, dim_tags):
 
         from loopy.kernel.array import parse_array_dim_tags
         new_dim_tags = parse_array_dim_tags(dim_tags,
+                n_axes=ary.num_user_axes(),
                 use_increasing_target_axes=ary.max_target_axes > 1)
 
         ary = ary.copy(dim_tags=tuple(new_dim_tags))
@@ -1636,13 +1637,15 @@ def make_copy_kernel(new_dim_tags, old_dim_tags=None):
 
     from loopy.kernel.array import (parse_array_dim_tags,
             SeparateArrayArrayDimTag, VectorArrayDimTag)
-    new_dim_tags = parse_array_dim_tags(new_dim_tags)
+    new_dim_tags = parse_array_dim_tags(new_dim_tags, n_axes=None)
 
     rank = len(new_dim_tags)
     if old_dim_tags is None:
-        old_dim_tags = parse_array_dim_tags(",".join(rank * ["c"]))
+        old_dim_tags = parse_array_dim_tags(
+                ",".join(rank * ["c"]), n_axes=None)
     elif isinstance(old_dim_tags, str):
-        old_dim_tags = parse_array_dim_tags(old_dim_tags)
+        old_dim_tags = parse_array_dim_tags(
+                old_dim_tags, n_axes=None)
 
     indices = ["i%d" % i for i in range(rank)]
     shape = ["n%d" % i for i in range(rank)]
diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py
index 34dc9e5c7..291559049 100644
--- a/loopy/kernel/array.py
+++ b/loopy/kernel/array.py
@@ -181,26 +181,31 @@ class VectorArrayDimTag(ArrayDimImplementationTag):
         return self
 
 
-NESTING_LEVEL_RE = re.compile(r"^N([0-9]+)(?::(.*)|)$")
+NESTING_LEVEL_RE = re.compile(r"^N([-0-9]+)(?::(.*)|)$")
 PADDED_STRIDE_TAG_RE = re.compile(r"^([a-zA-Z]*)\(pad=(.*)\)$")
 TARGET_AXIS_RE = re.compile(r"->([0-9])$")
 
 
 def _parse_array_dim_tag(tag, default_target_axis, nesting_levels):
     if isinstance(tag, ArrayDimImplementationTag):
-        return False, tag
+        return False, False, tag
 
     if not isinstance(tag, str):
         raise TypeError("arg dimension implementation tag must be "
                 "string or tag object")
 
     tag = tag.strip()
+    is_optional = False
+    if tag.endswith("?"):
+        tag = tag[:-1]
+        is_optional = True
+
     orig_tag = tag
 
     if tag == "sep":
-        return False, SeparateArrayArrayDimTag()
+        return False, is_optional, SeparateArrayArrayDimTag()
     elif tag == "vec":
-        return False, VectorArrayDimTag()
+        return False, is_optional, VectorArrayDimTag()
 
     nesting_level_match = NESTING_LEVEL_RE.match(tag)
 
@@ -228,14 +233,18 @@ def _parse_array_dim_tag(tag, default_target_axis, nesting_levels):
         fixed_stride_descr = tag[7:]
         if fixed_stride_descr.strip() == "auto":
             import loopy as lp
-            return has_explicit_nesting_level, FixedStrideArrayDimTag(
-                    lp.auto, target_axis,
-                    layout_nesting_level=nesting_level)
+            return (
+                    has_explicit_nesting_level, is_optional,
+                    FixedStrideArrayDimTag(
+                        lp.auto, target_axis,
+                        layout_nesting_level=nesting_level))
         else:
             from loopy.symbolic import parse
-            return has_explicit_nesting_level, FixedStrideArrayDimTag(
+            return (
+                has_explicit_nesting_level, is_optional,
+                FixedStrideArrayDimTag(
                     parse(fixed_stride_descr), target_axis,
-                    layout_nesting_level=nesting_level)
+                    layout_nesting_level=nesting_level))
 
     else:
         padded_stride_match = PADDED_STRIDE_TAG_RE.match(tag)
@@ -274,11 +283,13 @@ def _parse_array_dim_tag(tag, default_target_axis, nesting_levels):
         else:
             raise LoopyError("invalid dim tag: '%s'" % orig_tag)
 
-        return has_explicit_nesting_level, ComputedStrideArrayDimTag(
-                nesting_level, pad_to=pad_to, target_axis=target_axis)
+        return (
+                has_explicit_nesting_level, is_optional,
+                ComputedStrideArrayDimTag(
+                    nesting_level, pad_to=pad_to, target_axis=target_axis))
 
 
-def parse_array_dim_tags(dim_tags, use_increasing_target_axes=False):
+def parse_array_dim_tags(dim_tags, n_axes=None, use_increasing_target_axes=False):
     if isinstance(dim_tags, str):
         dim_tags = dim_tags.split(",")
 
@@ -291,9 +302,15 @@ def parse_array_dim_tags(dim_tags, use_increasing_target_axes=False):
 
     target_axis_to_has_explicit_nesting_level = {}
 
-    for dim_tag in dim_tags:
-        has_explicit_nesting_level, parsed_dim_tag = _parse_array_dim_tag(
-            dim_tag, default_target_axis, nesting_levels)
+    for iaxis, dim_tag in enumerate(dim_tags):
+        has_explicit_nesting_level, is_optional, parsed_dim_tag = (
+                _parse_array_dim_tag(
+                    dim_tag, default_target_axis, nesting_levels))
+
+        if (is_optional
+                and n_axes is not None
+                and len(result) + (len(dim_tags) - iaxis) > n_axes):
+            continue
 
         if isinstance(parsed_dim_tag, _StrideArrayDimTagBase):
             # {{{ check for C/F mixed with explicit layout nesting level specs
@@ -513,9 +530,7 @@ class ArrayBase(Record):
 
     .. attribute:: dim_tags
 
-        a list of :class:`ArrayDimImplementationTag` instances.
-        or a list of strings that :func:`parse_array_dim_tag` understands,
-        or a comma-separated string of such tags.
+        See :ref:`data-dim-tags`.
 
     .. attribute:: offset
 
@@ -644,6 +659,7 @@ class ArrayBase(Record):
 
         if dim_tags is not None:
             dim_tags = parse_array_dim_tags(dim_tags,
+                    n_axes=(len(shape) if shape_known else None),
                     use_increasing_target_axes=self.max_target_axes > 1)
 
         # {{{ determine number of user axes
@@ -675,6 +691,7 @@ class ArrayBase(Record):
 
         if dim_tags is None and num_user_axes is not None and order is not None:
             dim_tags = parse_array_dim_tags(num_user_axes*[order],
+                    n_axes=num_user_axes,
                     use_increasing_target_axes=self.max_target_axes > 1)
             order = None
 
-- 
GitLab