From e44c99c4667f67203403222bf188ce461f07274c Mon Sep 17 00:00:00 2001 From: zzjjbb <31069326+zzjjbb@users.noreply.github.com> Date: Mon, 7 Dec 2020 19:41:45 -0500 Subject: [PATCH 1/2] add "out" parameter to GPUArray.conj(); add equivalent GPUArray.conjugate() method --- pycuda/elementwise.py | 5 +++-- pycuda/gpuarray.py | 11 ++++++++--- test/test_gpuarray.py | 4 ++++ 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/pycuda/elementwise.py b/pycuda/elementwise.py index 633edd78..258bae5c 100644 --- a/pycuda/elementwise.py +++ b/pycuda/elementwise.py @@ -606,11 +606,12 @@ def get_imag_kernel(dtype, real_dtype): @context_dependent_memoize -def get_conj_kernel(dtype): +def get_conj_kernel(dtype, conj_dtype): return get_elwise_kernel( - "%(tp)s *y, %(tp)s *z" + "%(tp)s *y, %(conj_tp)s *z" % { "tp": dtype_to_ctype(dtype), + "conj_tp": dtype_to_ctype(conj_dtype) }, "z[i] = pycuda::conj(y[i])", "conj", diff --git a/pycuda/gpuarray.py b/pycuda/gpuarray.py index a1a3f3f3..f5908a06 100644 --- a/pycuda/gpuarray.py +++ b/pycuda/gpuarray.py @@ -1141,7 +1141,7 @@ class GPUArray: else: return zeros_like(self) - def conj(self): + def conj(self, out=None): dtype = self.dtype if issubclass(self.dtype.type, np.complexfloating): if not self.flags.forc: @@ -1154,9 +1154,12 @@ class GPUArray: order = "F" else: order = "C" - result = self._new_like_me(order=order) + if out is None: + result = self._new_like_me(order=order) + else: + result = out - func = elementwise.get_conj_kernel(dtype) + func = elementwise.get_conj_kernel(dtype, result.dtype) func.prepared_async_call( self._grid, self._block, @@ -1170,6 +1173,8 @@ class GPUArray: else: return self + conjugate = conj + # }}} # {{{ rich comparisons diff --git a/test/test_gpuarray.py b/test/test_gpuarray.py index fc7b6736..d5d09251 100644 --- a/test/test_gpuarray.py +++ b/test/test_gpuarray.py @@ -732,6 +732,10 @@ class TestGPUArray: assert la.norm(z.get().real - z.real.get()) == 0 assert la.norm(z.get().imag - z.imag.get()) == 0 assert la.norm(z.get().conj() - z.conj().get()) == 0 + # verify conj with out parameter + z_out = z.astype(np.complex64) + assert z_out is z.conj(out=z_out) + assert la.norm(z.get().conj() - z_out.get()) < 1e-7 # verify contiguity is preserved for order in ["C", "F"]: -- GitLab From a5a34fc25ac0b465c197c9490b233556d2d4c4d4 Mon Sep 17 00:00:00 2001 From: Jiabei Zhu Date: Wed, 10 Feb 2021 04:18:37 -0500 Subject: [PATCH 2/2] change doc for conj/conjugate --- doc/source/array.rst | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/doc/source/array.rst b/doc/source/array.rst index 34efb1cf..3b12f04a 100644 --- a/doc/source/array.rst +++ b/doc/source/array.rst @@ -189,12 +189,26 @@ The :class:`GPUArray` Array Class .. versionadded: 0.94 - .. method :: conj() + .. method :: conj(out=None) - Return the complex conjugate of *self*, or *self* if it is real. + Return the complex conjugate of *self*, or *self* if it is real. If *out* + is not given, a newly allocated :class:`GPUArray` will returned. Use + *out=self* to get conjugate in-place. .. versionadded: 0.94 + .. versionchanged:: 2020.1.1 + + add *out* parameter + + + .. method :: conjugate(out=None) + + alias of :meth:`conj` + + .. versionadded:: 2020.1.1 + + .. method:: bind_to_texref(texref, allow_offset=False) Bind *self* to the :class:`pycuda.driver.TextureReference` *texref*. -- GitLab