diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index f7667ca639e649a8f25b6e5d8975710742aef9a6..7acb53f8e1e71b8999e9e9123d1cc0a0bee91f02 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -1106,6 +1106,9 @@ class ArgumentGuesser: self.all_written_names = set() from loopy.symbolic import get_dependencies for insn in instructions: + for pred in insn.predicates: + self.all_names.update(get_dependencies(self.submap(pred))) + if isinstance(insn, MultiAssignmentBase): for assignee_var_name in insn.assignee_var_names(): self.all_written_names.add(assignee_var_name) diff --git a/test/test_loopy.py b/test/test_loopy.py index e33412001573a56ecf008b26e5849ba6c457dbeb..b78c754bdfe04b6bc18b3059b490e41ab5df1825 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2741,6 +2741,18 @@ def test_preamble_with_separate_temporaries(ctx_factory): queue, data=data.flatten('C'))[1][0], data[offsets[:-1] + 1]) +def test_arg_inference_for_predicates(): + knl = lp.make_kernel("{[i]: 0 <= i < 10}", + """ + if incr[i] + a = a + 1 + end + """) + + assert "incr" in knl.arg_dict + assert knl.arg_dict["incr"].shape == (10,) + + def test_add_prefetch_works_in_lhs_index(): knl = lp.make_kernel( "{ [n,k,l,k1,l1,k2,l2]: "