From 8efe8657a3d9a7386d852d8f98c521b8ee2f5072 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 16 Mar 2016 19:11:54 -0500 Subject: [PATCH] More type system fixery regarding vectors --- loopy/expression.py | 3 ++- loopy/target/cuda.py | 3 ++- loopy/target/opencl.py | 3 ++- loopy/target/pyopencl.py | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/loopy/expression.py b/loopy/expression.py index a363013ea..40d1aab88 100644 --- a/loopy/expression.py +++ b/loopy/expression.py @@ -51,7 +51,8 @@ def dtype_to_type_context(target, dtype): if isinstance(dtype, NumpyType) and dtype.dtype in [np.float32, np.complex64]: return 'f' if target.is_vector_dtype(dtype): - return dtype_to_type_context(target, dtype.fields["x"][0]) + return dtype_to_type_context( + target, NumpyType(dtype.numpy_dtype.fields["x"][0])) return None diff --git a/loopy/target/cuda.py b/loopy/target/cuda.py index 2fa510b5f..89d090b7a 100644 --- a/loopy/target/cuda.py +++ b/loopy/target/cuda.py @@ -202,7 +202,8 @@ class CudaTarget(CTarget): return result def is_vector_dtype(self, dtype): - return dtype.numpy_dtype in list(vec.types.values()) + return (isinstance(dtype, NumpyType) + and dtype.numpy_dtype in list(vec.types.values())) def vector_dtype(self, base, count): return NumpyType(vec.types[base.numpy_dtype, count]) diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index db53b7cb2..1d60f4c8f 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -243,7 +243,8 @@ class OpenCLTarget(CTarget): return result def is_vector_dtype(self, dtype): - return dtype.numpy_dtype in list(vec.types.values()) + return (isinstance(dtype, NumpyType) + and dtype.numpy_dtype in list(vec.types.values())) def vector_dtype(self, base, count): return NumpyType(vec.types[base.numpy_dtype, count]) diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py index e14a1dd9d..e17f9515a 100644 --- a/loopy/target/pyopencl.py +++ b/loopy/target/pyopencl.py @@ -296,7 +296,8 @@ class PyOpenCLTarget(OpenCLTarget): def is_vector_dtype(self, dtype): from pyopencl.array import vec - return dtype in list(vec.types.values()) + return (isinstance(dtype, NumpyType) + and dtype.numpy_dtype in list(vec.types.values())) def vector_dtype(self, base, count): from pyopencl.array import vec -- GitLab