diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index f6263452d68b112964fdb78c25ebcb9100f35af5..aacfd826b063dac89490a0df023e77b841f21bfe 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -309,20 +309,38 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): try: pt_prg = self._freeze_prg_cache[normalized_expr] except KeyError: - if normalized_expr in self._dag_transform_cache: - transformed_dag = self._dag_transform_cache[normalized_expr] - else: + try: + transformed_dag, function_name = \ + self._dag_transform_cache[normalized_expr] + except KeyError: transformed_dag = self.transform_dag(normalized_expr) - self._dag_transform_cache[normalized_expr] = transformed_dag + + from pytato.tags import PrefixNamed + name_hint_tags = [] + for subary in key_to_pt_arrays.values(): + name_hint_tags.extend(subary.tags_of_type(PrefixNamed)) + + from pytools import common_prefix + name_hint = common_prefix([nh.prefix for nh in name_hint_tags]) + if name_hint: + # All name_hint_tags shared at least some common prefix. + function_name = f"frozen_{name_hint}" + else: + function_name = "frozen_result" + + self._dag_transform_cache[normalized_expr] = ( + transformed_dag, function_name) pt_prg = pt.generate_loopy(transformed_dag, options=lp.Options(return_dict=True, no_numpy=True), - cl_device=self.queue.device) + cl_device=self.queue.device, + function_name=function_name) pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program) self._freeze_prg_cache[normalized_expr] = pt_prg else: - transformed_dag = self._dag_transform_cache[normalized_expr] + transformed_dag, function_name = \ + self._dag_transform_cache[normalized_expr] assert len(pt_prg.bound_arguments) == 0 evt, out_dict = pt_prg(self.queue, **bound_arguments)