From 7c3ef798be5305ddf637550776d01c66c98e51f7 Mon Sep 17 00:00:00 2001 From: Emanuel Rietveld Date: Tue, 20 Feb 2018 08:04:30 +0900 Subject: [PATCH] Add get_texref() to ElementwiseKernel --- pycuda/elementwise.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pycuda/elementwise.py b/pycuda/elementwise.py index feab0a6b..695817fa 100644 --- a/pycuda/elementwise.py +++ b/pycuda/elementwise.py @@ -149,7 +149,7 @@ def get_elwise_kernel_and_types(arguments, operation, func = mod.get_function(name) func.prepare("".join(arg.struct_char for arg in arguments)) - return func, arguments + return mod, func, arguments def get_elwise_kernel(arguments, operation, @@ -157,7 +157,7 @@ def get_elwise_kernel(arguments, operation, """Return a L{pycuda.driver.Function} that performs the same scalar operation on one or several vectors. """ - func, arguments = get_elwise_kernel_and_types( + mod, func, arguments = get_elwise_kernel_and_types( arguments, operation, name, keep, options, **kwargs) return func @@ -171,9 +171,13 @@ class ElementwiseKernel: self.gen_kwargs.update(dict(keep=keep, options=options, name=name, operation=operation, arguments=arguments)) + def get_texref(self, name, use_range=False): + mod, knl, arguments = self.generate_stride_kernel_and_types(use_range=use_range) + return mod.get_texref(name) + @memoize_method def generate_stride_kernel_and_types(self, use_range): - knl, arguments = get_elwise_kernel_and_types(use_range=use_range, + mod, knl, arguments = get_elwise_kernel_and_types(use_range=use_range, **self.gen_kwargs) assert [i for i, arg in enumerate(arguments) @@ -181,7 +185,7 @@ class ElementwiseKernel: "ElementwiseKernel can only be used with functions that " \ "have at least one vector argument" - return knl, arguments + return mod, knl, arguments def __call__(self, *args, **kwargs): vectors = [] @@ -195,7 +199,7 @@ class ElementwiseKernel: + ", ".join(six.iterkeys(kwargs))) invocation_args = [] - func, arguments = self.generate_stride_kernel_and_types( + mod, func, arguments = self.generate_stride_kernel_and_types( range_ is not None or slice_ is not None) for arg, arg_descr in zip(args, arguments): -- GitLab