diff --git a/pytools/__init__.py b/pytools/__init__.py index 90f8911bf35e3ee1ad9e1dc7e3584c6b10cc8f8b..0894f8ef6a22ff3837feeb2084af6a9d126b1d83 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -129,6 +129,12 @@ Functions for dealing with (large) auxiliary files -------------------------------------------------- .. autofunction:: download_from_web_if_not_present + +Helpers for :mod:`numpy` +------------------------ + +.. autofunction:: reshaped_view + """ @@ -2033,6 +2039,26 @@ def find_module_git_revision(module_file, n_levels_up): # }}} +# {{{ create a reshaped view of a numpy array + +def reshaped_view(a, newshape): + """ Create a new view object with shape ``newshape`` without copying the data of + ``a``. This function is different from ``numpy.reshape`` by raising an + exception when data copy is necessary. + + :arg a: a :class:`numpy.ndarray` object. + :arg newshape: an ``int`` object or a tuple of ``int`` objects. + + .. versionadded:: 2018.4 + """ + + newview = a.view() + newview.shape = newshape + return newview + +# }}} + + def _test(): import doctest doctest.testmod() diff --git a/test/test_pytools.py b/test/test_pytools.py index 51d64999c0b62a7a41ceb35d17a3c88f204bb770..65514dd927c09b3b7a69e76d309e004cb5098d8b 100644 --- a/test/test_pytools.py +++ b/test/test_pytools.py @@ -211,6 +211,18 @@ def test_find_module_git_revision(): print(pytools.find_module_git_revision(pytools.__file__, n_levels_up=1)) +def test_reshaped_view(): + import pytools + import numpy as np + + a = np.zeros((10, 2)) + b = a.T + c = pytools.reshaped_view(a, -1) + assert c.shape == (20,) + with pytest.raises(AttributeError): + pytools.reshaped_view(b, -1) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])