From 9329ba9cf26b5c2a729f0d827aa0c95fba00744d Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 12 Jun 2013 00:38:01 -0400 Subject: [PATCH] PEP8 test_algorithm --- test/test_algorithm.py | 56 +++++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/test/test_algorithm.py b/test/test_algorithm.py index b2ed6c6f..ef28a015 100644 --- a/test/test_algorithm.py +++ b/test/test_algorithm.py @@ -32,15 +32,12 @@ from test_array import general_clrand import pytest import pyopencl as cl -import pyopencl.array as cl_array # noqa -from pyopencl.tools import pytest_generate_tests_for_pyopencl \ - as pytest_generate_tests +import pyopencl.array as cl_array # noqa +from pyopencl.tools import ( # noqa + pytest_generate_tests_for_pyopencl as pytest_generate_tests) from pyopencl.characterize import has_double_support - - - # {{{ elementwise @pytools.test.mark_test.opencl @@ -121,6 +118,7 @@ def test_ranged_elwise_kernel(ctx_factory): assert (a_cpu == a_gpu.get()).all() + @pytools.test.mark_test.opencl def test_take(ctx_factory): context = ctx_factory() @@ -155,6 +153,7 @@ def test_reverse(ctx_factory): assert (a[::-1] == a_gpu.get()).all() + @pytools.test.mark_test.opencl def test_if_positive(ctx_factory): context = ctx_factory() @@ -229,6 +228,7 @@ def test_astype(ctx_factory): # }}} + # {{{ reduction @pytools.test.mark_test.opencl @@ -350,6 +350,7 @@ def test_dot(ctx_factory): assert abs(vdot_ab_gpu - vdot_ab) / abs(vdot_ab) < 1e-4 + @memoize def make_mmc_dtype(device): dtype = np.dtype([ @@ -366,6 +367,7 @@ def make_mmc_dtype(device): return dtype, c_decl + @pytools.test.mark_test.opencl def test_struct_reduce(ctx_factory): pytest.importorskip("mako") @@ -430,6 +432,7 @@ def test_struct_reduce(ctx_factory): # }}} + # {{{ scan-related def summarize_error(obtained, desired, orig, thresh=1e-5): @@ -458,8 +461,8 @@ def summarize_error(obtained, desired, orig, thresh=1e-5): bad_count += 1 if bad_count < bad_limit: - entries.append("%r (want: %r, got: %r, orig: %r)" % (obtained[i], desired[i], - obtained[i], orig[i])) + entries.append("%r (want: %r, got: %r, orig: %r)" % ( + obtained[i], desired[i], obtained[i], orig[i])) else: if bad_count: summarize_counts() @@ -467,7 +470,6 @@ def summarize_error(obtained, desired, orig, thresh=1e-5): ok_count += 1 - summarize_counts() return " ".join(entries) @@ -491,6 +493,7 @@ scan_test_counts = [ # larger sizes cause out of memory on low-end AMD APUs ] + @pytools.test.mark_test.opencl def test_scan(ctx_factory): from pytest import importorskip @@ -512,7 +515,8 @@ def test_scan(ctx_factory): host_data = np.random.randint(0, 10, n).astype(dtype) dev_data = cl_array.to_device(queue, host_data) - assert (host_data == dev_data.get()).all() # /!\ fails on Nv GT2?? for some drivers + # /!\ fails on Nv GT2?? for some drivers + assert (host_data == dev_data.get()).all() knl(dev_data) @@ -530,6 +534,7 @@ def test_scan(ctx_factory): from gc import collect collect() + @pytools.test.mark_test.opencl def test_copy_if(ctx_factory): from pytest import importorskip @@ -546,13 +551,15 @@ def test_copy_if(ctx_factory): from pyopencl.algorithm import copy_if crit = a_dev.dtype.type(300) - selected = a[a>crit] - selected_dev, count_dev, evt = copy_if(a_dev, "ary[i] > myval", [("myval", crit)]) + selected = a[a > crit] + selected_dev, count_dev, evt = copy_if( + a_dev, "ary[i] > myval", [("myval", crit)]) assert (selected_dev.get()[:count_dev.get()] == selected).all() from gc import collect collect() + @pytools.test.mark_test.opencl def test_partition(ctx_factory): from pytest import importorskip @@ -569,17 +576,19 @@ def test_partition(ctx_factory): a = a_dev.get() crit = a_dev.dtype.type(300) - true_host = a[a>crit] - false_host = a[a<=crit] + true_host = a[a > crit] + false_host = a[a <= crit] from pyopencl.algorithm import partition - true_dev, false_dev, count_true_dev, evt = partition(a_dev, "ary[i] > myval", [("myval", crit)]) + true_dev, false_dev, count_true_dev, evt = partition( + a_dev, "ary[i] > myval", [("myval", crit)]) count_true_dev = count_true_dev.get() assert (true_dev.get()[:count_true_dev] == true_host).all() assert (false_dev.get()[:n-count_true_dev] == false_host).all() + @pytools.test.mark_test.opencl def test_unique(ctx_factory): from pytest import importorskip @@ -606,6 +615,7 @@ def test_unique(ctx_factory): from gc import collect collect() + @pytools.test.mark_test.opencl def test_index_preservation(ctx_factory): from pytest import importorskip @@ -639,6 +649,7 @@ def test_index_preservation(ctx_factory): from gc import collect collect() + @pytools.test.mark_test.opencl def test_segmented_scan(ctx_factory): from pytest import importorskip @@ -660,8 +671,8 @@ def test_segmented_scan(ctx_factory): from pyopencl.scan import GenericScanKernel knl = GenericScanKernel(context, dtype, - arguments="__global %s *ary, __global char *segflags, __global %s *out" - % (ctype, ctype), + arguments="__global %s *ary, __global char *segflags, " + "__global %s *out" % (ctype, ctype), input_expr="ary[i]", scan_expr="across_seg_boundary ? b : (a+b)", neutral="0", is_segment_start_expr="segflags[i]", @@ -685,7 +696,8 @@ def test_segmented_scan(ctx_factory): seg_boundaries_values = [] for i in range(10): seg_boundary_count = max(2, min(100, randrange(0, int(0.4*n)))) - seg_boundaries = [randrange(n) for i in range(seg_boundary_count)] + seg_boundaries = [ + randrange(n) for i in range(seg_boundary_count)] if n >= 1029: seg_boundaries.insert(0, 1028) seg_boundaries.sort() @@ -697,7 +709,8 @@ def test_segmented_scan(ctx_factory): seg_boundary_flags = np.zeros(n, dtype=np.uint8) seg_boundary_flags[seg_boundaries] = 1 - seg_boundary_flags_dev = cl_array.to_device(queue, seg_boundary_flags) + seg_boundary_flags_dev = cl_array.to_device( + queue, seg_boundary_flags) seg_boundaries.insert(0, 0) @@ -778,6 +791,7 @@ def test_sort(ctx_factory): 1e-6*n/dev_elapsed, 1e-6*n/numpy_elapsed, numpy_elapsed/dev_elapsed)) assert (a_dev_sorted.get() == a_sorted).all() + @pytools.test.mark_test.opencl def test_list_builder(ctx_factory): from pytest import importorskip @@ -804,6 +818,7 @@ def test_list_builder(ctx_factory): assert inf.count == 3000 assert (inf.lists.get()[-6:] == [1, 2, 2, 3, 3, 3]).all() + @pytools.test.mark_test.opencl def test_key_value_sorter(ctx_factory): from pytest import importorskip @@ -838,10 +853,7 @@ def test_key_value_sorter(ctx_factory): # }}} - - if __name__ == "__main__": - import sys if len(sys.argv) > 1: exec(sys.argv[1]) else: -- GitLab