From 40e33240b531e65cc912d22f4d39fd26860deb5f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Sat, 22 May 2021 15:40:29 -0500 Subject: [PATCH] Fix up Fortran division specialization for kernel callables --- loopy/frontend/fortran/translator.py | 4 +++- loopy/type_inference.py | 12 +++++++++++- test/test_fortran.py | 8 ++++---- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/loopy/frontend/fortran/translator.py b/loopy/frontend/fortran/translator.py index 8dcc32e00..af701cf8e 100644 --- a/loopy/frontend/fortran/translator.py +++ b/loopy/frontend/fortran/translator.py @@ -31,6 +31,7 @@ from loopy.frontend.fortran.tree import FTreeWalkerBase from loopy.diagnostic import warn_with_kernel from loopy.frontend.fortran.diagnostic import ( TranslationError, TranslatorWarning) +from loopy.translation_unit import for_each_kernel import islpy as isl from islpy import dim_type from loopy.symbolic import (IdentityMapper, RuleAwareIdentityMapper, @@ -268,7 +269,7 @@ class FortranDivisionSpecializer(RuleAwareIdentityMapper): def __init__(self, rule_mapping_context, kernel): super().__init__(rule_mapping_context) from loopy.type_inference import TypeInferenceMapper - self.infer_type = TypeInferenceMapper(kernel) + self.infer_type = TypeInferenceMapper(kernel, None) self.kernel = kernel def map_fortran_division(self, expr, *args): @@ -292,6 +293,7 @@ class FortranDivisionSpecializer(RuleAwareIdentityMapper): self.rec(expr.denominator, *args)) +@for_each_kernel def specialize_fortran_division(knl): rmc = SubstitutionRuleMappingContext( knl.substitutions, knl.get_var_name_generator()) diff --git a/loopy/type_inference.py b/loopy/type_inference.py index cfc04d096..92df0323a 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -191,7 +191,13 @@ class TypeInferenceMapper(CombineMapper): instances """ self.kernel = kernel - assert isinstance(clbl_inf_ctx, CallablesInferenceContext) + assert ( + # FIXME: HACK + # only used in kernel-local type inference for division + # specialization in Fortran + clbl_inf_ctx is None + + or isinstance(clbl_inf_ctx, CallablesInferenceContext)) if new_assignments is None: new_assignments = {} self.new_assignments = new_assignments @@ -417,6 +423,10 @@ class TypeInferenceMapper(CombineMapper): arg_id_to_dtype = {i: none_if_empty(self.rec(par)) for (i, par) in enumerate(expr.parameters)} + if self.clbl_inf_ctx is None: + raise LoopyError("TypeInferenceMapper was created without a " + "CallablesInferenceContext, but encountered a function call") + # specializing the known function wrt type in_knl_callable = self.clbl_inf_ctx[expr.function.name] diff --git a/test/test_fortran.py b/test/test_fortran.py index d6bb57162..45e83b384 100644 --- a/test/test_fortran.py +++ b/test/test_fortran.py @@ -667,13 +667,13 @@ def test_division_in_shapes(ctx_factory): end do end subroutine """ - knl, = lp.parse_fortran(fortran_src) - ref_knl = knl + t_unit = lp.parse_fortran(fortran_src) + ref_t_unit = t_unit - print(knl) + print(t_unit) ctx = ctx_factory() - lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(m=128)) + lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit, parameters=dict(m=128)) if __name__ == "__main__": -- GitLab