From cc03864eeb3f644d0dc7d2b52c8ab0d1e892f95f Mon Sep 17 00:00:00 2001 From: Matthew Smith <mjsmith6@illinois.edu> Date: Wed, 27 Apr 2022 14:09:43 -0500 Subject: [PATCH] deduce result dtype --- grudge/tools.py | 44 +++++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/grudge/tools.py b/grudge/tools.py index 246c21d8..3427ba1c 100644 --- a/grudge/tools.py +++ b/grudge/tools.py @@ -27,8 +27,8 @@ THE SOFTWARE. """ import numpy as np -from pytools import levi_civita -from typing import Tuple, Callable, Optional, Any +from pytools import levi_civita, product +from typing import Tuple, Callable, Optional, Union, Any from functools import partial from arraycontext.container import ArrayOrContainerT @@ -247,9 +247,9 @@ def map_subarrays( """ Apply a function *f* to subarrays of shape *in_shape* of an :class:`numpy.ndarray`, typically (but not necessarily) of dtype - :class:`object`. Return an :class:`numpy.ndarray` of the same dtype, - with the corresponding subarrays replaced by the return values of *f*, - and with the shape adapted to reflect *out_shape*. + :class:`object`. Return an :class:`numpy.ndarray` with the corresponding + subarrays replaced by the return values of *f*, and with the shape adapted + to reflect *out_shape*. Similar to :class:`numpy.vectorize`. @@ -300,19 +300,33 @@ def map_subarrays( base_shape = ary.shape[:ary.ndim-len(in_shape)] if len(base_shape) == 0: return f(ary) + elif product(base_shape) == 0: + if return_nested: + return np.empty(base_shape, dtype=object) + else: + return np.empty(base_shape + out_shape, dtype=object) else: in_slice = tuple(slice(0, n) for n in in_shape) - out_slice = tuple(slice(0, n) for n in out_shape) - if return_nested: - result = np.empty(base_shape, dtype=object) - for idx in np.ndindex(base_shape): - result[idx] = f(ary[idx + in_slice]) - return result + result_entries = np.empty(base_shape, dtype=object) + for idx in np.ndindex(base_shape): + result_entries[idx] = f(ary[idx + in_slice]) + if len(out_shape) == 0: + out_entry_template = result_entries.flat[0] + if np.isscalar(out_entry_template): + return result_entries.astype(type(out_entry_template)) + else: + return result_entries else: - result = np.empty(base_shape + out_shape, dtype=ary.dtype) - for idx in np.ndindex(base_shape): - result[idx + out_slice] = f(ary[idx + in_slice]) - return result + if return_nested: + return result_entries + else: + out_slice = tuple(slice(0, n) for n in out_shape) + out_entry_template = result_entries.flat[0] + result = np.empty( + base_shape + out_shape, dtype=out_entry_template.dtype) + for idx in np.ndindex(base_shape): + result[idx + out_slice] = result_entries[idx] + return result def rec_map_subarrays( -- GitLab