From 93645d5535eca5e2305402650952e46da3506723 Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Fri, 15 Dec 2017 13:36:26 -0600 Subject: [PATCH] Fix arg inference for predicates (closes #114). --- loopy/kernel/creation.py | 3 +++ test/test_loopy.py | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index f7667ca63..7acb53f8e 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -1106,6 +1106,9 @@ class ArgumentGuesser: self.all_written_names = set() from loopy.symbolic import get_dependencies for insn in instructions: + for pred in insn.predicates: + self.all_names.update(get_dependencies(self.submap(pred))) + if isinstance(insn, MultiAssignmentBase): for assignee_var_name in insn.assignee_var_names(): self.all_written_names.add(assignee_var_name) diff --git a/test/test_loopy.py b/test/test_loopy.py index e33412001..b78c754bd 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2741,6 +2741,18 @@ def test_preamble_with_separate_temporaries(ctx_factory): queue, data=data.flatten('C'))[1][0], data[offsets[:-1] + 1]) +def test_arg_inference_for_predicates(): + knl = lp.make_kernel("{[i]: 0 <= i < 10}", + """ + if incr[i] + a = a + 1 + end + """) + + assert "incr" in knl.arg_dict + assert knl.arg_dict["incr"].shape == (10,) + + def test_add_prefetch_works_in_lhs_index(): knl = lp.make_kernel( "{ [n,k,l,k1,l1,k2,l2]: " -- GitLab