From e499c28b75c95e3910aaca1a8ab14fac00bd534d Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 17 Aug 2016 15:56:55 -0500 Subject: [PATCH] Add test balloon for custom type array kernels --- pyopencl/elementwise.py | 3 ++- pyopencl/tools.py | 3 +++ test/test_array.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py index 60cc8856..9a6d794e 100644 --- a/pyopencl/elementwise.py +++ b/pyopencl/elementwise.py @@ -36,7 +36,7 @@ import numpy as np import pyopencl as cl from pytools import memoize_method from pyopencl.tools import (dtype_to_ctype, VectorArg, ScalarArg, - KernelTemplateBase) + KernelTemplateBase, dtype_to_c_struct) # {{{ elementwise kernel code generator @@ -740,6 +740,7 @@ def get_fill_kernel(context, dtype): "tp": dtype_to_ctype(dtype), }, "z[i] = a", + preamble=dtype_to_c_struct(context.devices[0], dtype), name="fill") diff --git a/pyopencl/tools.py b/pyopencl/tools.py index 2e24890b..528d45cd 100644 --- a/pyopencl/tools.py +++ b/pyopencl/tools.py @@ -646,6 +646,9 @@ def match_dtype_to_c_struct(device, name, dtype, context=None): @memoize def dtype_to_c_struct(device, dtype): + if dtype.fields is None: + return "" + matched_dtype, c_decl = match_dtype_to_c_struct( device, dtype_to_ctype(dtype), dtype) diff --git a/test/test_array.py b/test/test_array.py index 1876b081..71bd1a90 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -260,6 +260,35 @@ def test_custom_type_zeros(ctx_factory): assert np.array_equal(np.zeros(n, dtype), z) + +def test_custom_type_fill(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + from pyopencl.characterize import has_struct_arg_count_bug + if has_struct_arg_count_bug(queue.device): + pytest.skip("device has LLVM arg counting bug") + + dtype = np.dtype([ + ("cur_min", np.int32), + ("cur_max", np.int32), + ("pad", np.int32), + ]) + + from pyopencl.tools import get_or_register_dtype, match_dtype_to_c_struct + + name = "mmc_type" + dtype, c_decl = match_dtype_to_c_struct(queue.device, name, dtype) + dtype = get_or_register_dtype(name, dtype) + + n = 1000 + z_dev = cl.array.empty(queue, n, dtype=dtype) + z_dev.fill(np.zeros((), dtype)) + + z = z_dev.get() + + assert np.array_equal(np.zeros(n, dtype), z) + # }}} -- GitLab