diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 32823286677d836872879b607589f740fe355e88..27037ac65de59a5abe7ad51c9dc06cc32c9ccbbb 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -105,6 +105,8 @@ class LeafArrayDescriptor(AbstractInputDescriptor): # }}} +# {{{ utilities + def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: """ Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an @@ -236,6 +238,10 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx): else: raise NotImplementedError(type(arg)) +# }}} + + +# {{{ BaseLazilyCompilingFunctionCaller @dataclass class BaseLazilyCompilingFunctionCaller: @@ -366,6 +372,10 @@ class BaseLazilyCompilingFunctionCaller: self.program_cache[arg_id_to_descr] = compiled_func return compiled_func(arg_id_to_arg) +# }}} + + +# {{{ LazilyPyOpenCLCompilingFunctionCaller class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): @property @@ -440,6 +450,8 @@ class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): return pytato_program, name_in_program_to_tags, name_in_program_to_axes +# }}} + # {{{ preserve back compat @@ -461,6 +473,8 @@ class LazilyCompilingFunctionCaller(LazilyPyOpenCLCompilingFunctionCaller): # }}} +# {{{ LazilyJAXCompilingFunctionCaller + class LazilyJAXCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): @property def compiled_function_returning_array_container_class( @@ -553,6 +567,10 @@ def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): return _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg) +# }}} + + +# {{{ compiled function class CompiledFunction(abc.ABC): """ @@ -582,6 +600,10 @@ class CompiledFunction(abc.ABC): """ pass +# }}} + + +# {{{ copmiled pyopencl function @dataclass(frozen=True) class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction): @@ -670,7 +692,10 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction): self.output_axes), tags=self.output_tags)) +# }}} + +# {{{ comiled jax function @dataclass(frozen=True) class CompiledJAXFunctionReturningArrayContainer(CompiledFunction): """ @@ -732,3 +757,7 @@ class CompiledJAXFunctionReturningArray(CompiledFunction): evt, out_dict = self.pytato_program(**input_kwargs_for_loopy) return self.actx.thaw(out_dict[self.output_name]) + +# }}} + +# vim: foldmethod=marker