From 6009050800d64e9f19ce7d3cc56858bae5e2a41b Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 26 Aug 2021 16:18:07 -0500 Subject: [PATCH] BoundPyOpenCLProgram: accomodate more valid arguments --- pytato/target/loopy/__init__.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pytato/target/loopy/__init__.py b/pytato/target/loopy/__init__.py index 976a5f6..d800216 100644 --- a/pytato/target/loopy/__init__.py +++ b/pytato/target/loopy/__init__.py @@ -124,7 +124,9 @@ class BoundPyOpenCLProgram(BoundProgram): return self.copy(program=f(self.program)) def __call__(self, queue: "pyopencl.CommandQueue", # type: ignore - *args: Any, **kwargs: Any) -> Any: + allocator=None, wait_for=None, out_host=None, + entrypoint="_pt_kernel", + **kwargs: Any) -> Any: """Convenience function for launching a :mod:`pyopencl` computation.""" if set(kwargs.keys()) & set(self.bound_arguments.keys()): @@ -133,8 +135,6 @@ class BoundPyOpenCLProgram(BoundProgram): updated_kwargs = dict(self.bound_arguments) updated_kwargs.update(kwargs) - if not isinstance(self. program, loopy.LoopKernel): - updated_kwargs.setdefault("entrypoint", "_pt_kernel") # final DAG might be independent of certain placeholders, for ex. # '0 * x' results in a final loopy t-unit that is independent of the @@ -143,7 +143,10 @@ class BoundPyOpenCLProgram(BoundProgram): for kw, arg in updated_kwargs.items() if kw in self.program.default_entrypoint.arg_dict} - return self.program(queue, *args, **updated_kwargs) + return self.program(queue, + allocator=allocator, wait_for=wait_for, + out_host=out_host, entrypoint=entrypoint, + **updated_kwargs) @property def kernel(self) -> "loopy.LoopKernel": -- GitLab