From 5500b90574aebceed4433924f298fa4ff006fd72 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 7 Jun 2022 19:57:24 -0500
Subject: [PATCH] Use NameHint/PrefixNamed to generate better kernel names in
 pytato freeze

---
 arraycontext/impl/pytato/__init__.py | 30 ++++++++++++++++++++++------
 1 file changed, 24 insertions(+), 6 deletions(-)

diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index f626345..aacfd82 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -309,20 +309,38 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
         try:
             pt_prg = self._freeze_prg_cache[normalized_expr]
         except KeyError:
-            if normalized_expr in self._dag_transform_cache:
-                transformed_dag = self._dag_transform_cache[normalized_expr]
-            else:
+            try:
+                transformed_dag, function_name = \
+                        self._dag_transform_cache[normalized_expr]
+            except KeyError:
                 transformed_dag = self.transform_dag(normalized_expr)
-                self._dag_transform_cache[normalized_expr] = transformed_dag
+
+                from pytato.tags import PrefixNamed
+                name_hint_tags = []
+                for subary in key_to_pt_arrays.values():
+                    name_hint_tags.extend(subary.tags_of_type(PrefixNamed))
+
+                from pytools import common_prefix
+                name_hint = common_prefix([nh.prefix for nh in name_hint_tags])
+                if name_hint:
+                    # All name_hint_tags shared at least some common prefix.
+                    function_name = f"frozen_{name_hint}"
+                else:
+                    function_name = "frozen_result"
+
+                self._dag_transform_cache[normalized_expr] = (
+                        transformed_dag, function_name)
 
             pt_prg = pt.generate_loopy(transformed_dag,
                                        options=lp.Options(return_dict=True,
                                                           no_numpy=True),
-                                       cl_device=self.queue.device)
+                                       cl_device=self.queue.device,
+                                       function_name=function_name)
             pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program)
             self._freeze_prg_cache[normalized_expr] = pt_prg
         else:
-            transformed_dag = self._dag_transform_cache[normalized_expr]
+            transformed_dag, function_name = \
+                    self._dag_transform_cache[normalized_expr]
 
         assert len(pt_prg.bound_arguments) == 0
         evt, out_dict = pt_prg(self.queue, **bound_arguments)
-- 
GitLab