diff --git a/pyopencl/array.py b/pyopencl/array.py index 450aec8e14255dd2f5d08226c0265b3ae2265a4c..30555a9f0aff4def6306e758f3d433eecdfdfb94 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -76,22 +76,22 @@ def _get_common_dtype(obj1, obj2): @decorator def elwise_kernel_runner(kernel_getter, *args, **kwargs): - """Take a kernel getter of the same signature as the kernel + """Take a kernel getter of the same signature as the kernel and return a function that invokes that kernel. Assumes that the zeroth entry in *args* is an :class:`Array`. """ - knl = kernel_getter(*args) - + # The decorators module converts kwargs to positional arguments, + # so we pop the queue argument first + args = list(args) repr_ary = args[0] - assert isinstance(repr_ary, Array) + queue = args.pop() or repr_ary.queue - queue = kwargs.pop("queue", None) or repr_ary.queue gs, ls = repr_ary.get_sizes(queue) + knl = kernel_getter(*args) - if kwargs: - raise TypeError("only the 'queue' keyword argument is supported") + assert isinstance(repr_ary, Array) actual_args = [] for arg in args: @@ -205,20 +205,21 @@ class Array(object): raise TypeError("pyopencl arrays are not hashable.") # kernel invocation wrappers ---------------------------------------------- - @elwise_kernel_runner @staticmethod + @elwise_kernel_runner def _axpbyz(out, afac, a, bfac, b, queue=None): """Compute ``out = selffac * self + otherfac*other``, where `other` is a vector..""" - assert self.shape == other.shape + assert out.shape == a.shape return elementwise.get_axpbyz_kernel( - self.context, self.dtype, other.dtype, out.dtype) + out.context, a.dtype, b.dtype, out.dtype) + @staticmethod @elwise_kernel_runner - def _axpbz(self, out, selffac, other, queue=None): - """Compute ``out = selffac * self + other``, where `other` is a scalar.""" - return elementwise.get_axpbz_kernel(self.context, self.dtype) + def _axpbz(out, afac, a, other, queue=None): + """Compute ``out = afac * a + other``, where `other` is a scalar.""" + return elementwise.get_axpbz_kernel(out.context, out.dtype) @elwise_kernel_runner def _elwise_multiply(self, out, other, queue=None): @@ -297,7 +298,7 @@ class Array(object): if isinstance(other, Array): result = self._new_like_me(_get_common_dtype(self, other)) - self._axpbyz(result, 1, other, -1) + self._axpbyz(result, 1, self, -1, other) return result else: if other == 0: