From 267d772e729a0c61afcc5a477f363b1aca258f82 Mon Sep 17 00:00:00 2001 From: "[6~" Date: Wed, 10 Jun 2020 13:39:35 -0500 Subject: [PATCH] Work around numpy/numpy#16564 in make_obj_array --- pytools/obj_array.py | 7 ++++++- test/test_pytools.py | 24 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/pytools/obj_array.py b/pytools/obj_array.py index d91a24c..df827c2 100644 --- a/pytools/obj_array.py +++ b/pytools/obj_array.py @@ -83,7 +83,12 @@ def make_obj_array(res_list): can be undesirable. """ result = np.empty((len(res_list),), dtype=object) - result[:] = res_list + + # 'result[:] = res_list' may look tempting, however: + # https://github.com/numpy/numpy/issues/16564 + for idx in range(len(res_list)): + result[idx] = res_list[idx] + return result diff --git a/test/test_pytools.py b/test/test_pytools.py index 4a47a06..939a0f5 100644 --- a/test/test_pytools.py +++ b/test/test_pytools.py @@ -261,6 +261,30 @@ def test_natsorted(): assert natsorted([10, 1, 9], key=lambda d: "x%d" % d) == [1, 9, 10] +# {{{ object array iteration behavior + +class FakeArray: + nopes = 0 + + def __len__(self): + FakeArray.nopes += 1 + return 10 + + def __getitem__(self, idx): + FakeArray.nopes += 1 + if idx > 10: + raise IndexError() + + +def test_make_obj_array_iteration(): + from pytools.obj_array import make_obj_array + make_obj_array([FakeArray()]) + + assert FakeArray.nopes == 0, FakeArray.nopes + +# }}} + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab