Skip to content
Snippets Groups Projects
Commit bd0d4974 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Prevent writes to inames (always), domain parameters (when inappropriate).

parent 5619ce81
No related branches found
No related tags found
No related merge requests found
......@@ -315,6 +315,29 @@ def check_bounds(kernel):
acm(insn.expression)
acm(insn.assignee)
def check_write_destinations(kernel):
for insn in kernel.instructions:
wvar = insn.get_assignee_var_name()
if wvar in kernel.all_inames():
raise RuntimeError("iname '%s' may not be written" % wvar)
insn_domain = kernel.get_inames_domain(kernel.insn_inames(insn))
insn_params = set(insn_domain.get_var_names(dim_type.param))
if wvar in kernel.all_params():
if wvar not in kernel.temporary_variables:
raise RuntimeError("domain parameter '%s' may not be written"
"--it is not a temporary variable" % wvar)
if wvar in insn_params:
raise RuntimeError("domain parameter '%s' may not be written "
"inside a domain dependent on it" % wvar)
if not (wvar in kernel.temporary_variables
or wvar in kernel.arg_dict) and wvar not in kernel.all_params():
raise RuntimeError
# }}}
def run_automatic_checks(kernel):
......@@ -326,6 +349,7 @@ def run_automatic_checks(kernel):
check_for_write_races(kernel)
check_for_data_dependent_parallel_bounds(kernel)
check_bounds(kernel)
check_write_destinations(kernel)
except:
print 75*"="
print "failing kernel after processing:"
......
......@@ -1007,6 +1007,36 @@ def test_ilp_write_race_avoidance_private(ctx_factory):
def test_write_parameter(ctx_factory):
dtype = np.float32
ctx = ctx_factory()
knl = lp.make_kernel(ctx.devices[0], [
"{[i,j]: 0<=i,j<n }",
],
"""
a = sum((i,j), i*j)
b = sum(i, sum(j, i*j))
n = 15
""",
[
lp.GlobalArg("a", dtype, shape=()),
lp.GlobalArg("b", dtype, shape=()),
lp.ValueArg("n", np.int32, approximately=1000),
],
assumptions="n>=1")
try:
lp.CompiledKernel(ctx, knl).get_code()
except RuntimeError, e:
assert "may not be written" in str(e)
pass # expected!
else:
assert False # expecting an error
if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment