Skip to content
Snippets Groups Projects
Commit a3cd87ed authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Check block nesting in Fortran frontend

parent f8b05ee4
No related branches found
No related tags found
No related merge requests found
......@@ -221,6 +221,8 @@ class F2LoopyTranslator(FTreeWalkerBase):
self.index_dtype = None
self.block_nest = []
def add_expression_instruction(self, lhs, rhs):
scope = self.scope_stack[-1]
......@@ -261,6 +263,7 @@ class F2LoopyTranslator(FTreeWalkerBase):
scope = Scope(node.name, list(node.args))
self.scope_stack.append(scope)
self.block_nest.append("sub")
for c in node.content:
self.rec(c)
......@@ -269,6 +272,11 @@ class F2LoopyTranslator(FTreeWalkerBase):
self.kernels.append(scope)
def map_EndSubroutine(self, node):
if not self.block_nest:
raise TranslationError("no subroutine started at this point")
if self.block_nest.pop() != "sub":
raise TranslationError("mismatched end subroutine")
return []
def map_Implicit(self, node):
......@@ -459,115 +467,126 @@ class F2LoopyTranslator(FTreeWalkerBase):
for c in node.content:
self.rec(c)
self.block_nest.append("if")
def map_Else(self, node):
cond_name = self.conditions.pop()
self.conditions.append("!" + cond_name)
def map_EndIfThen(self, node):
if not self.block_nest:
raise TranslationError("no if block started at end do")
if self.block_nest.pop() != "if":
raise TranslationError("mismatched end if")
self.conditions.pop()
def map_Do(self, node):
scope = self.scope_stack[-1]
if node.loopcontrol:
loop_var, loop_bounds = node.loopcontrol.split("=")
loop_var = loop_var.strip()
iname_dtype = scope.get_type(loop_var)
if self.index_dtype is None:
self.index_dtype = iname_dtype
else:
if self.index_dtype != iname_dtype:
raise LoopyError("type of '%s' (%s) does not agree with prior "
"index type (%s)"
% (loop_var, iname_dtype, self.index_dtype))
scope.use_name(loop_var)
loop_bounds = self.parse_expr(
node,
loop_bounds, min_precedence=self.expr_parser._PREC_FUNC_ARGS)
if len(loop_bounds) == 2:
start, stop = loop_bounds
step = 1
elif len(loop_bounds) == 3:
start, stop, step = loop_bounds
else:
raise RuntimeError("loop bounds not understood: %s"
% node.loopcontrol)
if not node.loopcontrol:
raise NotImplementedError("unbounded do loop")
if step != 1:
raise NotImplementedError(
"do loops with non-unit stride")
loop_var, loop_bounds = node.loopcontrol.split("=")
loop_var = loop_var.strip()
if not isinstance(step, int):
raise TranslationError(
"non-constant steps not supported: %s" % step)
from loopy.symbolic import get_dependencies
loop_bound_deps = (
get_dependencies(start)
| get_dependencies(stop)
| get_dependencies(step))
# {{{ find a usable loopy-side loop name
loopy_loop_var = loop_var
loop_var_suffix = None
while True:
already_used = False
for iset in scope.index_sets:
if loopy_loop_var in iset.get_var_dict(dim_type.set):
already_used = True
break
if not already_used:
iname_dtype = scope.get_type(loop_var)
if self.index_dtype is None:
self.index_dtype = iname_dtype
else:
if self.index_dtype != iname_dtype:
raise LoopyError("type of '%s' (%s) does not agree with prior "
"index type (%s)"
% (loop_var, iname_dtype, self.index_dtype))
scope.use_name(loop_var)
loop_bounds = self.parse_expr(
node,
loop_bounds, min_precedence=self.expr_parser._PREC_FUNC_ARGS)
if len(loop_bounds) == 2:
start, stop = loop_bounds
step = 1
elif len(loop_bounds) == 3:
start, stop, step = loop_bounds
else:
raise RuntimeError("loop bounds not understood: %s"
% node.loopcontrol)
if step != 1:
raise NotImplementedError(
"do loops with non-unit stride")
if not isinstance(step, int):
raise TranslationError(
"non-constant steps not supported: %s" % step)
from loopy.symbolic import get_dependencies
loop_bound_deps = (
get_dependencies(start)
| get_dependencies(stop)
| get_dependencies(step))
# {{{ find a usable loopy-side loop name
loopy_loop_var = loop_var
loop_var_suffix = None
while True:
already_used = False
for iset in scope.index_sets:
if loopy_loop_var in iset.get_var_dict(dim_type.set):
already_used = True
break
if loop_var_suffix is None:
loop_var_suffix = 0
if not already_used:
break
loop_var_suffix += 1
loopy_loop_var = loop_var + "_%d" % loop_var_suffix
if loop_var_suffix is None:
loop_var_suffix = 0
# }}}
loop_var_suffix += 1
loopy_loop_var = loop_var + "_%d" % loop_var_suffix
space = isl.Space.create_from_names(self.isl_context,
set=[loopy_loop_var], params=list(loop_bound_deps))
from loopy.isl_helpers import iname_rel_aff
from loopy.symbolic import aff_from_expr
index_set = (
isl.BasicSet.universe(space)
.add_constraint(
isl.Constraint.inequality_from_aff(
iname_rel_aff(space,
loopy_loop_var, ">=",
aff_from_expr(space, 0))))
.add_constraint(
isl.Constraint.inequality_from_aff(
iname_rel_aff(space,
loopy_loop_var, "<=",
aff_from_expr(space, stop-start)))))
from pymbolic import var
scope.active_iname_aliases[loop_var] = \
var(loopy_loop_var) + start
scope.active_loopy_inames.add(loopy_loop_var)
scope.index_sets.append(index_set)
for c in node.content:
self.rec(c)
del scope.active_iname_aliases[loop_var]
scope.active_loopy_inames.remove(loopy_loop_var)
# }}}
else:
raise NotImplementedError("unbounded do loop")
space = isl.Space.create_from_names(self.isl_context,
set=[loopy_loop_var], params=list(loop_bound_deps))
from loopy.isl_helpers import iname_rel_aff
from loopy.symbolic import aff_from_expr
index_set = (
isl.BasicSet.universe(space)
.add_constraint(
isl.Constraint.inequality_from_aff(
iname_rel_aff(space,
loopy_loop_var, ">=",
aff_from_expr(space, 0))))
.add_constraint(
isl.Constraint.inequality_from_aff(
iname_rel_aff(space,
loopy_loop_var, "<=",
aff_from_expr(space, stop-start)))))
from pymbolic import var
scope.active_iname_aliases[loop_var] = \
var(loopy_loop_var) + start
scope.active_loopy_inames.add(loopy_loop_var)
scope.index_sets.append(index_set)
self.block_nest.append("do")
for c in node.content:
self.rec(c)
del scope.active_iname_aliases[loop_var]
scope.active_loopy_inames.remove(loopy_loop_var)
def map_EndDo(self, node):
pass
if not self.block_nest:
raise TranslationError("no do loop started at end do")
if self.block_nest.pop() != "do":
raise TranslationError("mismatched end do")
def map_Continue(self, node):
raise NotImplementedError("continue")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment