diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 20c32ec23a3e83fbfadd53df2a12da85cb460f4d..9ad952c8046e90a4448ff4b346e4f429ffa2ccb7 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -52,6 +52,15 @@ import logging logger = logging.getLogger(__name__) +def _func_to_kernel_name(f: Callable[..., Any]) -> str: + name = f.__name__ + if not name.isidentifier(): + return "actx_compiled_" + "".join( + ch for ch in name if ch.isidentifier()) + else: + return name + + class FromArrayContextCompile(Tag): """ Tagged to the entrypoint kernel of every translation unit that is generated @@ -374,12 +383,14 @@ class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): for name, out in pt_dict_of_named_arrays._data.items()} with ProcessLogger(logger, f"generate_loopy for '{self.f}'"): - pytato_program = pt.generate_loopy(pt_dict_of_named_arrays, - options=lp.Options( - return_dict=True, - no_numpy=True), - # pylint: disable=no-member - cl_device=self.actx.queue.device) + pytato_program = pt.generate_loopy( + pt_dict_of_named_arrays, + options=lp.Options( + return_dict=True, + no_numpy=True), + function_name=_func_to_kernel_name(self.f), + # pylint: disable=no-member + cl_device=self.actx.queue.device) assert isinstance(pytato_program, BoundPyOpenCLProgram) with ProcessLogger(logger, f"transform_loopy_program for '{self.f}'"): @@ -441,7 +452,10 @@ class LazilyJAXCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): for name, out in pt_dict_of_named_arrays._data.items()} with ProcessLogger(logger, "generate_jax"): - pytato_program = pt.generate_jax(pt_dict_of_named_arrays, jit=True) + pytato_program = pt.generate_jax( + pt_dict_of_named_arrays, + jit=True, + function_name=_func_to_kernel_name(self.f)) return pytato_program, name_in_program_to_tags, name_in_program_to_axes