Skip to content
Snippets Groups Projects
Commit 5500b905 authored by Andreas Klöckner's avatar Andreas Klöckner Committed by Andreas Klöckner
Browse files

Use NameHint/PrefixNamed to generate better kernel names in pytato freeze

parent 37e142e7
No related branches found
No related tags found
No related merge requests found
Pipeline #304468 passed
......@@ -309,20 +309,38 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
try:
pt_prg = self._freeze_prg_cache[normalized_expr]
except KeyError:
if normalized_expr in self._dag_transform_cache:
transformed_dag = self._dag_transform_cache[normalized_expr]
else:
try:
transformed_dag, function_name = \
self._dag_transform_cache[normalized_expr]
except KeyError:
transformed_dag = self.transform_dag(normalized_expr)
self._dag_transform_cache[normalized_expr] = transformed_dag
from pytato.tags import PrefixNamed
name_hint_tags = []
for subary in key_to_pt_arrays.values():
name_hint_tags.extend(subary.tags_of_type(PrefixNamed))
from pytools import common_prefix
name_hint = common_prefix([nh.prefix for nh in name_hint_tags])
if name_hint:
# All name_hint_tags shared at least some common prefix.
function_name = f"frozen_{name_hint}"
else:
function_name = "frozen_result"
self._dag_transform_cache[normalized_expr] = (
transformed_dag, function_name)
pt_prg = pt.generate_loopy(transformed_dag,
options=lp.Options(return_dict=True,
no_numpy=True),
cl_device=self.queue.device)
cl_device=self.queue.device,
function_name=function_name)
pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program)
self._freeze_prg_cache[normalized_expr] = pt_prg
else:
transformed_dag = self._dag_transform_cache[normalized_expr]
transformed_dag, function_name = \
self._dag_transform_cache[normalized_expr]
assert len(pt_prg.bound_arguments) == 0
evt, out_dict = pt_prg(self.queue, **bound_arguments)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment