From 8efe8657a3d9a7386d852d8f98c521b8ee2f5072 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 16 Mar 2016 19:11:54 -0500
Subject: [PATCH] More type system fixery regarding vectors

---
 loopy/expression.py      | 3 ++-
 loopy/target/cuda.py     | 3 ++-
 loopy/target/opencl.py   | 3 ++-
 loopy/target/pyopencl.py | 3 ++-
 4 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/loopy/expression.py b/loopy/expression.py
index a363013ea..40d1aab88 100644
--- a/loopy/expression.py
+++ b/loopy/expression.py
@@ -51,7 +51,8 @@ def dtype_to_type_context(target, dtype):
     if isinstance(dtype, NumpyType) and dtype.dtype in [np.float32, np.complex64]:
         return 'f'
     if target.is_vector_dtype(dtype):
-        return dtype_to_type_context(target, dtype.fields["x"][0])
+        return dtype_to_type_context(
+                target, NumpyType(dtype.numpy_dtype.fields["x"][0]))
 
     return None
 
diff --git a/loopy/target/cuda.py b/loopy/target/cuda.py
index 2fa510b5f..89d090b7a 100644
--- a/loopy/target/cuda.py
+++ b/loopy/target/cuda.py
@@ -202,7 +202,8 @@ class CudaTarget(CTarget):
         return result
 
     def is_vector_dtype(self, dtype):
-        return dtype.numpy_dtype in list(vec.types.values())
+        return (isinstance(dtype, NumpyType)
+                and dtype.numpy_dtype in list(vec.types.values()))
 
     def vector_dtype(self, base, count):
         return NumpyType(vec.types[base.numpy_dtype, count])
diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py
index db53b7cb2..1d60f4c8f 100644
--- a/loopy/target/opencl.py
+++ b/loopy/target/opencl.py
@@ -243,7 +243,8 @@ class OpenCLTarget(CTarget):
         return result
 
     def is_vector_dtype(self, dtype):
-        return dtype.numpy_dtype in list(vec.types.values())
+        return (isinstance(dtype, NumpyType)
+                and dtype.numpy_dtype in list(vec.types.values()))
 
     def vector_dtype(self, base, count):
         return NumpyType(vec.types[base.numpy_dtype, count])
diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py
index e14a1dd9d..e17f9515a 100644
--- a/loopy/target/pyopencl.py
+++ b/loopy/target/pyopencl.py
@@ -296,7 +296,8 @@ class PyOpenCLTarget(OpenCLTarget):
 
     def is_vector_dtype(self, dtype):
         from pyopencl.array import vec
-        return dtype in list(vec.types.values())
+        return (isinstance(dtype, NumpyType)
+                and dtype.numpy_dtype in list(vec.types.values()))
 
     def vector_dtype(self, base, count):
         from pyopencl.array import vec
-- 
GitLab