From 81b7ee4d8804a193076fcb4437f755040c42e3fa Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Sat, 22 May 2021 18:47:38 -0500
Subject: [PATCH] call specialize_fortran_division on the entire translation
 unit rather than the kernel

---
 loopy/frontend/fortran/__init__.py   |  3 +++
 loopy/frontend/fortran/translator.py | 36 ++++++++++++++++++++--------
 2 files changed, 29 insertions(+), 10 deletions(-)

diff --git a/loopy/frontend/fortran/__init__.py b/loopy/frontend/fortran/__init__.py
index 15c6a7dc3..c5c1943ca 100644
--- a/loopy/frontend/fortran/__init__.py
+++ b/loopy/frontend/fortran/__init__.py
@@ -356,6 +356,9 @@ def parse_fortran(source, filename="<floopy code>", free_form=None, strict=None,
         # guesssing in the case of only one function
         prog = prog.with_entrypoints(all_kernels[0].name)
 
+    from loopy.frontend.fortran.translator import specialize_fortran_division
+    prog = specialize_fortran_division(prog)
+
     parse_plog.done()
 
     return prog
diff --git a/loopy/frontend/fortran/translator.py b/loopy/frontend/fortran/translator.py
index af701cf8e..f13f63098 100644
--- a/loopy/frontend/fortran/translator.py
+++ b/loopy/frontend/fortran/translator.py
@@ -31,7 +31,6 @@ from loopy.frontend.fortran.tree import FTreeWalkerBase
 from loopy.diagnostic import warn_with_kernel
 from loopy.frontend.fortran.diagnostic import (
         TranslationError, TranslatorWarning)
-from loopy.translation_unit import for_each_kernel
 import islpy as isl
 from islpy import dim_type
 from loopy.symbolic import (IdentityMapper, RuleAwareIdentityMapper,
@@ -266,10 +265,10 @@ class FortranDivisionToFloorDiv(IdentityMapper):
 
 
 class FortranDivisionSpecializer(RuleAwareIdentityMapper):
-    def __init__(self, rule_mapping_context, kernel):
+    def __init__(self, rule_mapping_context, kernel, callables):
         super().__init__(rule_mapping_context)
-        from loopy.type_inference import TypeInferenceMapper
-        self.infer_type = TypeInferenceMapper(kernel, None)
+        from loopy.type_inference import TypeReader
+        self.infer_type = TypeReader(kernel, callables)
         self.kernel = kernel
 
     def map_fortran_division(self, expr, *args):
@@ -293,11 +292,31 @@ class FortranDivisionSpecializer(RuleAwareIdentityMapper):
                     self.rec(expr.denominator, *args))
 
 
-@for_each_kernel
-def specialize_fortran_division(knl):
+def _specialize_fortran_division_for_kernel(knl, callables):
     rmc = SubstitutionRuleMappingContext(
             knl.substitutions, knl.get_var_name_generator())
-    return FortranDivisionSpecializer(rmc, knl).map_kernel(knl)
+    return FortranDivisionSpecializer(rmc, knl, callables).map_kernel(knl)
+
+
+def specialize_fortran_division(t_unit):
+    from loopy.translation_unit import TranslationUnit, resolve_callables
+    from loopy.kernel.function_interface import CallableKernel
+    from loopy.type_inference import infer_unknown_types
+    assert isinstance(t_unit, TranslationUnit)
+
+    t_unit = resolve_callables(t_unit)
+    t_unit = infer_unknown_types(t_unit)
+    new_callables = {}
+
+    for name, clbl in t_unit.callables_table.items():
+        if isinstance(clbl, CallableKernel):
+            knl = clbl.subkernel
+            clbl = clbl.copy(subkernel=_specialize_fortran_division_for_kernel(
+                    knl, t_unit.callables_table))
+
+        new_callables[name] = clbl
+
+    return t_unit.copy(callables_table=new_callables)
 
 # }}}
 
@@ -904,9 +923,6 @@ class F2LoopyTranslator(FTreeWalkerBase):
                     seq_dependencies=seq_dependencies,
                     )
 
-            if self.all_names_known:
-                knl = specialize_fortran_division(knl)
-
             from loopy.loop import merge_loop_domains
             knl = merge_loop_domains(knl)
 
-- 
GitLab