From 0b7362ae74d37e3983c57e82c082caf83d4b3204 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 29 Mar 2012 01:36:16 -0400
Subject: [PATCH] Fixes to auto-test.

---
 loopy/compiled.py | 39 ++++++++++++++++++++++++++++-----------
 1 file changed, 28 insertions(+), 11 deletions(-)

diff --git a/loopy/compiled.py b/loopy/compiled.py
index cc78d2b24..92997594a 100644
--- a/loopy/compiled.py
+++ b/loopy/compiled.py
@@ -11,10 +11,10 @@ import numpy as np
 
 class CompiledKernel:
     def __init__(self, context, kernel, size_args=None, options=[],
-             edit_code=False, with_annotation=False):
+             edit_code=False, codegen_kwargs={}):
         self.kernel = kernel
         from loopy.codegen import generate_code
-        self.code = generate_code(kernel, with_annotation=with_annotation)
+        self.code = generate_code(kernel, **codegen_kwargs)
 
         if edit_code:
             from pytools import invoke_editor
@@ -283,9 +283,9 @@ def _default_check_result(result, ref_result):
 
 
 
-def auto_test_vs_ref(ref_knl, ctx, kernel_gen, op_count, op_label, parameters,
+def auto_test_vs_ref(ref_knl, ctx, kernel_gen, op_count=[], op_label=[], parameters={},
         print_ref_code=False, print_code=True, warmup_rounds=2,
-        edit_code=False, dump_binary=False, with_annotation=False,
+        edit_code=False, dump_binary=False, codegen_kwargs={},
         fills_entire_output=True, check_result=None):
     """Compare results of `ref_knl` to the kernels generated by the generator
     `kernel_gen`.
@@ -294,6 +294,16 @@ def auto_test_vs_ref(ref_knl, ctx, kernel_gen, op_count, op_label, parameters,
         *(result, reference_result)* returning a a tuple (class:`bool`, message)
         indicating correctness/acceptability of the result
     """
+
+    if isinstance(op_count, (int, float)):
+        from warnings import warn
+        warn("op_count should be a list", stacklevel=2)
+        op_count = [op_count]
+    if isinstance(op_label, str):
+        from warnings import warn
+        warn("op_label should be a list", stacklevel=2)
+        op_label = [op_label]
+
     from time import time
 
     if check_result is None:
@@ -340,7 +350,7 @@ def auto_test_vs_ref(ref_knl, ctx, kernel_gen, op_count, op_label, parameters,
         break
 
     ref_compiled = CompiledKernel(ref_ctx, ref_sched_kernel,
-            with_annotation=with_annotation)
+            codegen_kwargs=codegen_kwargs)
     if print_ref_code:
         print 75*"-"
         print "Reference Code:"
@@ -385,7 +395,7 @@ def auto_test_vs_ref(ref_knl, ctx, kernel_gen, op_count, op_label, parameters,
                     fill_value=fill_value)
 
         compiled = CompiledKernel(ctx, kernel, edit_code=edit_code,
-                with_annotation=with_annotation)
+                codegen_kwargs=codegen_kwargs)
 
         print 75*"-"
         print "Kernel #%d:" % i
@@ -452,11 +462,18 @@ def auto_test_vs_ref(ref_knl, ctx, kernel_gen, op_count, op_label, parameters,
             else:
                 break
 
-        print "elapsed: %g s event, %s s other-event %g s wall, rate: %g %s/s (%d rounds)" % (
-                elapsed, elapsed_evt_2, elapsed_wall, op_count/elapsed, op_label,
-                timing_rounds)
-        print "ref: elapsed: %g s event, %g s wall, rate: %g %s/s" % (
-                ref_elapsed, ref_elapsed_wall, op_count/ref_elapsed, op_label)
+        rates = ""
+        for cnt, lbl in zip(op_count, op_label):
+            rates += " %g %s/s" % (cnt/elapsed, lbl)
+
+        print "elapsed: %g s event, %s s other-event %g s wall (%d rounds)%s" % (
+                elapsed, elapsed_evt_2, elapsed_wall, timing_rounds, rates)
+
+        ref_rates = ""
+        for cnt, lbl in zip(op_count, op_label):
+            ref_rates += " %g %s/s" % (cnt/ref_elapsed, lbl)
+        print "ref: elapsed: %g s event, %g s wall%s" % (
+                ref_elapsed, ref_elapsed_wall, ref_rates)
 
     # }}}
 
-- 
GitLab