From c88e1685257b7471f0667cfef7d1706cd84d3228 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 15 Nov 2016 18:41:44 -0600 Subject: [PATCH] Add elif/else --- loopy/kernel/creation.py | 82 ++++++++++++++++++++++++++++++++++++++-- test/test_loopy.py | 32 +++++++++++++++- 2 files changed, 110 insertions(+), 4 deletions(-) diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index ab3035be0..744700502 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -270,7 +270,26 @@ def parse_insn_options(opt_dict, options_str, assignee_names=None): opt_value.split(":")) elif opt_key == "if" and opt_value is not None: - result["predicates"] = intern_frozenset_of_ids(opt_value.split(":")) + predicates = opt_value.split(":") + new_predicates = set() + + for pred in predicates: + from pymbolic.primitives import LogicalNot + from loopy.symbolic import parse + if pred.startswith("!"): + from warnings import warn + warn("predicates starting with '!' are deprecated. " + "Simply use 'not' instead") + pred = LogicalNot(parse(pred[1:])) + else: + pred = parse(pred) + + new_predicates.add(pred) + + result["predicates"] = frozenset(new_predicates) + + del predicates + del new_predicates elif opt_key == "tags" and opt_value is not None: result["tags"] = frozenset( @@ -330,9 +349,17 @@ FOR_RE = re.compile( IF_RE = re.compile( "^" "\s*if\s+" - "(?P(?:not\s+)?\w+(?:\[[ ,\w\d]+\])?)" + "(?P.+)" + "\s*$") + +ELIF_RE = re.compile( + "^" + "\s*elif\s+" + "(?P.+)" "\s*$") +ELSE_RE = re.compile("^\s*else\s*$") + INSN_RE = re.compile( "^" "\s*" @@ -579,7 +606,8 @@ def parse_instructions(instructions, defines): new_instructions.append( insn.copy( id=intern(insn.id) if isinstance(insn.id, str) else insn.id, - depends_on=frozenset(intern_if_str(dep) for dep in insn.depends_on), + depends_on=frozenset(intern_if_str(dep) + for dep in insn.depends_on), groups=frozenset(intern(grp) for grp in insn.groups), conflicts_with_groups=frozenset( intern(grp) for grp in insn.conflicts_with_groups), @@ -751,12 +779,60 @@ def parse_instructions(instructions, defines): if not predicate: raise LoopyError("'if' without predicate encountered") + from loopy.symbolic import parse + predicate = parse(predicate) + options["predicates"] = ( options.get("predicates", frozenset()) | frozenset([predicate])) insn_options_stack.append(options) del options + del predicate + continue + + elif_match = ELIF_RE.match(insn) + else_match = ELSE_RE.match(insn) + if elif_match is not None or else_match is not None: + prev_predicates = insn_options_stack[-1].get( + "predicates", frozenset()) + insn_options_stack.pop() + + outer_predicates = insn_options_stack[-1].get( + "predicates", frozenset()) + last_if_predicates = prev_predicates - outer_predicates + + if elif_match is not None: + predicate = elif_match.group("predicate") + if not predicate: + raise LoopyError("'elif' without predicate encountered") + from loopy.symbolic import parse + predicate = parse(predicate) + + additional_preds = frozenset([predicate]) + del predicate + + else: + assert else_match is not None + additional_preds = frozenset() + + options = insn_options_stack[-1].copy() + + from pymbolic.primitives import LogicalNot + options["predicates"] = ( + options.get("predicates", frozenset()) + | outer_predicates + | frozenset( + LogicalNot(pred) for pred in last_if_predicates) + | additional_preds + ) + + insn_options_stack.append(options) + + del options + del additional_preds + del last_if_predicates + continue if insn == "end": diff --git a/test/test_loopy.py b/test/test_loopy.py index 347c08d0d..0eba70940 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1601,6 +1601,36 @@ def test_scalars_with_base_storage(ctx_factory): knl(queue, out_host=True) +def test_if_else(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + knl = lp.make_kernel( + "{ [i]: 0<=i<50}", + """ + if i % 3 == 0 + a[i] = 15 + elif i % 3 == 1 + a[i] = 11 + else + a[i] = 3 + end + """ + ) + print(knl) + + knl = lp.set_options(knl, write_cl=True) + + evt, (out,) = knl(queue, out_host=True) + + out_ref = np.empty(50) + out_ref[::3] = 15 + out_ref[1::3] = 11 + out_ref[2::3] = 3 + + assert np.array_equal(out_ref, out) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) @@ -1608,4 +1638,4 @@ if __name__ == "__main__": from py.test.cmdline import main main([__file__]) -# vim: foldmethod=marker \ No newline at end of file +# vim: foldmethod=marker -- GitLab