diff --git a/loopy/frontend/fortran/translator.py b/loopy/frontend/fortran/translator.py index 783dc32313883f092f15baacccf3e68d75af3b93..04c8941455bfb5c5413def494bdf8ec33a29c65f 100644 --- a/loopy/frontend/fortran/translator.py +++ b/loopy/frontend/fortran/translator.py @@ -221,6 +221,7 @@ class F2LoopyTranslator(FTreeWalkerBase): self.isl_context = isl.Context() self.insn_id_counter = 0 + self.condition_id_counter = 0 self.kernels = [] @@ -229,9 +230,34 @@ class F2LoopyTranslator(FTreeWalkerBase): self.in_transform_code = False self.instruction_tags = [] + self.conditions = [] self.transform_code_lines = [] + def add_expression_instruction(self, lhs, rhs): + scope = self.scope_stack[-1] + + new_id = "insn%d" % self.insn_id_counter + self.insn_id_counter += 1 + + if scope.previous_instruction_id: + insn_deps = frozenset([scope.previous_instruction_id]) + else: + insn_deps = frozenset() + + from loopy.kernel.data import ExpressionInstruction + insn = ExpressionInstruction( + lhs, rhs, + forced_iname_deps=frozenset( + scope.active_loopy_inames), + insn_deps=insn_deps, + id=new_id, + predicates=frozenset(self.conditions), + tags=tuple(self.instruction_tags)) + + scope.previous_instruction_id = new_id + scope.instructions.append(insn) + # {{{ map_XXX functions def map_BeginSource(self, node): @@ -385,28 +411,9 @@ class F2LoopyTranslator(FTreeWalkerBase): scope.use_name(lhs_name) - from loopy.kernel.data import ExpressionInstruction - rhs = scope.process_expression_for_loopy(self.parse_expr(node.expr)) - new_id = "insn%d" % self.insn_id_counter - self.insn_id_counter += 1 - - if scope.previous_instruction_id: - insn_deps = frozenset([scope.previous_instruction_id]) - else: - insn_deps = frozenset() - - insn = ExpressionInstruction( - lhs, rhs, - forced_iname_deps=frozenset( - scope.active_loopy_inames), - insn_deps=insn_deps, - id=new_id, - tags=tuple(self.instruction_tags)) - - scope.previous_instruction_id = new_id - scope.instructions.append(insn) + self.add_expression_instruction(lhs, rhs) def map_Allocate(self, node): raise NotImplementedError("allocate") @@ -448,10 +455,31 @@ class F2LoopyTranslator(FTreeWalkerBase): # node.content[0] def map_IfThen(self, node): - raise NotImplementedError("if-then") + scope = self.scope_stack[-1] + + cond_name = "loopy_cond%d" % self.condition_id_counter + self.condition_id_counter += 1 + assert cond_name not in scope.type_map + + scope.type_map[cond_name] = np.int32 + + from pymbolic import var + cond_var = var(cond_name) + + self.add_expression_instruction( + cond_var, self.parse_expr(node.expr)) + + self.conditions.append(cond_name) + + for c in node.content: + self.rec(c) + + def map_Else(self, node): + cond_name = self.conditions.pop() + self.conditions.append("!" + cond_name) def map_EndIfThen(self, node): - return [] + self.conditions.pop() def map_Do(self, node): scope = self.scope_stack[-1] @@ -460,7 +488,8 @@ class F2LoopyTranslator(FTreeWalkerBase): loop_var, loop_bounds = node.loopcontrol.split("=") loop_var = loop_var.strip() scope.use_name(loop_var) - loop_bounds = [self.parse_expr(s) for s in loop_bounds.split(",")] + loop_bounds = self.parse_expr( + loop_bounds, min_precedence=self.expr_parser._PREC_FUNC_ARGS) if len(loop_bounds) == 2: start, stop = loop_bounds @@ -560,7 +589,8 @@ class F2LoopyTranslator(FTreeWalkerBase): begin_tag_match = self.begin_tag_re.match(stripped_comment_line) end_tag_match = self.end_tag_re.match(stripped_comment_line) - faulty_loopy_pragma_match = self.faulty_loopy_pragma.match(stripped_comment_line) + faulty_loopy_pragma_match = self.faulty_loopy_pragma.match( + stripped_comment_line) if stripped_comment_line == "$loopy begin transform": if self.in_transform_code: diff --git a/loopy/frontend/fortran/tree.py b/loopy/frontend/fortran/tree.py index fe25435b79d73d89f84f3a7dc1e5e9a474ef8b6a..4291d98749cb84bf743c3e8c745052190000cd03 100644 --- a/loopy/frontend/fortran/tree.py +++ b/loopy/frontend/fortran/tree.py @@ -81,8 +81,8 @@ class FTreeWalkerBase(object): # {{{ expressions - def parse_expr(self, expr_str): - return self.expr_parser(expr_str) + def parse_expr(self, expr_str, **kwargs): + return self.expr_parser(expr_str, **kwargs) # }}} diff --git a/test/test_fortran.py b/test/test_fortran.py index bf6ffe140289c13e5c04df02b0c226f44832d41b..b080293c9a94cd86d502887fd3a1c6e3bea82071 100644 --- a/test/test_fortran.py +++ b/test/test_fortran.py @@ -207,6 +207,40 @@ def test_temporary_to_subst_indices(ctx_factory): lp.auto_test_vs_ref(ref_knl, ctx, knl) +def test_if(ctx_factory): + fortran_src = """ + subroutine fill(out, out2, inp, n) + implicit none + + real*8 a, b, out(n), out2(n), inp(n) + integer n + + do i = 1, n + a = inp(i) + if (a.ge.3) then + b = 2*a + do j = 1,3 + b = 3 * b + end do + out(i) = 5*b + else + out(i) = 4*a + endif + end do + end + """ + + from loopy.frontend.fortran import f2loopy + knl, = f2loopy(fortran_src) + + ref_knl = knl + + knl = lp.temporary_to_subst(knl, "a") + + ctx = ctx_factory() + lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])