diff --git a/gen_wrap.py b/gen_wrap.py index bfeaf24eb8e119108564a04e077dcc572e520fbd..115097c766080279eb88e357acd126759a228591 100644 --- a/gen_wrap.py +++ b/gen_wrap.py @@ -38,8 +38,6 @@ ISL_SEM_TO_SEM = { "__isl_null": SEM_NULL, } -PY3 = sys.version_info >= (3,) - NON_COPYABLE = ["ctx", "printer", "access_info"] NON_COPYABLE_WITH_ISL_PREFIX = ["isl_"+i for i in NON_COPYABLE] @@ -295,6 +293,10 @@ typedef isl_restriction *(*isl_access_restrict)( PY_PREAMBLE = """ import six +import sys + + +_PY3 = sys.version_info >= (3,) from islpy._isl_cffi import ffi @@ -321,7 +323,7 @@ _context_use_map = {} def _deref_ctx(ctx_data, ctx_iptr): _context_use_map[ctx_iptr] -= 1 - if not _context_use_map[ctx_iptr]: + if _context_use_map[ctx_iptr] == 0: del _context_use_map[ctx_iptr] lib.isl_ctx_free(ctx_data) @@ -1028,7 +1030,14 @@ def gen_callback_wrapper(gen, cb, func_name, has_userptr): def write_method_wrapper(gen, cls_name, meth): pre_call = PythonCodeGenerator() - post_call = PythonCodeGenerator() + + # There are two post-call phases, "safety", and "check". The "safety" + # phase's job is to package up all the data returned by the function + # called. No exceptions may be raised during 'safety'. + # + # Next, the "check" phase will perform error checking and may raise exceptions. + safety = PythonCodeGenerator() + check = PythonCodeGenerator() docs = [] passed_args = [] @@ -1209,7 +1218,7 @@ def write_method_wrapper(gen, cls_name, meth): passed_args.append("ffi.addressof(_retptr_{name})".format(name=arg.name)) py_cls = isl_class_to_py_class(arg.base_type) - post_call(""" + safety(""" if _retptr_{name} == ffi.NULL: _ret_{name} = None else: @@ -1252,14 +1261,14 @@ def write_method_wrapper(gen, cls_name, meth): # {{{ return value processing if meth.return_base_type == "isl_stat" and not meth.return_ptr: - post_call("if _result == lib.isl_stat_error:") - with Indentation(post_call): - post_call('raise Error("call to \\"{0}\\" failed")'.format(meth.c_name)) + check("if _result == lib.isl_stat_error:") + with Indentation(check): + check('raise Error("call to \\"{0}\\" failed")'.format(meth.c_name)) elif meth.return_base_type == "isl_bool" and not meth.return_ptr: - post_call("if _result == lib.isl_bool_error:") - with Indentation(post_call): - post_call('raise Error("call to \\"{0}\\" failed")'.format(meth.c_name)) + check("if _result == lib.isl_bool_error:") + with Indentation(check): + check('raise Error("call to \\"{0}\\" failed")'.format(meth.c_name)) ret_vals.insert(0, "_result == lib.isl_bool_true") ret_descrs.insert(0, "bool") @@ -1282,7 +1291,7 @@ def write_method_wrapper(gen, cls_name, meth): meth.mutator_veto = True raise Retry() - post_call("%s._reset(_result)" % meth.args[0].name) + safety("%s._reset(_result)" % meth.args[0].name) ret_vals.insert(0, meth.args[0].name) ret_descrs.insert(0, @@ -1294,32 +1303,37 @@ def write_method_wrapper(gen, cls_name, meth): if meth.return_semantics is not SEM_GIVE and ret_cls != "ctx": raise SignatureNotSupported("non-give return") - post_call("if _result == ffi.NULL:") - with Indentation(post_call): - post_call( - 'raise Error("call to \\"{0}\\" failed")' - .format(meth.c_name)) - py_ret_cls = isl_class_to_py_class(ret_cls) - ret_vals.insert(0, "{0}(_data=_result)".format(py_ret_cls)) + safety( + "_result = None if _result == ffi.NULL else {0}(_data=_result)" + .format(py_ret_cls)) + + check(""" + if _result is None: + raise Error("call to {c_method} failed") + """ + .format(c_method=meth.c_name)) + + ret_vals.insert(0, "_result") ret_descrs.insert(0, ":class:`%s`" % py_ret_cls) elif meth.return_base_type in ["const char", "char"] and meth.return_ptr == "*": - post_call("if _result != ffi.NULL:") - with Indentation(post_call): - post_call("_str_ret = ffi.string(_result)") - post_call("else:") - with Indentation(post_call): - post_call("_str_ret = None") - - if PY3: - ret_vals.insert(0, - "_str_ret.decode() if _str_ret is not None else _str_ret") - else: - ret_vals.insert(0, "_str_ret") + safety(""" + if _result != ffi.NULL: + _str_ret = ffi.string(_result) + else: + _str_ret = None + """) if meth.return_semantics is SEM_GIVE: - post_call("libc.free(_result)") + safety("libc.free(_result)") + + check(""" + if _PY3 and _str_ret is not None: + _str_ret = _str_ret.decode() + """) + + ret_vals.insert(0, "_str_ret") ret_descrs.insert(0, "string") @@ -1344,16 +1358,16 @@ def write_method_wrapper(gen, cls_name, meth): assert len(ret_vals) == len(ret_descrs) - post_call("") + check("") if len(ret_vals) == 0: ret_descr = "(nothing)" elif len(ret_vals) == 1: - post_call("return " + ret_vals[0]) + check("return " + ret_vals[0]) ret_descr = ret_descrs[0] else: - post_call("return " + ", ".join(ret_vals)) + check("return " + ", ".join(ret_vals)) ret_descr = "(%s)" % ", ".join(ret_descrs) docs = (["%s(%s)" % (meth.name, ", ".join(input_args)), ""] @@ -1366,7 +1380,8 @@ def write_method_wrapper(gen, cls_name, meth): gen(repr("\n".join(docs))) gen("") gen.extend(pre_call) - gen.extend(post_call) + gen.extend(safety) + gen.extend(check) gen.dedent() gen("")