Skip to content
Snippets Groups Projects
Commit 69d7badf authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Support 'if/then/else' in Fortran

parent 9720f71c
No related branches found
No related tags found
No related merge requests found
......@@ -221,6 +221,7 @@ class F2LoopyTranslator(FTreeWalkerBase):
self.isl_context = isl.Context()
self.insn_id_counter = 0
self.condition_id_counter = 0
self.kernels = []
......@@ -229,9 +230,34 @@ class F2LoopyTranslator(FTreeWalkerBase):
self.in_transform_code = False
self.instruction_tags = []
self.conditions = []
self.transform_code_lines = []
def add_expression_instruction(self, lhs, rhs):
scope = self.scope_stack[-1]
new_id = "insn%d" % self.insn_id_counter
self.insn_id_counter += 1
if scope.previous_instruction_id:
insn_deps = frozenset([scope.previous_instruction_id])
else:
insn_deps = frozenset()
from loopy.kernel.data import ExpressionInstruction
insn = ExpressionInstruction(
lhs, rhs,
forced_iname_deps=frozenset(
scope.active_loopy_inames),
insn_deps=insn_deps,
id=new_id,
predicates=frozenset(self.conditions),
tags=tuple(self.instruction_tags))
scope.previous_instruction_id = new_id
scope.instructions.append(insn)
# {{{ map_XXX functions
def map_BeginSource(self, node):
......@@ -385,28 +411,9 @@ class F2LoopyTranslator(FTreeWalkerBase):
scope.use_name(lhs_name)
from loopy.kernel.data import ExpressionInstruction
rhs = scope.process_expression_for_loopy(self.parse_expr(node.expr))
new_id = "insn%d" % self.insn_id_counter
self.insn_id_counter += 1
if scope.previous_instruction_id:
insn_deps = frozenset([scope.previous_instruction_id])
else:
insn_deps = frozenset()
insn = ExpressionInstruction(
lhs, rhs,
forced_iname_deps=frozenset(
scope.active_loopy_inames),
insn_deps=insn_deps,
id=new_id,
tags=tuple(self.instruction_tags))
scope.previous_instruction_id = new_id
scope.instructions.append(insn)
self.add_expression_instruction(lhs, rhs)
def map_Allocate(self, node):
raise NotImplementedError("allocate")
......@@ -448,10 +455,31 @@ class F2LoopyTranslator(FTreeWalkerBase):
# node.content[0]
def map_IfThen(self, node):
raise NotImplementedError("if-then")
scope = self.scope_stack[-1]
cond_name = "loopy_cond%d" % self.condition_id_counter
self.condition_id_counter += 1
assert cond_name not in scope.type_map
scope.type_map[cond_name] = np.int32
from pymbolic import var
cond_var = var(cond_name)
self.add_expression_instruction(
cond_var, self.parse_expr(node.expr))
self.conditions.append(cond_name)
for c in node.content:
self.rec(c)
def map_Else(self, node):
cond_name = self.conditions.pop()
self.conditions.append("!" + cond_name)
def map_EndIfThen(self, node):
return []
self.conditions.pop()
def map_Do(self, node):
scope = self.scope_stack[-1]
......@@ -460,7 +488,8 @@ class F2LoopyTranslator(FTreeWalkerBase):
loop_var, loop_bounds = node.loopcontrol.split("=")
loop_var = loop_var.strip()
scope.use_name(loop_var)
loop_bounds = [self.parse_expr(s) for s in loop_bounds.split(",")]
loop_bounds = self.parse_expr(
loop_bounds, min_precedence=self.expr_parser._PREC_FUNC_ARGS)
if len(loop_bounds) == 2:
start, stop = loop_bounds
......@@ -560,7 +589,8 @@ class F2LoopyTranslator(FTreeWalkerBase):
begin_tag_match = self.begin_tag_re.match(stripped_comment_line)
end_tag_match = self.end_tag_re.match(stripped_comment_line)
faulty_loopy_pragma_match = self.faulty_loopy_pragma.match(stripped_comment_line)
faulty_loopy_pragma_match = self.faulty_loopy_pragma.match(
stripped_comment_line)
if stripped_comment_line == "$loopy begin transform":
if self.in_transform_code:
......
......@@ -81,8 +81,8 @@ class FTreeWalkerBase(object):
# {{{ expressions
def parse_expr(self, expr_str):
return self.expr_parser(expr_str)
def parse_expr(self, expr_str, **kwargs):
return self.expr_parser(expr_str, **kwargs)
# }}}
......
......@@ -207,6 +207,40 @@ def test_temporary_to_subst_indices(ctx_factory):
lp.auto_test_vs_ref(ref_knl, ctx, knl)
def test_if(ctx_factory):
fortran_src = """
subroutine fill(out, out2, inp, n)
implicit none
real*8 a, b, out(n), out2(n), inp(n)
integer n
do i = 1, n
a = inp(i)
if (a.ge.3) then
b = 2*a
do j = 1,3
b = 3 * b
end do
out(i) = 5*b
else
out(i) = 4*a
endif
end do
end
"""
from loopy.frontend.fortran import f2loopy
knl, = f2loopy(fortran_src)
ref_knl = knl
knl = lp.temporary_to_subst(knl, "a")
ctx = ctx_factory()
lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5))
if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment