From 3a61e3688785b9e30119c9163dc27d8382a0a0d0 Mon Sep 17 00:00:00 2001 From: Jonathan Mackenzie Date: Thu, 15 Dec 2016 14:24:26 +1030 Subject: [PATCH] fixed some bugs and cleaned up the code --- pyopencl/cltypes.py | 44 ++++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/pyopencl/cltypes.py b/pyopencl/cltypes.py index fc83e730..e0d19e57 100644 --- a/pyopencl/cltypes.py +++ b/pyopencl/cltypes.py @@ -20,7 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -import numpy as __np +import numpy as np from pyopencl.tools import get_or_register_dtype import warnings @@ -31,21 +31,24 @@ if __file__.endswith('array.py'): This file provides a type mapping from OpenCl type names to their numpy equivalents """ -char = __np.int8 -uchar = __np.uint8 -short = __np.int16 -ushort = __np.uint16 -int = __np.int32 -uint = __np.uint32 -long = __np.int64 -ulong = __np.uint64 -float = __np.float32 -double = __np.float64 +char = np.int8 +uchar = np.uint8 +short = np.int16 +ushort = np.uint16 +int = np.int32 +uint = np.uint32 +long = np.int64 +ulong = np.uint64 +half = np.float16 +float = np.float32 +double = np.float64 # {{{ vector types def _create_vector_types(): + _mapping = [(k, globals()[k]) for k in ['char', 'uchar', 'short', 'ushort', 'int', + 'uint', 'long', 'ulong', 'float', 'double']] def set_global(key, val): globals()[key] = val field_names = ["x", "y", "z", "w"] @@ -55,7 +58,8 @@ def _create_vector_types(): counts = [2, 3, 4, 8, 16] - for base_name, base_type in [(k, v) for k, v in globals().items() if not k.startswith('__')]: + + for base_name, base_type in _mapping: for count in counts: name = "%s%d" % (base_name, count) @@ -73,16 +77,16 @@ def _create_vector_types(): titles.extend((len(names)-len(titles))*[None]) try: - dtype = __np.dtype(dict( + dtype = np.dtype(dict( names=names, formats=[base_type]*padded_count, titles=titles)) except NotImplementedError: try: - dtype = __np.dtype([((n, title), base_type) + dtype = np.dtype([((n, title), base_type) for (n, title) in zip(names, titles)]) except TypeError: - dtype = __np.dtype([(n, base_type) for (n, title) + dtype = np.dtype([(n, base_type) for (n, title) in zip(names, titles)]) get_or_register_dtype(name, dtype) @@ -97,7 +101,7 @@ def _create_vector_types(): " array.vec.zeros_xxx", DeprecationWarning) padded_args = tuple(list(args)+[0]*(padded_count-len(args))) array = eval("array(padded_args, dtype=dtype)", - dict(array=__np.array, padded_args=padded_args, + dict(array=np.array, padded_args=padded_args, dtype=dtype)) for key, val in list(kwargs.items()): array[key] = val @@ -114,12 +118,12 @@ def _create_vector_types(): set_global("ones_"+name, staticmethod(eval("lambda: vec.filled_%s(1)" % (name)))) - globals()['types'][__np.dtype(base_type), count] = dtype - globals()['type_to_scalar_and_count'][dtype] = __np.dtype(base_type), count - + globals()['types'][np.dtype(base_type), count] = dtype + globals()['type_to_scalar_and_count'][dtype] = np.dtype(base_type), count +_create_vector_types() # }}} -half = __np.float16 + -- GitLab