From 799e2f19df4c462ba02e0910442102a96b63907f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 7 Jun 2022 19:26:48 -0500 Subject: [PATCH] Pass name of function being compiled to pt.generate_{jax,loopy} --- arraycontext/impl/pytato/compile.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 20c32ec..9ad952c 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -52,6 +52,15 @@ import logging logger = logging.getLogger(__name__) +def _func_to_kernel_name(f: Callable[..., Any]) -> str: + name = f.__name__ + if not name.isidentifier(): + return "actx_compiled_" + "".join( + ch for ch in name if ch.isidentifier()) + else: + return name + + class FromArrayContextCompile(Tag): """ Tagged to the entrypoint kernel of every translation unit that is generated @@ -374,12 +383,14 @@ class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): for name, out in pt_dict_of_named_arrays._data.items()} with ProcessLogger(logger, f"generate_loopy for '{self.f}'"): - pytato_program = pt.generate_loopy(pt_dict_of_named_arrays, - options=lp.Options( - return_dict=True, - no_numpy=True), - # pylint: disable=no-member - cl_device=self.actx.queue.device) + pytato_program = pt.generate_loopy( + pt_dict_of_named_arrays, + options=lp.Options( + return_dict=True, + no_numpy=True), + function_name=_func_to_kernel_name(self.f), + # pylint: disable=no-member + cl_device=self.actx.queue.device) assert isinstance(pytato_program, BoundPyOpenCLProgram) with ProcessLogger(logger, f"transform_loopy_program for '{self.f}'"): @@ -441,7 +452,10 @@ class LazilyJAXCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): for name, out in pt_dict_of_named_arrays._data.items()} with ProcessLogger(logger, "generate_jax"): - pytato_program = pt.generate_jax(pt_dict_of_named_arrays, jit=True) + pytato_program = pt.generate_jax( + pt_dict_of_named_arrays, + jit=True, + function_name=_func_to_kernel_name(self.f)) return pytato_program, name_in_program_to_tags, name_in_program_to_axes -- GitLab