diff --git a/pyopencl/array.py b/pyopencl/array.py index 0e20645c438a98b07739c7c84c23a2807b285471..d3ef426fb1b8410167dd8a6240a6060d4a9161b5 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -161,7 +161,7 @@ def elwise_kernel_runner(kernel_getter): else: wait_for = list(wait_for) - knl = kernel_getter(*args) + knl = kernel_getter(*args, **kwargs) gs, ls = repr_ary.get_sizes(queue, knl.get_work_group_info( @@ -352,6 +352,26 @@ class Array(object): .. automethod :: setitem + .. rubric:: Comparisons, conditionals, any, all + + .. versionadded:: 2013.2 + + Boolean arrays are stored as :class:`numpy.int8` because ``bool`` + has an unspecified size in the OpenCL spec. + + .. automethod :: __bool__ + + Only works for device scalars. (i.e. "arrays" with ``shape == ()``.) + + .. automethod :: any + .. automethod :: all + + .. automethod :: __eq__ + .. automethod :: __ne__ + .. automethod :: __lt__ + .. automethod :: __le__ + .. automethod :: __gt__ + .. automethod :: __ge__ """ __array_priority__ = 100 @@ -985,25 +1005,98 @@ class Array(object): self._copy(result, self, queue=queue) return result - # {{{ rich comparisons (or rather, lack thereof) + # {{{ rich comparisons, any, all + + def __nonzero__(self): + if self.shape == (): + return bool(self.get()) + else: + raise ValueError("The truth value of an array with " + "more than one element is ambiguous. Use a.any() or a.all()") + + def any(self, queue=None, wait_for=None): + from pyopencl.reduction import get_any_kernel + krnl = get_any_kernel(self.context, self.dtype) + return krnl(self, queue=queue, wait_for=wait_for) + + def all(self, queue=None, wait_for=None): + from pyopencl.reduction import get_all_kernel + krnl = get_all_kernel(self.context, self.dtype) + return krnl(self, queue=queue, wait_for=wait_for) + + @staticmethod + @elwise_kernel_runner + def _scalar_comparison(out, a, b, queue=None, op=None): + return elementwise.get_array_scalar_comparison_kernel( + out.context, op, a.dtype) + + @staticmethod + @elwise_kernel_runner + def _array_comparison(out, a, b, queue=None, op=None): + if a.shape != b.shape: + raise ValueError("shapes of comparison arguments do not match") + return elementwise.get_array_comparison_kernel( + out.context, op, a.dtype, b.dtype) def __eq__(self, other): - raise NotImplementedError + if isinstance(other, Array): + result = self._new_like_me(np.int8) + self._array_comparison(result, self, other, op="==") + return result + else: + result = self._new_like_me(np.int8) + self._scalar_comparison(result, self, other, op="==") + return result def __ne__(self, other): - raise NotImplementedError + if isinstance(other, Array): + result = self._new_like_me(np.int8) + self._array_comparison(result, self, other, op="!=") + return result + else: + result = self._new_like_me(np.int8) + self._scalar_comparison(result, self, other, op="!=") + return result def __le__(self, other): - raise NotImplementedError + if isinstance(other, Array): + result = self._new_like_me(np.int8) + self._array_comparison(result, self, other, op="<=") + return result + else: + result = self._new_like_me(np.int8) + self._scalar_comparison(result, self, other, op="<=") + return result def __ge__(self, other): - raise NotImplementedError + if isinstance(other, Array): + result = self._new_like_me(np.int8) + self._array_comparison(result, self, other, op=">=") + return result + else: + result = self._new_like_me(np.int8) + self._scalar_comparison(result, self, other, op=">=") + return result def __lt__(self, other): - raise NotImplementedError + if isinstance(other, Array): + result = self._new_like_me(np.int8) + self._array_comparison(result, self, other, op="<") + return result + else: + result = self._new_like_me(np.int8) + self._scalar_comparison(result, self, other, op="<") + return result def __gt__(self, other): - raise NotImplementedError + if isinstance(other, Array): + result = self._new_like_me(np.int8) + self._array_comparison(result, self, other, op=">") + return result + else: + result = self._new_like_me(np.int8) + self._scalar_comparison(result, self, other, op=">") + return result # }}} diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py index 3e63f4e0b1fa46b3b6ef6687b4c5adf76621efde..8bcd174dd5a7741fd043ad46cacb2c46a55a2bfb 100644 --- a/pyopencl/elementwise.py +++ b/pyopencl/elementwise.py @@ -786,6 +786,28 @@ def get_pow_kernel(context, dtype_x, dtype_y, dtype_z, name="pow_method") +@context_dependent_memoize +def get_array_scalar_comparison_kernel(context, operator, dtype_a): + return get_elwise_kernel(context, [ + VectorArg(np.int8, "out", with_offset=True), + VectorArg(dtype_a, "a", with_offset=True), + ScalarArg(dtype_a, "b"), + ], + "out[i] = a[i] %s b" % operator, + name="scalar_comparison_kernel") + + +@context_dependent_memoize +def get_array_comparison_kernel(context, operator, dtype_a, dtype_b): + return get_elwise_kernel(context, [ + VectorArg(np.int8, "out", with_offset=True), + VectorArg(dtype_a, "a", with_offset=True), + VectorArg(dtype_b, "b", with_offset=True), + ], + "out[i] = a[i] %s b[i]" % operator, + name="comparison_kernel") + + @context_dependent_memoize def get_fmod_kernel(context): return get_elwise_kernel(context, diff --git a/pyopencl/reduction.py b/pyopencl/reduction.py index d09e612ea19a83376ff5253d66ebfacdc9d35fa6..ac07fbd4e3d324e66132b5e1dc8f98e9bf36cce3 100644 --- a/pyopencl/reduction.py +++ b/pyopencl/reduction.py @@ -35,7 +35,7 @@ import pyopencl as cl from pyopencl.tools import ( context_dependent_memoize, dtype_to_ctype, KernelTemplateBase, - _process_code_for_macro) + _process_code_for_macro, VectorArg) import numpy as np @@ -229,8 +229,7 @@ def get_reduction_kernel(stage, map_expr = "in[i]" from pyopencl.tools import ( - parse_arg_list, get_arg_list_scalar_arg_dtypes, - VectorArg) + parse_arg_list, get_arg_list_scalar_arg_dtypes) if arguments is not None: arguments = parse_arg_list(arguments) @@ -299,7 +298,6 @@ class ReductionKernel: max_group_size=max_group_size) from pytools import any - from pyopencl.tools import VectorArg assert any( isinstance(arg_tp, VectorArg) for arg_tp in self.stage_1_inf.arg_types), \ @@ -327,7 +325,6 @@ class ReductionKernel: invocation_args = [] vectors = [] - from pyopencl.tools import VectorArg for arg, arg_tp in zip(args, stage_inf.arg_types): if isinstance(arg_tp, VectorArg): if not arg.flags.forc: @@ -429,6 +426,20 @@ class ReductionTemplate(KernelTemplateBase): # {{{ array reduction kernel getters +@context_dependent_memoize +def get_any_kernel(ctx, dtype_in): + return ReductionKernel(ctx, np.int8, "false", "a || b", + map_expr="(bool) (in[i])", + arguments=[VectorArg(dtype_in, "in")]) + + +@context_dependent_memoize +def get_all_kernel(ctx, dtype_in): + return ReductionKernel(ctx, np.int8, "true", "a && b", + map_expr="(bool) (in[i])", + arguments=[VectorArg(dtype_in, "in")]) + + @context_dependent_memoize def get_sum_kernel(ctx, dtype_out, dtype_in): if dtype_out is None: diff --git a/test/test_array.py b/test/test_array.py index bbc63f3a36f8c90a1ee96bc525bec5c45dda80fb..3ff5030b9fb405729886bf77d2eca116e3415a6a 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -555,6 +555,8 @@ def test_view(ctx_factory): # }}} +# {{{ slices, concatenation + @pytools.test.mark_test.opencl def test_slice(ctx_factory): context = ctx_factory() @@ -616,6 +618,66 @@ def test_concatenate(ctx_factory): assert la.norm(cat - cat_dev.get()) == 0 +# }}} + + +# {{{ conditionals, any, all + +@pytools.test.mark_test.opencl +def test_comparisons(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + from pyopencl.clrandom import rand as clrand + + l = 20000 + a_dev = clrand(queue, (l,), dtype=np.float32) + b_dev = clrand(queue, (l,), dtype=np.float32) + + a = a_dev.get() + b = b_dev.get() + + import operator as o + for op in [o.eq, o.ne, o.le, o.lt, o.ge, o.gt]: + res_dev = op(a_dev, b_dev) + res = op(a, b) + + assert (res_dev.get() == res).all() + + res_dev = op(a_dev, 0) + res = op(a, 0) + + assert (res_dev.get() == res).all() + + res_dev = op(0, b_dev) + res = op(0, b) + + assert (res_dev.get() == res).all() + + +@pytools.test.mark_test.opencl +def test_any_all(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + l = 20000 + a_dev = cl_array.zeros(queue, (l,), dtype=np.int8) + + assert not a_dev.all() + assert not a_dev.any() + + a_dev[15213] = 1 + + assert not a_dev.all() + assert a_dev.any() + + a_dev.fill(1) + + assert a_dev.all() + assert a_dev.any() + +# }}} + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the