diff --git a/loopy/frontend/fortran/__init__.py b/loopy/frontend/fortran/__init__.py index 15c6a7dc31f33d0897c4d6cd5293a273a52d2ee0..c5c1943ca6a8e43be1f218920f21b4dc1f16c4e1 100644 --- a/loopy/frontend/fortran/__init__.py +++ b/loopy/frontend/fortran/__init__.py @@ -356,6 +356,9 @@ def parse_fortran(source, filename="", free_form=None, strict=None, # guesssing in the case of only one function prog = prog.with_entrypoints(all_kernels[0].name) + from loopy.frontend.fortran.translator import specialize_fortran_division + prog = specialize_fortran_division(prog) + parse_plog.done() return prog diff --git a/loopy/frontend/fortran/translator.py b/loopy/frontend/fortran/translator.py index af701cf8e352659e690cbd8c21c49ba70f582f01..f13f630985a750b3476d9c6e0521a190620c5bbe 100644 --- a/loopy/frontend/fortran/translator.py +++ b/loopy/frontend/fortran/translator.py @@ -31,7 +31,6 @@ from loopy.frontend.fortran.tree import FTreeWalkerBase from loopy.diagnostic import warn_with_kernel from loopy.frontend.fortran.diagnostic import ( TranslationError, TranslatorWarning) -from loopy.translation_unit import for_each_kernel import islpy as isl from islpy import dim_type from loopy.symbolic import (IdentityMapper, RuleAwareIdentityMapper, @@ -266,10 +265,10 @@ class FortranDivisionToFloorDiv(IdentityMapper): class FortranDivisionSpecializer(RuleAwareIdentityMapper): - def __init__(self, rule_mapping_context, kernel): + def __init__(self, rule_mapping_context, kernel, callables): super().__init__(rule_mapping_context) - from loopy.type_inference import TypeInferenceMapper - self.infer_type = TypeInferenceMapper(kernel, None) + from loopy.type_inference import TypeReader + self.infer_type = TypeReader(kernel, callables) self.kernel = kernel def map_fortran_division(self, expr, *args): @@ -293,11 +292,31 @@ class FortranDivisionSpecializer(RuleAwareIdentityMapper): self.rec(expr.denominator, *args)) -@for_each_kernel -def specialize_fortran_division(knl): +def _specialize_fortran_division_for_kernel(knl, callables): rmc = SubstitutionRuleMappingContext( knl.substitutions, knl.get_var_name_generator()) - return FortranDivisionSpecializer(rmc, knl).map_kernel(knl) + return FortranDivisionSpecializer(rmc, knl, callables).map_kernel(knl) + + +def specialize_fortran_division(t_unit): + from loopy.translation_unit import TranslationUnit, resolve_callables + from loopy.kernel.function_interface import CallableKernel + from loopy.type_inference import infer_unknown_types + assert isinstance(t_unit, TranslationUnit) + + t_unit = resolve_callables(t_unit) + t_unit = infer_unknown_types(t_unit) + new_callables = {} + + for name, clbl in t_unit.callables_table.items(): + if isinstance(clbl, CallableKernel): + knl = clbl.subkernel + clbl = clbl.copy(subkernel=_specialize_fortran_division_for_kernel( + knl, t_unit.callables_table)) + + new_callables[name] = clbl + + return t_unit.copy(callables_table=new_callables) # }}} @@ -904,9 +923,6 @@ class F2LoopyTranslator(FTreeWalkerBase): seq_dependencies=seq_dependencies, ) - if self.all_names_known: - knl = specialize_fortran_division(knl) - from loopy.loop import merge_loop_domains knl = merge_loop_domains(knl)