diff --git a/examples/fortran/foo.floopy b/examples/fortran/foo.floopy index 7d2d3eef0570e448ad8e9e7113e2470e1d4ef64d..0db51ea86fe2f13f8c8c9d099efedc59c5ec68b9 100644 --- a/examples/fortran/foo.floopy +++ b/examples/fortran/foo.floopy @@ -2,7 +2,7 @@ subroutine fill(out, a, n) implicit none real*8 a, out(n) - integer n + integer n, i do i = 1, n out(i) = a diff --git a/examples/fortran/sparse.floopy b/examples/fortran/sparse.floopy index 924e0aa4abe51c4dd84b01cad0cb83b56122c97d..f3cb50290dd9f5fec78d93717fdd9b911d47e2b9 100644 --- a/examples/fortran/sparse.floopy +++ b/examples/fortran/sparse.floopy @@ -6,6 +6,7 @@ subroutine sparse(rowstarts, colindices, values, m, n, nvals, x, y) real*8 x(n), y(n), rowsum integer m, n, rowstart, rowend, length, nvals + integer i, j do i = 1, m rowstart = rowstarts(i) diff --git a/examples/fortran/tagging.floopy b/examples/fortran/tagging.floopy index f4b4e28eab3ddd544c791a088279718ef5221bb9..e7deb113a7676258fa2392a6b2b5c5e8075930b4 100644 --- a/examples/fortran/tagging.floopy +++ b/examples/fortran/tagging.floopy @@ -2,7 +2,7 @@ subroutine fill(out, a, n) implicit none real*8 a, out(n) - integer n + integer n, i !$loopy begin tagged: init do i = 1, n diff --git a/loopy/frontend/fortran/translator.py b/loopy/frontend/fortran/translator.py index abd925c14fe9ebf14342ac5ac25e2062291feefc..55e507dbca6be065a55eef56e9696c8ddcd6b5e7 100644 --- a/loopy/frontend/fortran/translator.py +++ b/loopy/frontend/fortran/translator.py @@ -35,6 +35,7 @@ from loopy.frontend.fortran.diagnostic import ( import islpy as isl from islpy import dim_type from loopy.symbolic import IdentityMapper +from loopy.diagnostic import LoopyError from pymbolic.primitives import Wildcard @@ -236,6 +237,8 @@ class F2LoopyTranslator(FTreeWalkerBase): self.filename = filename + self.index_dtype = None + def add_expression_instruction(self, lhs, rhs): scope = self.scope_stack[-1] @@ -487,6 +490,16 @@ class F2LoopyTranslator(FTreeWalkerBase): if node.loopcontrol: loop_var, loop_bounds = node.loopcontrol.split("=") loop_var = loop_var.strip() + + iname_dtype = scope.get_type(loop_var) + if self.index_dtype is None: + self.index_dtype = iname_dtype + else: + if self.index_dtype != iname_dtype: + raise LoopyError("type of '%s' (%s) does not agree with prior " + "index type (%s)" + % (loop_var, iname_dtype, self.index_dtype)) + scope.use_name(loop_var) loop_bounds = self.parse_expr( node, @@ -685,7 +698,8 @@ class F2LoopyTranslator(FTreeWalkerBase): sub.instructions, kernel_data, name=sub.subprogram_name, - default_order="F" + default_order="F", + index_dtype=self.index_dtype, ) from loopy.loop import fuse_loop_domains diff --git a/test/test_fortran.py b/test/test_fortran.py index d51411a6e65ba750ddf8cfc2ee844c4bc167ab82..a2167cdafcf152dca282a39fff77d8b829c18f62 100644 --- a/test/test_fortran.py +++ b/test/test_fortran.py @@ -51,7 +51,7 @@ def test_fill(ctx_factory): implicit none real*8 a, out(n) - integer n + integer n, i do i = 1, n out(i) = a @@ -80,7 +80,7 @@ def test_fill_const(ctx_factory): implicit none real*8 a, out(n) - integer n + integer n, i do i = 1, n out(i) = 3.45 @@ -102,7 +102,7 @@ def test_asterisk_in_shape(ctx_factory): implicit none real*8 a, out(n), out2(n), inp(*) - integer n + integer n, i do i = 1, n a = inp(n) @@ -127,7 +127,7 @@ def test_temporary_to_subst(ctx_factory): implicit none real*8 a, out(n), out2(n), inp(n) - integer n + integer n, i do i = 1, n a = inp(i) @@ -154,7 +154,7 @@ def test_temporary_to_subst_two_defs(ctx_factory): implicit none real*8 a, out(n), out2(n), inp(n) - integer n + integer n, i do i = 1, n a = inp(i) @@ -182,7 +182,7 @@ def test_temporary_to_subst_indices(ctx_factory): implicit none real*8 a(n), out(n), out2(n), inp(n) - integer n + integer n, i do i = 1, n a(i) = 6*inp(i) @@ -215,7 +215,7 @@ def test_if(ctx_factory): implicit none real*8 a, b, out(n), out2(n), inp(n) - integer n + integer n, i, j do i = 1, n a = inp(i) @@ -249,7 +249,7 @@ def test_tagged(ctx_factory): implicit none real*8 a, b, r, out(n), out2(n), inp(n), inp2(n) real*8 alpha - integer n + integer n, i do i = 1, n !$loopy begin tagged: input @@ -339,6 +339,7 @@ def test_batched_sparse(): real*8 x(n, nvecs), y(n, nvecs), rowsum(nvecs) integer m, n, rowstart, rowend, length, nvals, nvecs + integer i, j, k do i = 1, m rowstart = rowstarts(i)