diff --git a/bin/loopy b/bin/loopy index b47127740b4d016d823de3874fd7398c031d8160..0ac8ee78a5da228bb7628cd0c8866bf5f29abeef 100644 --- a/bin/loopy +++ b/bin/loopy @@ -57,6 +57,7 @@ def main(): parser.add_argument("--target") parser.add_argument("--name") parser.add_argument("--transform") + parser.add_argument("--edit-code", action="store_true") parser.add_argument("--occa-defines") parser.add_argument("--occa-add-dummy-arg", action="store_true") parser.add_argument("--print-ir", action="store_true") @@ -191,11 +192,20 @@ def main(): else: outfile = "-" + code = "\n\n".join(codes) + + import os + edit_kernel_env = os.environ.get("LOOPY_EDIT_KERNEL") + if (args.edit_code + or any(edit_kernel_env.lower() in k.name.lower() for k in kernels)): + from pytools import invoke_editor + code = invoke_editor(code, filename="edit.cl") + if outfile == "-": - sys.stdout.write("\n\n".join(codes)) + sys.stdout.write(code) else: with open(outfile, "w") as outfile_fd: - outfile_fd.write("\n\n".join(codes)) + outfile_fd.write(code) if __name__ == "__main__":