From 6f3e4a476505bdcb2d909148f0032b1e314d3ddc Mon Sep 17 00:00:00 2001 From: Hao Gao Date: Tue, 27 Mar 2018 14:31:27 -0500 Subject: [PATCH] Add reshaped_view function --- pytools/__init__.py | 18 ++++++++++++++++++ test/test_pytools.py | 12 ++++++++++++ 2 files changed, 30 insertions(+) diff --git a/pytools/__init__.py b/pytools/__init__.py index 90f8911..eea0576 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -2033,6 +2033,24 @@ def find_module_git_revision(module_file, n_levels_up): # }}} +# {{{ create a reshaped view of a numpy array + +def reshaped_view(a, newshape): + """ Create a new view object with shape ``newshape`` without copying the data of + ``a``. This function is different from ``numpy.reshape`` by raising an + exception when data copy is necessary. + + :arg a: a :class:`numpy.ndarray` object. + :arg newshape: an ``int`` object or a tuple of ``int`` objects. + """ + + newview = a.view() + newview.shape = newshape + return newview + +# }}} + + def _test(): import doctest doctest.testmod() diff --git a/test/test_pytools.py b/test/test_pytools.py index 51d6499..65514dd 100644 --- a/test/test_pytools.py +++ b/test/test_pytools.py @@ -211,6 +211,18 @@ def test_find_module_git_revision(): print(pytools.find_module_git_revision(pytools.__file__, n_levels_up=1)) +def test_reshaped_view(): + import pytools + import numpy as np + + a = np.zeros((10, 2)) + b = a.T + c = pytools.reshaped_view(a, -1) + assert c.shape == (20,) + with pytest.raises(AttributeError): + pytools.reshaped_view(b, -1) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab