From 344473d758dffdcf95bafb736697c77cfef332d4 Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Wed, 26 Jul 2023 16:57:53 -0500
Subject: [PATCH] Disable bounds checking when no loopy call present (#449)

* Disable bounds checking when no loopy call present

* add comment
---
 pytato/target/loopy/codegen.py | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py
index e76498b..3901f9d 100644
--- a/pytato/target/loopy/codegen.py
+++ b/pytato/target/loopy/codegen.py
@@ -342,6 +342,7 @@ class CodeGenMapper(Mapper):
     """A mapper for generating code for nodes in the computation graph.
     """
     exprgen_mapper: InlinedExpressionGenMapper
+    has_loopy_call: bool
 
     def __init__(self,
                  array_tag_t_to_not_propagate: FrozenSet[Type[Tag]],
@@ -349,6 +350,7 @@ class CodeGenMapper(Mapper):
         self.exprgen_mapper = InlinedExpressionGenMapper(self)
         self.array_tag_t_to_not_propagate = array_tag_t_to_not_propagate
         self.axis_tag_t_to_not_propagate = axis_tag_t_to_not_propagate
+        self.has_loopy_call = False
 
     def map_size_param(self, expr: SizeParam,
             state: CodeGenState) -> ImplementedResult:
@@ -463,6 +465,7 @@ class CodeGenMapper(Mapper):
         return state.results[expr]
 
     def map_loopy_call(self, expr: LoopyCall, state: CodeGenState) -> None:
+        self.has_loopy_call = True
         from loopy.kernel.instruction import make_assignment
         from loopy.symbolic import SubArrayRef
 
@@ -1084,6 +1087,11 @@ def generate_loopy(result: Union[Array, DictOfNamedArrays, Dict[str, Array]],
     # avoid such reduction iname collisions.
     t_unit = lp.make_reduction_inames_unique(state.t_unit)
 
+    # Disable bounds checking if there is no hand-written LoopyCall in the DAG.
+    if not cg_mapper.has_loopy_call:
+        t_unit = lp.set_options(t_unit,
+                                enforce_array_accesses_within_bounds="no_check")
+
     return target.bind_program(
             program=t_unit,
             bound_arguments=preproc_result.bound_arguments)
-- 
GitLab