diff --git a/pyopencl/array.py b/pyopencl/array.py index 71ce3d8c2ee82d868315afc2fc4417d139f3f022..8d29b6c1d74db82d4230cb3f513a41e820bad7a8 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -1296,8 +1296,14 @@ class Array(object): if isinstance(shape[0], tuple) or isinstance(shape[0], list): shape = tuple(shape[0]) - if any(s < 0 for s in shape): - raise NotImplementedError("negative/automatic shapes not supported") + if -1 in shape: + shape = list(shape) + idx = shape.index(-1) + size = -reduce(lambda x, y: x * y, shape, 1) + shape[idx] = self.size // size + if any(s < 0 for s in shape): + raise ValueError("can only specify one unknown dimension") + shape = tuple(shape) if shape == self.shape: return self._new_with_changes( diff --git a/test/test_array.py b/test/test_array.py index 42d5f74e7ffcb1f3298dc048500ad1518497b4b7..acb82ec8876b0b96867379f9c129a24c3d15ccb9 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -776,6 +776,28 @@ def test_event_management(ctx_factory): assert len(x.events) < 100 +def test_reshape(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + a = np.arange(128).reshape(8, 16).astype(np.float32) + a_dev = cl_array.to_device(queue, a) + + # different ways to specify the shape + a_dev.reshape(4, 32) + a_dev.reshape((4, 32)) + a_dev.reshape([4, 32]) + + # using -1 as unknown dimension + assert a_dev.reshape(-1, 32).shape == (4, 32) + assert a_dev.reshape((32, -1)).shape == (32, 4) + assert a_dev.reshape(((8, -1, 4))).shape == (8, 4, 4) + + import pytest + with pytest.raises(ValueError): + a_dev.reshape(-1, -1, 4) + + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the # tests.