diff --git a/gen_wrap.py b/gen_wrap.py index 4f58acf0fe627c1e0208d7105c19d192ed463b20..c4e0fb9f484ccdcdcb876493a6f59a43cf5b7530 100644 --- a/gen_wrap.py +++ b/gen_wrap.py @@ -905,11 +905,12 @@ def gen_conversions(gen, tgt_cls, name): conversion_method=conversion_method)) -def gen_callback_wrapper(gen, cb, func_name): +def gen_callback_wrapper(gen, cb, func_name, has_userptr): passed_args = [] input_args = [] - assert cb.args[-1].name == "user" + if has_userptr: + assert cb.args[-1].name == "user" pre_call = PythonCodeGenerator() post_call = PythonCodeGenerator() @@ -937,7 +938,40 @@ def gen_callback_wrapper(gen, cb, func_name): raise SignatureNotSupported("unsupported callback arg: %s %s" % ( arg.base_type, arg.ptr)) - input_args.append("user") + if has_userptr: + input_args.append("user") + + if cb.return_base_type in SAFE_IN_TYPES and cb.return_ptr == "": + failure_return = "lib.isl_stat_error" + + post_call(""" + if _result is None: + _result = lib.isl_stat_ok + """) + + elif cb.return_base_type.startswith("isl_") and cb.return_ptr == "*": + failure_return = "ffi.NULL" + + ret_py_cls = isl_class_to_py_class(cb.return_base_type) + pre_call(""" + if not isinstance(_result, {py_cls}): + raise IslTypeError("return value is not a {py_cls}") + """ + .format(py_cls=ret_py_cls)) + + ret_cls = cb.return_base_type[4:] + + if cb.return_semantics is None: + raise SignatureNotSupported("callback return with unspecified semantics") + elif cb.return_semantics is not SEM_GIVE: + raise SignatureNotSupported("callback return with non-GIVE semantics") + if ret_cls in NON_COPYABLE: + raise SignatureNotSupported("noncopyable callback return") + + post_call("_result = _result._release()") + + else: + raise SignatureNotSupported("unsupported callback signature") gen( "def {func_name}({input_args}):" @@ -949,13 +983,13 @@ def gen_callback_wrapper(gen, cb, func_name): gen("try:") with Indentation(gen): gen.extend(pre_call) - gen( "_result = {name}({passed_args})" .format(name=cb.name, passed_args=", ".join(passed_args))) - gen("return lib.isl_stat_ok") - gen("") + gen.extend(post_call) + + gen("return _result") gen(""" except Exception as e: @@ -965,8 +999,8 @@ def gen_callback_wrapper(gen, cb, func_name): import traceback traceback.print_exc() - return lib.isl_stat_error - """) + return {failure_return} + """.format(failure_return=failure_return)) gen("") @@ -987,16 +1021,16 @@ def write_method_wrapper(gen, cls_name, meth): arg = meth.args[arg_idx] if isinstance(arg, CallbackArgument): - if arg.return_base_type not in SAFE_IN_TYPES or arg.return_ptr: - raise SignatureNotSupported("non-int callback") - arg_idx += 1 - if meth.args[arg_idx].name != "user": - raise SignatureNotSupported("unexpected callback signature") + has_userptr = ( + arg_idx + 1 < len(meth.args) + and meth.args[arg_idx+1].name == "user") + if has_userptr: + arg_idx += 1 cb_wrapper_name = "_cb_wrapper_"+arg.name - gen_callback_wrapper(pre_call, arg, cb_wrapper_name) + gen_callback_wrapper(pre_call, arg, cb_wrapper_name, has_userptr) pre_call( '_cb_{name} = ffi.callback("{cb_decl}")({cb_wrapper_name})' @@ -1099,13 +1133,11 @@ def write_method_wrapper(gen, cls_name, meth): gen_conversions(pre_call, arg.base_type, arg.name) arg_py_cls = isl_class_to_py_class(arg.base_type) - pre_call("if not isinstance({name}, {py_cls}):" - .format( - name=arg.name, py_cls=arg_py_cls)) - with Indentation(pre_call): - pre_call('raise IslTypeError("{name} is not a {py_cls}")' - .format( - name=arg.name, py_cls=arg_py_cls)) + pre_call(""" + if not isinstance({name}, {py_cls}): + raise IslTypeError("{name} is not a {py_cls}") + """ + .format(name=arg.name, py_cls=arg_py_cls)) arg_cls = arg.base_type[4:] arg_descr = ":param %s: :class:`%s`" % ( @@ -1439,7 +1471,7 @@ def gen_wrapper(include_dirs, include_barvinok=False, isl_version=None): write_method_header(header_f, meth) - if meth.name == "free": + if meth.name in ["free", "set_free_user"]: continue try: