diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py index 60cc88566db5f06f77971c58489b1d4be99a505d..9a6d794eeaeffde8a60f7a44aa4ee3f80cca9a79 100644 --- a/pyopencl/elementwise.py +++ b/pyopencl/elementwise.py @@ -36,7 +36,7 @@ import numpy as np import pyopencl as cl from pytools import memoize_method from pyopencl.tools import (dtype_to_ctype, VectorArg, ScalarArg, - KernelTemplateBase) + KernelTemplateBase, dtype_to_c_struct) # {{{ elementwise kernel code generator @@ -740,6 +740,7 @@ def get_fill_kernel(context, dtype): "tp": dtype_to_ctype(dtype), }, "z[i] = a", + preamble=dtype_to_c_struct(context.devices[0], dtype), name="fill") diff --git a/pyopencl/tools.py b/pyopencl/tools.py index 2e24890bc476c25ad580e966ab17b88687addd5c..528d45cdf0f95b64ff3fface83b8e633a2ec51ce 100644 --- a/pyopencl/tools.py +++ b/pyopencl/tools.py @@ -646,6 +646,9 @@ def match_dtype_to_c_struct(device, name, dtype, context=None): @memoize def dtype_to_c_struct(device, dtype): + if dtype.fields is None: + return "" + matched_dtype, c_decl = match_dtype_to_c_struct( device, dtype_to_ctype(dtype), dtype) diff --git a/test/test_array.py b/test/test_array.py index 1876b081f4977d9935076ab49c6a69a9f345e65b..71bd1a9045d8f214f731d6a4a1462afd389a4edb 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -260,6 +260,35 @@ def test_custom_type_zeros(ctx_factory): assert np.array_equal(np.zeros(n, dtype), z) + +def test_custom_type_fill(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + from pyopencl.characterize import has_struct_arg_count_bug + if has_struct_arg_count_bug(queue.device): + pytest.skip("device has LLVM arg counting bug") + + dtype = np.dtype([ + ("cur_min", np.int32), + ("cur_max", np.int32), + ("pad", np.int32), + ]) + + from pyopencl.tools import get_or_register_dtype, match_dtype_to_c_struct + + name = "mmc_type" + dtype, c_decl = match_dtype_to_c_struct(queue.device, name, dtype) + dtype = get_or_register_dtype(name, dtype) + + n = 1000 + z_dev = cl.array.empty(queue, n, dtype=dtype) + z_dev.fill(np.zeros((), dtype)) + + z = z_dev.get() + + assert np.array_equal(np.zeros(n, dtype), z) + # }}}