Skip to content
Snippets Groups Projects
Commit bfbc8ee3 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Add expression fuzzing test, to check code generation.

parent 5e7738b1
No related branches found
No related tags found
No related merge requests found
......@@ -180,6 +180,103 @@ def test_argmax(ctx_factory):
def make_random_value():
from random import randrange, uniform
v = randrange(3)
if v == 0:
while True:
z = randrange(-1000, 1000)
if z:
return z
elif v == 1:
return uniform(-10, 10)
else:
return uniform(-10, 10) + 1j*uniform(-10, 10)
def make_random_expression(var_values, size):
from random import randrange
import pymbolic.primitives as p
v = randrange(1500)
size[0] += 1
if v < 500 and size[0] < 40:
term_count = randrange(2, 5)
if randrange(2) < 1:
cls = p.Sum
else:
cls = p.Product
return cls(tuple(
make_random_expression(var_values, size)
for i in range(term_count)))
elif v < 750:
return make_random_value()
elif v < 1000:
var_name = "var_%d" % len(var_values)
assert var_name not in var_values
var_values[var_name] = make_random_value()
return p.Variable(var_name)
elif v < 1250:
return make_random_expression(var_values, size) - make_random_expression(var_values, size)
elif v < 1500:
return make_random_expression(var_values, size) / make_random_expression(var_values, size)
def generate_random_fuzz_examples():
for i in xrange(20):
size = [0]
var_values = {}
expr = make_random_expression(var_values, size)
yield expr, var_values
def test_fuzz_code_generator(ctx_factory):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
from expr_fuzz import get_fuzz_examples
for expr, var_values in generate_random_fuzz_examples():
#for expr, var_values in get_fuzz_examples():
from pymbolic import evaluate
true_value = evaluate(expr, var_values)
def get_dtype(x):
if isinstance(x, complex):
return np.complex128
else:
return np.float64
knl = lp.make_kernel(ctx.devices[0], "{ : }",
[lp.Instruction(None, "value", expr)],
[lp.GlobalArg("value", np.complex128, shape=())]
+ [
lp.ScalarArg(name, get_dtype(val))
for name, val in var_values.iteritems()
])
ck = lp.CompiledKernel(ctx, knl)
evt, (lp_value,) = ck(queue, **var_values)
err = abs(true_value-lp_value)/abs(true_value)
if abs(err) > 1e-10:
print "---------------------------------------------------------------------"
print "WRONG: rel error=%g" % err
print "true=%r" % true_value
print "loopy=%r" % lp_value
print "---------------------------------------------------------------------"
print ck.code
print "---------------------------------------------------------------------"
print var_values
print "---------------------------------------------------------------------"
print repr(expr)
print "---------------------------------------------------------------------"
print expr
print "---------------------------------------------------------------------"
1/0
if __name__ == "__main__":
import sys
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment