From 6a9a78f4eaff80c0816171b2c505a2d21312453d Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 21 Mar 2017 14:45:29 -0500 Subject: [PATCH] Test, fix assignments to struct components (reported by Fred Burton) --- loopy/codegen/instruction.py | 7 ++++++- loopy/kernel/creation.py | 14 +++++++++----- loopy/kernel/instruction.py | 14 ++++++++++---- test/test_loopy.py | 35 +++++++++++++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 10 deletions(-) diff --git a/loopy/codegen/instruction.py b/loopy/codegen/instruction.py index 6224d9709..3ef7c8f6a 100644 --- a/loopy/codegen/instruction.py +++ b/loopy/codegen/instruction.py @@ -126,10 +126,13 @@ def generate_assignment_instruction_code(codegen_state, insn): # }}} - from pymbolic.primitives import Variable, Subscript + from pymbolic.primitives import Variable, Subscript, Lookup from loopy.symbolic import LinearSubscript lhs = insn.assignee + if isinstance(lhs, Lookup): + lhs = lhs.aggregate + if isinstance(lhs, Variable): assignee_var_name = lhs.name assignee_indices = () @@ -145,6 +148,8 @@ def generate_assignment_instruction_code(codegen_state, insn): else: raise RuntimeError("invalid lvalue '%s'" % lhs) + del lhs + result = codegen_state.ast_builder.emit_assignment(codegen_state, insn) # {{{ tracing diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index 6eedfcc20..d81b1f895 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -448,7 +448,7 @@ def parse_insn(groups, insn_options): "the following error occurred:" % groups["rhs"]) raise - from pymbolic.primitives import Variable, Subscript + from pymbolic.primitives import Variable, Subscript, Lookup from loopy.symbolic import TypeAnnotation if not isinstance(lhs, tuple): @@ -469,11 +469,15 @@ def parse_insn(groups, insn_options): else: temp_var_types.append(None) + inner_lhs_i = lhs_i + if isinstance(inner_lhs_i, Lookup): + inner_lhs_i = inner_lhs_i.aggregate + from loopy.symbolic import LinearSubscript - if isinstance(lhs_i, Variable): - assignee_names.append(lhs_i.name) - elif isinstance(lhs_i, (Subscript, LinearSubscript)): - assignee_names.append(lhs_i.aggregate.name) + if isinstance(inner_lhs_i, Variable): + assignee_names.append(inner_lhs_i.name) + elif isinstance(inner_lhs_i, (Subscript, LinearSubscript)): + assignee_names.append(inner_lhs_i.aggregate.name) else: raise LoopyError("left hand side of assignment '%s' must " "be variable or subscript" % (lhs_i,)) diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py index fdd8f1d37..dfa1df18f 100644 --- a/loopy/kernel/instruction.py +++ b/loopy/kernel/instruction.py @@ -455,9 +455,12 @@ class InstructionBase(ImmutableRecord): def _get_assignee_var_name(expr): - from pymbolic.primitives import Variable, Subscript + from pymbolic.primitives import Variable, Subscript, Lookup from loopy.symbolic import LinearSubscript + if isinstance(expr, Lookup): + expr = expr.aggregate + if isinstance(expr, Variable): return expr.name @@ -477,9 +480,12 @@ def _get_assignee_var_name(expr): def _get_assignee_subscript_deps(expr): - from pymbolic.primitives import Variable, Subscript + from pymbolic.primitives import Variable, Subscript, Lookup from loopy.symbolic import LinearSubscript, get_dependencies + if isinstance(expr, Lookup): + expr = expr.aggregate + if isinstance(expr, Variable): return frozenset() elif isinstance(expr, Subscript): @@ -770,9 +776,9 @@ class Assignment(MultiAssignmentBase): if isinstance(expression, str): expression = parse(expression) - from pymbolic.primitives import Variable, Subscript + from pymbolic.primitives import Variable, Subscript, Lookup from loopy.symbolic import LinearSubscript - if not isinstance(assignee, (Variable, Subscript, LinearSubscript)): + if not isinstance(assignee, (Variable, Subscript, LinearSubscript, Lookup)): raise LoopyError("invalid lvalue '%s'" % assignee) self.assignee = assignee diff --git a/test/test_loopy.py b/test/test_loopy.py index 851a7f076..94cdb499c 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2108,6 +2108,41 @@ def test_barrier_insertion_near_bottom_of_loop(): assert_barrier_between(knl, "ainit", "aupdate", ignore_barriers_in_levels=[1]) +def test_struct_assignment(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + bbhit = np.dtype([ + ("tmin", np.float32), + ("tmax", np.float32), + ("bi", np.int32), + ("hit", np.int32)]) + + bbhit, bbhit_c_decl = cl.tools.match_dtype_to_c_struct( + ctx.devices[0], "bbhit", bbhit) + bbhit = cl.tools.get_or_register_dtype('bbhit', bbhit) + + preamble = bbhit_c_decl + + knl = lp.make_kernel( + "{ [i]: 0<=i 1: exec(sys.argv[1]) -- GitLab