diff --git a/doc/source/array.rst b/doc/source/array.rst index 56ab19951417432519c8bd3b990e08807b726d58..7e70dd824c6e0fdbcfdc860e04ae7d32f1597a24 100644 --- a/doc/source/array.rst +++ b/doc/source/array.rst @@ -358,7 +358,7 @@ Generating Arrays of Random Numbers .. module:: pyopencl.clrandom -.. function:: rand(context, queue, shape, dtype) +.. function:: rand(queue, shape, dtype) Return an array of `shape` filled with random values of `dtype` in the range [0,1). diff --git a/pyopencl/clrandom.py b/pyopencl/clrandom.py index c2cea99cfe41dff89009d538a17fdafbba9954ce..2b13047eef792e23bd023d238f74aa03b604c41f 100644 --- a/pyopencl/clrandom.py +++ b/pyopencl/clrandom.py @@ -241,13 +241,23 @@ def _rand(output, seed): def fill_rand(result): _rand(result, np.random.randint(2**31-1)) -def rand(context, queue, shape, dtype): - from pyopencl.array import Array - - result = Array(queue, shape, dtype) - _rand(result, np.random.randint(2**31-1)) - return result - +def rand(*args, **kwargs): + def inner_rand(queue, shape, dtype): + from pyopencl.array import Array + + result = Array(queue, shape, dtype) + _rand(result, np.random.randint(2**31-1)) + return result + + if isinstance(args[0], cl.Context): + from warnings import warn + warn("Passing a context as first argument is deprecated. " + "This will be continue to be accepted througout " + "versions 2011.x of PyOpenCL.", + DeprecationWarning, 2) + args = args[1:] + + return inner_rand(*args, **kwargs) if __name__ == "__main__": import sys