diff --git a/pytools/obj_array.py b/pytools/obj_array.py index d91a24cda76412a0567fde0dbaf12d74045d55f7..df827c21349107b9205efc90f98545c2c74d38a9 100644 --- a/pytools/obj_array.py +++ b/pytools/obj_array.py @@ -83,7 +83,12 @@ def make_obj_array(res_list): can be undesirable. """ result = np.empty((len(res_list),), dtype=object) - result[:] = res_list + + # 'result[:] = res_list' may look tempting, however: + # https://github.com/numpy/numpy/issues/16564 + for idx in range(len(res_list)): + result[idx] = res_list[idx] + return result diff --git a/test/test_pytools.py b/test/test_pytools.py index 4a47a067c2af96fbcbc47214f9b245301ce16bca..939a0f57970ffb0cba1bc59bf56d8b510f068910 100644 --- a/test/test_pytools.py +++ b/test/test_pytools.py @@ -261,6 +261,30 @@ def test_natsorted(): assert natsorted([10, 1, 9], key=lambda d: "x%d" % d) == [1, 9, 10] +# {{{ object array iteration behavior + +class FakeArray: + nopes = 0 + + def __len__(self): + FakeArray.nopes += 1 + return 10 + + def __getitem__(self, idx): + FakeArray.nopes += 1 + if idx > 10: + raise IndexError() + + +def test_make_obj_array_iteration(): + from pytools.obj_array import make_obj_array + make_obj_array([FakeArray()]) + + assert FakeArray.nopes == 0, FakeArray.nopes + +# }}} + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])