diff --git a/doc/ref_kernel.rst b/doc/ref_kernel.rst index 97d71f3e04051d45a2f911eb0f7b2eca7147b96b..2d754dec23762b289d3bf30ed6a7740326b11817 100644 --- a/doc/ref_kernel.rst +++ b/doc/ref_kernel.rst @@ -292,6 +292,8 @@ Loopy's expressions are a slight superset of the expressions supported by :mod:`pymbolic`. * ``if`` +* ``elif`` (following an ``if``) +* ``else`` (following an ``if`` / ``elif``) * ``reductions`` * duplication of reduction inames * ``reduce`` vs ``simul_reduce`` diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index 6c5491384d4fc37dc48604aa52753d11ac10fc55..024d97c3fed14e9917f9c21be0f17f555947f600 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -274,7 +274,26 @@ def parse_insn_options(opt_dict, options_str, assignee_names=None): opt_value.split(":")) elif opt_key == "if" and opt_value is not None: - result["predicates"] = intern_frozenset_of_ids(opt_value.split(":")) + predicates = opt_value.split(":") + new_predicates = set() + + for pred in predicates: + from pymbolic.primitives import LogicalNot + from loopy.symbolic import parse + if pred.startswith("!"): + from warnings import warn + warn("predicates starting with '!' are deprecated. " + "Simply use 'not' instead") + pred = LogicalNot(parse(pred[1:])) + else: + pred = parse(pred) + + new_predicates.add(pred) + + result["predicates"] = frozenset(new_predicates) + + del predicates + del new_predicates elif opt_key == "tags" and opt_value is not None: result["tags"] = frozenset( @@ -334,9 +353,17 @@ FOR_RE = re.compile( IF_RE = re.compile( "^" "\s*if\s+" - "(?P(?:not\s+)?\w+(?:\[[ ,\w\d]+\])?)" + "(?P.+)" "\s*$") +ELIF_RE = re.compile( + "^" + "\s*elif\s+" + "(?P.+)" + "\s*$") + +ELSE_RE = re.compile("^\s*else\s*$") + INSN_RE = re.compile( "^" "\s*" @@ -583,7 +610,8 @@ def parse_instructions(instructions, defines): new_instructions.append( insn.copy( id=intern(insn.id) if isinstance(insn.id, str) else insn.id, - depends_on=frozenset(intern_if_str(dep) for dep in insn.depends_on), + depends_on=frozenset(intern_if_str(dep) + for dep in insn.depends_on), groups=frozenset(intern(grp) for grp in insn.groups), conflicts_with_groups=frozenset( intern(grp) for grp in insn.conflicts_with_groups), @@ -667,6 +695,7 @@ def parse_instructions(instructions, defines): # {{{ pass 4: parsing insn_options_stack = [get_default_insn_options_dict()] + if_predicates_stack = [{'predicates' : frozenset(), 'insn_predicates' : frozenset()}] for insn in instructions: if isinstance(insn, InstructionBase): @@ -755,16 +784,85 @@ def parse_instructions(instructions, defines): if not predicate: raise LoopyError("'if' without predicate encountered") + from loopy.symbolic import parse + predicate = parse(predicate) + options["predicates"] = ( options.get("predicates", frozenset()) | frozenset([predicate])) insn_options_stack.append(options) + + #add to the if_stack + if_options = options.copy() + if_options['insn_predicates'] = options["predicates"] + if_predicates_stack.append(if_options) del options + del predicate continue - if insn == "end": + elif_match = ELIF_RE.match(insn) + else_match = ELSE_RE.match(insn) + if elif_match is not None or else_match is not None: + prev_predicates = insn_options_stack[-1].get( + "predicates", frozenset()) + last_if_predicates = if_predicates_stack[-1].get( + "predicates", frozenset()) insn_options_stack.pop() + if_predicates_stack.pop() + + outer_predicates = insn_options_stack[-1].get( + "predicates", frozenset()) + last_if_predicates = last_if_predicates - outer_predicates + + + if elif_match is not None: + predicate = elif_match.group("predicate") + if not predicate: + raise LoopyError("'elif' without predicate encountered") + from loopy.symbolic import parse + predicate = parse(predicate) + + additional_preds = frozenset([predicate]) + del predicate + + else: + assert else_match is not None + if not last_if_predicates: + raise LoopyError("'else' without 'if'/'elif' encountered") + additional_preds = frozenset() + + options = insn_options_stack[-1].copy() + if_options = insn_options_stack[-1].copy() + + from pymbolic.primitives import LogicalNot + options["predicates"] = ( + options.get("predicates", frozenset()) + | outer_predicates + | prev_predicates - last_if_predicates + | frozenset( + LogicalNot(pred) for pred in last_if_predicates) + | additional_preds + ) + if_options["predicates"] = additional_preds + #hold on to this for comparison / stack popping later + if_options["insn_predicates"] = options["predicates"] + + insn_options_stack.append(options) + if_predicates_stack.append(if_options) + + del options + del additional_preds + del last_if_predicates + + continue + + if insn == "end": + obj = insn_options_stack.pop() + #if this object is the end of an if statement + if obj['predicates'] == if_predicates_stack[-1]["insn_predicates"] and\ + if_predicates_stack[-1]["insn_predicates"]: + if_predicates_stack.pop() continue insn_match = SPECIAL_INSN_RE.match(insn) diff --git a/loopy/transform/ilp.py b/loopy/transform/ilp.py index f3b7d3f0e8f9fae1847ad4eea42175cddadfe5d9..77840753258fa545aa01ef3e8c58cbc36e66ed72 100644 --- a/loopy/transform/ilp.py +++ b/loopy/transform/ilp.py @@ -193,7 +193,7 @@ def realize_ilp(kernel, iname): on a per-iname basis (so that, for instance, data layout of the duplicated storage can be controlled explicitly. """ - from loopy.ilp import add_axes_to_temporaries_for_ilp_and_vec + from loopy.transform.ilp import add_axes_to_temporaries_for_ilp_and_vec return add_axes_to_temporaries_for_ilp_and_vec(kernel, iname) # }}} diff --git a/test/test_loopy.py b/test/test_loopy.py index af4269047539b800a5fd389f9293f11551c9a291..4221feb792550a6c8b243f0f648dfff6d4654fac 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1635,6 +1635,28 @@ def test_ilp_and_conditionals(ctx_factory): lp.auto_test_vs_ref(ref_knl, ctx, knl) +def test_unr_and_conditionals(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel('{[k]: 0<=k Tcond[k] = T[k] < 0.5 + if Tcond[k] + cp[k] = 2 * T[k] + Tcond[k] + end + end + """) + + knl = lp.fix_parameters(knl, n=200) + knl = lp.add_and_infer_dtypes(knl, {"T": np.float32}) + + ref_knl = knl + + knl = lp.split_iname(knl, 'k', 2, inner_tag='unr') + + lp.auto_test_vs_ref(ref_knl, ctx, knl) + def test_unr_and_conditionals(ctx_factory): ctx = ctx_factory() @@ -1751,6 +1773,62 @@ def test_scalars_with_base_storage(ctx_factory): knl(queue, out_host=True) +def test_if_else(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + knl = lp.make_kernel( + "{ [i]: 0<=i<50}", + """ + if i % 3 == 0 + a[i] = 15 + elif i % 3 == 1 + a[i] = 11 + else + a[i] = 3 + end + """ + ) + + evt, (out,) = knl(queue, out_host=True) + + out_ref = np.empty(50) + out_ref[::3] = 15 + out_ref[1::3] = 11 + out_ref[2::3] = 3 + + assert np.array_equal(out_ref, out) + + knl = lp.make_kernel( + "{ [i]: 0<=i<50}", + """ + for i + if i % 2 == 0 + if i % 3 == 0 + a[i] = 15 + elif i % 3 == 1 + a[i] = 11 + else + a[i] = 3 + end + else + a[i] = 4 + end + end + """ + ) + + evt, (out,) = knl(queue, out_host=True) + + out_ref = np.zeros(50) + out_ref[1::2] = 4 + out_ref[0::6] = 15 + out_ref[4::6] = 11 + out_ref[2::6] = 3 + + assert np.array_equal(out_ref, out) + + def test_tight_loop_bounds(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx)