From b886882b561b25ce4110ca448d692f7612c399dc Mon Sep 17 00:00:00 2001 From: Mit Kotak Date: Sat, 2 Jul 2022 13:09:01 -0500 Subject: [PATCH 1/2] added dtype default to gpuarray.zeros --- pycuda/gpuarray.py | 2 +- test/test_gpuarray.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pycuda/gpuarray.py b/pycuda/gpuarray.py index a3dbc100..73fac727 100644 --- a/pycuda/gpuarray.py +++ b/pycuda/gpuarray.py @@ -1240,7 +1240,7 @@ def to_gpu_async(ary, allocator=drv.mem_alloc, stream=None): empty = GPUArray -def zeros(shape, dtype, allocator=drv.mem_alloc, order="C"): +def zeros(shape, dtype=np.float64, allocator=drv.mem_alloc, order="C"): """Returns an array of the given shape and dtype filled with 0's.""" result = GPUArray(shape, dtype, allocator, order=order) zero = np.zeros((), dtype) diff --git a/test/test_gpuarray.py b/test/test_gpuarray.py index 73ec3ade..025d5795 100644 --- a/test/test_gpuarray.py +++ b/test/test_gpuarray.py @@ -1297,6 +1297,13 @@ class TestGPUArray: assert new_z.dtype == np.complex64 assert new_z.shape == arr.shape + @mark_cuda_test + def test_default_zero(self): + # This test was added to make sure that + # GPUArray.zeros was reverting to dtype = np.float64 by default + a_gpu = gpuarray.zeros((50000,)) + assert a_gpu.dtype == np.float64 + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the tests. -- GitLab From 14252930a1c89633a7f4c7c7ce6ce86e499ada3c Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 2 Jul 2022 23:23:17 +0000 Subject: [PATCH 2/2] Reduce test array size --- test/test_gpuarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_gpuarray.py b/test/test_gpuarray.py index 025d5795..45531210 100644 --- a/test/test_gpuarray.py +++ b/test/test_gpuarray.py @@ -1301,7 +1301,7 @@ class TestGPUArray: def test_default_zero(self): # This test was added to make sure that # GPUArray.zeros was reverting to dtype = np.float64 by default - a_gpu = gpuarray.zeros((50000,)) + a_gpu = gpuarray.zeros(10) assert a_gpu.dtype == np.float64 -- GitLab