Skip to content
Snippets Groups Projects
Commit 0384e073 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Limit number of tested kernels in auto_test_vs_ref

parent 532b9202
No related branches found
No related tags found
No related merge requests found
......@@ -353,14 +353,15 @@ def auto_test_vs_ref(
ref_knl, ctx, test_knl, op_count=[], op_label=[], parameters={},
print_ref_code=False, print_code=True, warmup_rounds=2,
dump_binary=False,
fills_entire_output=None, do_check=True, check_result=None
):
fills_entire_output=None, do_check=True, check_result=None,
max_test_kernel_count=1):
"""Compare results of `ref_knl` to the kernels generated by
scheduling *test_knl*.
:arg check_result: a callable with :class:`numpy.ndarray` arguments
*(result, reference_result)* returning a a tuple (class:`bool`,
message) indicating correctness/acceptability of the result
:arg max_test_kernel_count: Stop testing after this many *test_knl*
"""
import pyopencl as cl
......@@ -488,28 +489,25 @@ def auto_test_vs_ref(
properties=cl.command_queue_properties.PROFILING_ENABLE)
args = None
from loopy.kernel import LoopKernel
if not isinstance(test_knl, LoopKernel):
warn("Passing an iterable of kernels to auto_test_vs_ref "
"is deprecated--just pass the kernel instead. "
"Scheduling will be performed in auto_test_vs_ref.",
DeprecationWarning, stacklevel=2)
test_kernels = test_knl
from loopy.kernel import kernel_state
if test_knl.state not in [
kernel_state.PREPROCESSED,
kernel_state.SCHEDULED]:
test_knl = lp.preprocess_kernel(test_knl)
if not test_knl.schedule:
test_kernels = lp.generate_loop_schedules(test_knl)
else:
from loopy.kernel import kernel_state
if test_knl.state not in [
kernel_state.PREPROCESSED,
kernel_state.SCHEDULED]:
test_knl = lp.preprocess_kernel(test_knl)
if not test_knl.schedule:
test_kernels = lp.generate_loop_schedules(test_knl)
else:
test_kernels = [test_knl]
test_kernels = [test_knl]
test_kernel_count = 0
from loopy.preprocess import infer_unknown_types
for i, kernel in enumerate(test_kernels):
test_kernel_count += 1
if test_kernel_count > max_test_kernel_count:
break
kernel = infer_unknown_types(kernel, expect_completion=True)
compiled = CompiledKernel(ctx, kernel)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment