diff --git a/pycuda/gpuarray.py b/pycuda/gpuarray.py index 420e5d7ad96e6eca54cdfeda25c1d21e21391146..e4bd10afbae6ecfb4f4c8a6ecf36d86ba5237236 100644 --- a/pycuda/gpuarray.py +++ b/pycuda/gpuarray.py @@ -385,6 +385,13 @@ class GPUArray: def __hash__(self): raise TypeError("GPUArrays are not hashable.") + def __bool__(self): + if self.size == 1: + return bool(self.get()) + else: + raise ValueError("The truth value of an array with " + "more than one element is ambiguous. Use a.any() or a.all()") + @property def ptr(self): return self.gpudata.__int__() @@ -732,7 +739,7 @@ class GPUArray: if len(self.shape): return self.shape[0] else: - return TypeError("scalar has no len()") + raise TypeError("len() of unsized object") def __abs__(self): """Return a `GPUArray` of the absolute values of the elements @@ -1114,7 +1121,10 @@ class GPUArray: ) def __setitem__(self, index, value): - _memcpy_discontig(self[index], value) + if np.isscalar(value): + self[index].fill(value) + else: + _memcpy_discontig(self[index], value) # }}} diff --git a/test/test_gpuarray.py b/test/test_gpuarray.py index 1d618e1ddde85c93f38be00092f536e8d96cb0dd..cce2de92f93bcc31d508b226655b06ca3a5b9cde 100644 --- a/test/test_gpuarray.py +++ b/test/test_gpuarray.py @@ -1366,6 +1366,20 @@ class TestGPUArray: gpuarray.logical_not(gpuarray.ones(10)).get(), np.logical_not(np.ones(10))) + def test_truth_value(self): + for i in range(5): + shape = (1,)*i + zeros = gpuarray.zeros(shape, dtype="float32") + ones = gpuarray.ones(shape, dtype="float32") + assert bool(ones) + assert not bool(zeros) + + def test_setitem_scalar(self): + a = gpuarray.zeros(5, "float64") + 42 + np.testing.assert_allclose(a.get(), 42) + a[...] = 1729 + np.testing.assert_allclose(a.get(), 1729) + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the tests.