diff --git a/pyopencl/array.py b/pyopencl/array.py index 279705c25fffb3447b7757069dfe5a847bacb34a..f40cfbbcd65d27f776f1fd60e2c37d79d431f516 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -28,7 +28,7 @@ OTHER DEALINGS IN THE SOFTWARE. """ import six -from six.moves import range, zip, reduce +from six.moves import range, reduce import numpy as np import pyopencl.elementwise as elementwise @@ -44,9 +44,10 @@ from pyopencl.compyte.array import ( from pyopencl.characterize import has_double_support from pyopencl import cltypes + def _get_common_dtype(obj1, obj2, queue): return _get_common_dtype_base(obj1, obj2, - has_double_support(queue.device)) + has_double_support(queue.device)) # Work around PyPy not currently supporting the object dtype. @@ -70,10 +71,12 @@ class VecLookupWarner(object): DeprecationWarning, 2) return getattr(cltypes, name) + vec = VecLookupWarner() # {{{ helper functionality + def splay(queue, n, kernel_specific_max_wg_size=None): dev = queue.device max_work_items = _builtin_min(128, dev.max_work_group_size) diff --git a/pyopencl/cltypes.py b/pyopencl/cltypes.py index f15ab9d7c39da600ce7b2a811fb55b8538ffebba..c8ff35c378bd1eb395e54dc0efa5ce6a21ff9b85 100644 --- a/pyopencl/cltypes.py +++ b/pyopencl/cltypes.py @@ -43,14 +43,18 @@ half = np.float16 float = np.float32 double = np.float64 + # {{{ vector types def _create_vector_types(): - _mapping = [(k, globals()[k]) for k in ['char', 'uchar', 'short', 'ushort', 'int', - 'uint', 'long', 'ulong', 'float', 'double']] + _mapping = [(k, globals()[k]) for k in + ['char', 'uchar', 'short', 'ushort', 'int', + 'uint', 'long', 'ulong', 'float', 'double']] + def set_global(key, val): globals()[key] = val + field_names = ["x", "y", "z", "w"] set_global('types', {}) @@ -58,7 +62,6 @@ def _create_vector_types(): counts = [2, 3, 4, 8, 16] - for base_name, base_type in _mapping: for count in counts: name = "%s%d" % (base_name, count) @@ -71,15 +74,15 @@ def _create_vector_types(): names = ["s%d" % i for i in range(count)] while len(names) < padded_count: - names.append("padding%d" % (len(names)-count)) + names.append("padding%d" % (len(names) - count)) if len(titles) < len(names): - titles.extend((len(names)-len(titles))*[None]) + titles.extend((len(names) - len(titles)) * [None]) try: dtype = np.dtype(dict( names=names, - formats=[base_type]*padded_count, + formats=[base_type] * padded_count, titles=titles)) except NotImplementedError: try: @@ -97,31 +100,28 @@ def _create_vector_types(): if len(args) < count: from warnings import warn warn("default values for make_xxx are deprecated;" - " instead specify all parameters or use" - " cltypes.zeros_xxx", DeprecationWarning) - padded_args = tuple(list(args)+[0]*(padded_count-len(args))) + " instead specify all parameters or use" + " cltypes.zeros_xxx", DeprecationWarning) + padded_args = tuple(list(args) + [0] * (padded_count - len(args))) array = eval("array(padded_args, dtype=dtype)", - dict(array=np.array, padded_args=padded_args, - dtype=dtype)) + dict(array=np.array, padded_args=padded_args, + dtype=dtype)) for key, val in list(kwargs.items()): array[key] = val return array - set_global("make_"+name, eval( - "lambda *args, **kwargs: create_array(dtype, %i, %i, " - "*args, **kwargs)" % (count, padded_count), - dict(create_array=create_array, dtype=dtype))) - set_global("filled_"+name, eval( - "lambda val: make_%s(*[val]*%i)" % (name, count))) - set_global("zeros_"+name, eval("lambda: filled_%s(0)" % (name))) - set_global("ones_"+name, eval("lambda: filled_%s(1)" % (name))) + set_global("make_" + name, eval( + "lambda *args, **kwargs: create_array(dtype, %i, %i, " + "*args, **kwargs)" % (count, padded_count), + dict(create_array=create_array, dtype=dtype))) + set_global("filled_" + name, eval( + "lambda val: make_%s(*[val]*%i)" % (name, count))) + set_global("zeros_" + name, eval("lambda: filled_%s(0)" % (name))) + set_global("ones_" + name, eval("lambda: filled_%s(1)" % (name))) globals()['types'][np.dtype(base_type), count] = dtype globals()['type_to_scalar_and_count'][dtype] = np.dtype(base_type), count + _create_vector_types() # }}} - - - -