diff --git a/gen_wrap.py b/gen_wrap.py index 590c700551a201bab09b17e0741f03f31f612ad0..675b01f4eb90f47bc18d3e8429f4ed315113b08b 100644 --- a/gen_wrap.py +++ b/gen_wrap.py @@ -362,6 +362,10 @@ import six import sys import logging import threading +import operator as _operator + +# isl has parameters called type which end up shadowing the built-in function. +_type = type _PY3 = sys.version_info >= (3,) @@ -1391,7 +1395,13 @@ def write_method_wrapper(gen, cls_name, meth): pre_call("{val_name} = {name}._copy()".format(**fmt_args)) pre_call(""" - elif isinstance({name}, six.integer_types): + else: + try: + {name} = _operator.index({name}) + except TypeError: + raise IslTypeError("{name} is a %s and cannot " + "be cast to a Val" % _type({name})) + _cdata_{name} = lib.isl_val_int_from_si( {arg0_name}._get_ctx_data(), {name}) @@ -1399,10 +1409,6 @@ def write_method_wrapper(gen, cls_name, meth): raise Error("isl_val_int_from_si failed") {val_name} = Val(_data=_cdata_{name}) - - else: - raise IslTypeError("{name} is a %s and cannot " - "be cast to a Val" % type({name})) """ .format(**fmt_args)) diff --git a/test/test_isl.py b/test/test_isl.py index 0fe13d35f6522d00ec97a889eb38bab9193098dd..9e18953878a8df2912404c45326678ab6ece074d 100644 --- a/test/test_isl.py +++ b/test/test_isl.py @@ -310,6 +310,16 @@ def test_align_spaces(): assert result.get_var_dict() == m2.get_var_dict() +def test_pass_numpy_int(): + np = pytest.importorskip("numpy") + + s = isl.BasicMap("{[i,j]: 0<=i,j<15}") + c0 = s.get_constraints()[0] + + c1 = c0.set_constant_val(np.int32(5)) + print(c1) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: