diff --git a/loopy/check.py b/loopy/check.py index c7ba6a76f03b4a5b0ae2d5e2201394abf3990059..65962a8eb4059d7403d06e32ab2684446743e0b5 100644 --- a/loopy/check.py +++ b/loopy/check.py @@ -315,6 +315,29 @@ def check_bounds(kernel): acm(insn.expression) acm(insn.assignee) +def check_write_destinations(kernel): + for insn in kernel.instructions: + wvar = insn.get_assignee_var_name() + + if wvar in kernel.all_inames(): + raise RuntimeError("iname '%s' may not be written" % wvar) + + insn_domain = kernel.get_inames_domain(kernel.insn_inames(insn)) + insn_params = set(insn_domain.get_var_names(dim_type.param)) + + if wvar in kernel.all_params(): + if wvar not in kernel.temporary_variables: + raise RuntimeError("domain parameter '%s' may not be written" + "--it is not a temporary variable" % wvar) + + if wvar in insn_params: + raise RuntimeError("domain parameter '%s' may not be written " + "inside a domain dependent on it" % wvar) + + if not (wvar in kernel.temporary_variables + or wvar in kernel.arg_dict) and wvar not in kernel.all_params(): + raise RuntimeError + # }}} def run_automatic_checks(kernel): @@ -326,6 +349,7 @@ def run_automatic_checks(kernel): check_for_write_races(kernel) check_for_data_dependent_parallel_bounds(kernel) check_bounds(kernel) + check_write_destinations(kernel) except: print 75*"=" print "failing kernel after processing:" diff --git a/test/test_loopy.py b/test/test_loopy.py index 16a096c191edbb80fb3415003d492903ccad3c36..91c77b07440cc4b038089b3f5ff92a91b6a4d21d 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1007,6 +1007,36 @@ def test_ilp_write_race_avoidance_private(ctx_factory): +def test_write_parameter(ctx_factory): + dtype = np.float32 + ctx = ctx_factory() + + knl = lp.make_kernel(ctx.devices[0], [ + "{[i,j]: 0<=i,j<n }", + ], + """ + a = sum((i,j), i*j) + b = sum(i, sum(j, i*j)) + n = 15 + """, + [ + lp.GlobalArg("a", dtype, shape=()), + lp.GlobalArg("b", dtype, shape=()), + lp.ValueArg("n", np.int32, approximately=1000), + ], + assumptions="n>=1") + + try: + lp.CompiledKernel(ctx, knl).get_code() + except RuntimeError, e: + assert "may not be written" in str(e) + pass # expected! + else: + assert False # expecting an error + + + + if __name__ == "__main__": import sys if len(sys.argv) > 1: