From 0f26b19d1a7da69c1303e76a8082df91c7180446 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 27 Aug 2012 13:46:03 -0400
Subject: [PATCH] Make defines usable in argument shapes and domains.

---
 loopy/kernel.py    | 123 ++++++++++++++++++++++++++++++---------------
 test/test_loopy.py |  41 +++++++++++++++
 2 files changed, 124 insertions(+), 40 deletions(-)

diff --git a/loopy/kernel.py b/loopy/kernel.py
index 439b6f19d..3067041d9 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -114,7 +114,7 @@ def parse_tag(tag):
 
 # {{{ arguments
 
-class _ShapedArg(object):
+class _ShapedArg(Record):
     def __init__(self, name, dtype, shape=None, strides=None, order="C",
             offset=0):
         """
@@ -127,34 +127,29 @@ class _ShapedArg(object):
         :arg offset: Offset from the beginning of the vector from which
             the strides are counted.
         """
-        self.name = name
-        self.dtype = np.dtype(dtype)
+        dtype = np.dtype(dtype)
 
-        if strides is not None and shape is not None:
-            raise ValueError("can only specify one of shape and strides")
-
-        if strides is not None:
-            if isinstance(strides, str):
+        def parse_if_necessary(x):
+            if isinstance(x, str):
                 from pymbolic import parse
-                strides = parse(strides)
+                return parse(x)
+            else:
+                return x
 
-            if not isinstance(shape, tuple):
-                shape = (shape,)
+        def process_tuple(x):
+            x = parse_if_necessary(x)
+            if not isinstance(x, tuple):
+                x = (x,)
 
-        if shape is not None:
-            def parse_if_necessary(x):
-                if isinstance(x, str):
-                    from pymbolic import parse
-                    return parse(x)
-                else:
-                    return x
+            return tuple(parse_if_necessary(xi) for xi in x)
 
-            shape = parse_if_necessary(shape)
-            if not isinstance(shape, tuple):
-                shape = (shape,)
+        if strides is not None:
+            strides = process_tuple(strides)
 
-            shape = tuple(parse_if_necessary(si) for si in shape)
+        if shape is not None:
+            shape = process_tuple(shape)
 
+        if strides is None and shape is not None:
             from pyopencl.compyte.array import (
                     f_contiguous_strides,
                     c_contiguous_strides)
@@ -166,10 +161,13 @@ class _ShapedArg(object):
             else:
                 raise ValueError("invalid order: %s" % order)
 
-        self.strides = strides
-        self.offset = offset
-        self.shape = shape
-        self.order = order
+        Record.__init__(self,
+                name=name,
+                dtype=dtype,
+                strides=strides,
+                offset=offset,
+                shape=shape,
+                order=order)
 
     @property
     def dimensions(self):
@@ -409,32 +407,57 @@ class Instruction(Record):
 
 # {{{ expand defines
 
-MACRO_RE = re.compile(r"\{([a-zA-Z0-9_]+)\}")
+WORD_RE = re.compile(r"\b([a-zA-Z0-9_]+)\b")
 
-def expand_defines(insn, defines):
-    macros = set(match.group(1) for match in MACRO_RE.finditer(insn))
+def expand_defines(insn, defines, single_valued=True):
+    words = set(match.group(1) for match in WORD_RE.finditer(insn))
 
     replacements = [()]
-    for mac in macros:
-        value = defines[mac]
+    for word in words:
+        if word not in defines:
+            continue
+
+        value = defines[word]
         if isinstance(value, list):
+            if single_valued:
+                raise ValueError("multi-valued macro expansion not allowed "
+                        "in this context (when expanding '%s')" % word)
+
             replacements = [
-                    rep+(("{%s}" % mac, subval),)
+                    rep+((r"\b%s\b" % word, subval),)
                     for rep in replacements
                     for subval in value
                     ]
         else:
             replacements = [
-                    rep+(("{%s}" % mac, value),)
+                    rep+((r"\b%s\b" % word, value),)
                     for rep in replacements]
 
     for rep in replacements:
         rep_value = insn
-        for name, val in rep:
-            rep_value = rep_value.replace(name, str(val))
+        for pattern, val in rep:
+            rep_value = re.sub(pattern, str(val), rep_value)
 
         yield rep_value
 
+def expand_defines_in_expr(expr, defines):
+    from pymbolic.primitives import Variable
+    from loopy.symbolic import parse
+
+    def subst_func(var):
+        if isinstance(var, Variable):
+            try:
+                var_value = defines[var.name]
+            except KeyError:
+                return None
+            else:
+                return parse(str(var_value))
+        else:
+            return None
+
+    from loopy.symbolic import SubstitutionMapper
+    return SubstitutionMapper(subst_func)(expr)
+
 # }}}
 
 # {{{ function manglers / dtype getters
@@ -564,13 +587,15 @@ _IDENTIFIER_RE = re.compile(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\b")
 def _gather_identifiers(s):
     return set(_IDENTIFIER_RE.findall(s))
 
-def _parse_domains(ctx, args_and_vars, domains):
+def _parse_domains(ctx, args_and_vars, domains, defines):
     result = []
     available_parameters = args_and_vars.copy()
     used_inames = set()
 
     for dom in domains:
         if isinstance(dom, str):
+            dom, = expand_defines(dom, defines)
+
             if not dom.lstrip().startswith("["):
                 # i.e. if no parameters are already given
                 ids = _gather_identifiers(dom)
@@ -639,10 +664,13 @@ class LoopKernel(Record):
         evaluated.
     :ivar defines: a dictionary of replacements to be made in instructions given
         as strings before parsing. A macro instance intended to be replaced should
-        look like "{MACRO}" in the instruction code. The expansion given in this
+        look like "MACRO" in the instruction code. The expansion given in this
         parameter is allowed to be a list. In this case, instructions are generated
         for *each* combination of macro values.
 
+        These defines may also be used in the domain and in argument shapes and
+        strides. They are expanded only upon kernel creation.
+
     The following arguments are not user-facing:
 
     :ivar iname_slab_increments: a dictionary mapping inames to (lower_incr,
@@ -864,7 +892,7 @@ class LoopKernel(Record):
                         "instance or a parseable string. got '%s' instead."
                         % type(insn))
 
-            for insn in expand_defines(insn, defines):
+            for insn in expand_defines(insn, defines, single_valued=False):
                 parse_insn(insn)
 
         parsed_instructions = []
@@ -902,7 +930,8 @@ class LoopKernel(Record):
                 | set(insn.get_assignee_var_name()
                     for insn in parsed_instructions
                     if insn.temp_var_type is not None))
-        domains = _parse_domains(isl_context, scalar_arg_names | var_names, domains)
+        domains = _parse_domains(isl_context, scalar_arg_names | var_names, domains,
+                defines)
 
         # }}}
 
@@ -936,6 +965,20 @@ class LoopKernel(Record):
 
         # }}}
 
+        # {{{ expand macros in arg shapes
+
+        processed_args = []
+        for arg in args:
+            if isinstance(arg, _ShapedArg):
+                if arg.shape is not None:
+                    arg = arg.copy(shape=expand_defines_in_expr(arg.shape, defines))
+                if arg.strides is not None:
+                    arg = arg.copy(strides=expand_defines_in_expr(arg.strides, defines))
+
+            processed_args.append(arg)
+
+        # }}}
+
         index_dtype = np.dtype(index_dtype)
         if index_dtype.kind != 'i':
             raise TypeError("index_dtype must be an integer")
@@ -945,7 +988,7 @@ class LoopKernel(Record):
         Record.__init__(self,
                 device=device, domains=domains,
                 instructions=parsed_instructions,
-                args=args,
+                args=processed_args,
                 schedule=schedule,
                 name=name,
                 preambles=preambles,
diff --git a/test/test_loopy.py b/test/test_loopy.py
index ffad62f6a..db2b5ca9d 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -520,6 +520,47 @@ def test_bare_data_dependency(ctx_factory):
 
 
 
+def test_split(ctx_factory):
+    dtype = np.float32
+    ctx = ctx_factory()
+
+    order = "C"
+
+    K = 10000
+    Np = 36
+    Nq = 50
+
+    knl = lp.make_kernel(ctx.devices[0],
+            "[K] -> {[i,j,k,ii,jj]: 0<=k<K and 0<= i,j,ii < Np and 0 <= jj < Nq}",
+            [
+                "<> temp[ii] = sum(jj, d[ii, jj]*f[k, jj])",
+                "result[k, i] = sum(j, d2[i, j]*temp[j])"
+                ],
+            [
+                lp.GlobalArg("d", dtype, shape="Np, Nq", order=order),
+                lp.GlobalArg("d2", dtype, shape="Np, Np", order=order),
+                lp.GlobalArg("f", dtype, shape="K, Nq", order=order),
+                lp.GlobalArg("result", dtype, shape="K, Np", order=order),
+                lp.ValueArg("K", np.int32, approximately=1000),
+                ],
+            name="batched_matvec", assumptions="K>=1",
+            defines=dict(Np=Np, Nq=Nq))
+
+    seq_knl = knl
+
+    knl = lp.add_prefetch(knl, 'd[:,:]')
+    knl = lp.add_prefetch(knl, 'd2[:,:]')
+
+    kernel_gen = lp.generate_loop_schedules(knl)
+    kernel_gen = lp.check_kernels(kernel_gen, dict(K=K))
+
+    lp.auto_test_vs_ref(seq_knl, ctx, kernel_gen,
+            op_count=[K*2*(Np**2+Np*Nq)/1e9], op_label=["GFlops"],
+            parameters=dict(K=K), print_ref_code=True)
+
+
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1:
-- 
GitLab