diff --git a/test/test_isl.py b/test/test_isl.py index a5de274682cb5aa41f42319bf27f07bbf0e23749..fb3ce73b6abd47b1bf9675d643f0bb2bd8ccfc49 100644 --- a/test/test_isl.py +++ b/test/test_isl.py @@ -158,27 +158,27 @@ def test_eval_pw_qpolynomial(): def test_schedule(): - schedule = isl.UnionMap("{A[i,j] -> [i,j]: 0 < i < j < 100}") + schedule = isl.Map("{S[t,i,j] -> [t,i,j]: 0 < t < 20 and 0 < i < j < 100}") + accesses = isl.Map("{S[t,i,j] -> bar[t%2, i+1, j-1]}") context = isl.Set("{:}") build = isl.AstBuild.from_context(context) def callback(node, build): - return None + schedulemap = build.get_schedule() + accessmap = accesses.apply_domain(schedulemap) + aff = isl.PwMultiAff.from_map(isl.Map.from_union_map(accessmap)) + access = build.call_from_pw_multi_aff(aff) + return isl.AstNode.alloc_user(access) - build, callback_handle = build.set_after_each_for(callback) + build, callback_handle = build.set_at_each_domain(callback) - try: - ast = build.ast_from_schedule(schedule) - except isl.Error: - # expected for now -- callback needs to return an AstNode, - # but I don't know how to make one. - pass + ast = build.ast_from_schedule(schedule) - else: - printer = isl.Printer.to_str(isl.DEFAULT_CONTEXT) - printer = printer.set_output_format(4) - printer = printer.print_ast_node(ast) - print(printer.get_str()) + printer = isl.Printer.to_str(isl.DEFAULT_CONTEXT) + printer = printer.set_output_format(isl.format.C) + printer = printer.print_ast_node(ast) + + print(printer.get_str()) if __name__ == "__main__":