From 6a9a78f4eaff80c0816171b2c505a2d21312453d Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 21 Mar 2017 14:45:29 -0500
Subject: [PATCH] Test, fix assignments to struct components (reported by Fred
 Burton)

---
 loopy/codegen/instruction.py |  7 ++++++-
 loopy/kernel/creation.py     | 14 +++++++++-----
 loopy/kernel/instruction.py  | 14 ++++++++++----
 test/test_loopy.py           | 35 +++++++++++++++++++++++++++++++++++
 4 files changed, 60 insertions(+), 10 deletions(-)

diff --git a/loopy/codegen/instruction.py b/loopy/codegen/instruction.py
index 6224d9709..3ef7c8f6a 100644
--- a/loopy/codegen/instruction.py
+++ b/loopy/codegen/instruction.py
@@ -126,10 +126,13 @@ def generate_assignment_instruction_code(codegen_state, insn):
 
     # }}}
 
-    from pymbolic.primitives import Variable, Subscript
+    from pymbolic.primitives import Variable, Subscript, Lookup
     from loopy.symbolic import LinearSubscript
 
     lhs = insn.assignee
+    if isinstance(lhs, Lookup):
+        lhs = lhs.aggregate
+
     if isinstance(lhs, Variable):
         assignee_var_name = lhs.name
         assignee_indices = ()
@@ -145,6 +148,8 @@ def generate_assignment_instruction_code(codegen_state, insn):
     else:
         raise RuntimeError("invalid lvalue '%s'" % lhs)
 
+    del lhs
+
     result = codegen_state.ast_builder.emit_assignment(codegen_state, insn)
 
     # {{{ tracing
diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py
index 6eedfcc20..d81b1f895 100644
--- a/loopy/kernel/creation.py
+++ b/loopy/kernel/creation.py
@@ -448,7 +448,7 @@ def parse_insn(groups, insn_options):
                 "the following error occurred:" % groups["rhs"])
         raise
 
-    from pymbolic.primitives import Variable, Subscript
+    from pymbolic.primitives import Variable, Subscript, Lookup
     from loopy.symbolic import TypeAnnotation
 
     if not isinstance(lhs, tuple):
@@ -469,11 +469,15 @@ def parse_insn(groups, insn_options):
         else:
             temp_var_types.append(None)
 
+        inner_lhs_i = lhs_i
+        if isinstance(inner_lhs_i, Lookup):
+            inner_lhs_i = inner_lhs_i.aggregate
+
         from loopy.symbolic import LinearSubscript
-        if isinstance(lhs_i, Variable):
-            assignee_names.append(lhs_i.name)
-        elif isinstance(lhs_i, (Subscript, LinearSubscript)):
-            assignee_names.append(lhs_i.aggregate.name)
+        if isinstance(inner_lhs_i, Variable):
+            assignee_names.append(inner_lhs_i.name)
+        elif isinstance(inner_lhs_i, (Subscript, LinearSubscript)):
+            assignee_names.append(inner_lhs_i.aggregate.name)
         else:
             raise LoopyError("left hand side of assignment '%s' must "
                     "be variable or subscript" % (lhs_i,))
diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py
index fdd8f1d37..dfa1df18f 100644
--- a/loopy/kernel/instruction.py
+++ b/loopy/kernel/instruction.py
@@ -455,9 +455,12 @@ class InstructionBase(ImmutableRecord):
 
 
 def _get_assignee_var_name(expr):
-    from pymbolic.primitives import Variable, Subscript
+    from pymbolic.primitives import Variable, Subscript, Lookup
     from loopy.symbolic import LinearSubscript
 
+    if isinstance(expr, Lookup):
+        expr = expr.aggregate
+
     if isinstance(expr, Variable):
         return expr.name
 
@@ -477,9 +480,12 @@ def _get_assignee_var_name(expr):
 
 
 def _get_assignee_subscript_deps(expr):
-    from pymbolic.primitives import Variable, Subscript
+    from pymbolic.primitives import Variable, Subscript, Lookup
     from loopy.symbolic import LinearSubscript, get_dependencies
 
+    if isinstance(expr, Lookup):
+        expr = expr.aggregate
+
     if isinstance(expr, Variable):
         return frozenset()
     elif isinstance(expr, Subscript):
@@ -770,9 +776,9 @@ class Assignment(MultiAssignmentBase):
         if isinstance(expression, str):
             expression = parse(expression)
 
-        from pymbolic.primitives import Variable, Subscript
+        from pymbolic.primitives import Variable, Subscript, Lookup
         from loopy.symbolic import LinearSubscript
-        if not isinstance(assignee, (Variable, Subscript, LinearSubscript)):
+        if not isinstance(assignee, (Variable, Subscript, LinearSubscript, Lookup)):
             raise LoopyError("invalid lvalue '%s'" % assignee)
 
         self.assignee = assignee
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 851a7f076..94cdb499c 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -2108,6 +2108,41 @@ def test_barrier_insertion_near_bottom_of_loop():
     assert_barrier_between(knl, "ainit", "aupdate", ignore_barriers_in_levels=[1])
 
 
+def test_struct_assignment(ctx_factory):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    bbhit = np.dtype([
+        ("tmin", np.float32),
+        ("tmax", np.float32),
+        ("bi", np.int32),
+        ("hit", np.int32)])
+
+    bbhit, bbhit_c_decl = cl.tools.match_dtype_to_c_struct(
+            ctx.devices[0], "bbhit", bbhit)
+    bbhit = cl.tools.get_or_register_dtype('bbhit', bbhit)
+
+    preamble = bbhit_c_decl
+
+    knl = lp.make_kernel(
+        "{ [i]: 0<=i<N }",
+        """
+        for i
+            result[i].hit = i % 2
+            result[i].tmin = i
+            result[i].tmax = i+10
+            result[i].bi = i
+        end
+        """,
+        [
+            lp.GlobalArg("result", shape=("N",), dtype=bbhit),
+            "..."],
+        preambles=[("000", preamble)])
+
+    knl = lp.set_options(knl, write_cl=True)
+    knl(queue, N=200)
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab