diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index c339609c98cde40599bb90555339f3937cc45a65..3eb3664f4c9be355424d2fb4ef0963d9d4b91337 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -603,29 +603,51 @@ def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any: def unflatten( template: ArrayOrContainerT, ary: Any, - actx: ArrayContext) -> ArrayOrContainerT: + actx: ArrayContext, *, + strict: bool = True) -> ArrayOrContainerT: """Unflatten an array *ary* produced by :func:`flatten` back into an :class:`~arraycontext.ArrayContainer`. The order and sizes of each slice into *ary* are determined by the array container *template*. + + :arg strict: if *True* additional :class:`~numpy.dtype` and stride + checking is performed on the unflattened array. Otherwise, these + checks are skipped. """ # NOTE: https://github.com/python/mypy/issues/7057 offset = 0 + common_dtype = None def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT: - nonlocal offset + nonlocal offset, common_dtype try: iterable = serialize_container(template_subary) except NotAnArrayContainerError: + # {{{ validate subary + if (offset + template_subary.size) > ary.size: raise ValueError("'template' and 'ary' sizes do not match: " "'template' is too large") - if template_subary.dtype != ary.dtype: - raise ValueError("'template' dtype does not match 'ary': " - f"got {template_subary.dtype}, expected {ary.dtype}") + if strict: + if template_subary.dtype != ary.dtype: + raise ValueError("'template' dtype does not match 'ary': " + f"got {template_subary.dtype}, expected {ary.dtype}") + else: + # NOTE: still require that *template* has a uniform dtype + if common_dtype is None: + common_dtype = template_subary.dtype + else: + if common_dtype != template_subary.dtype: + raise ValueError("arrays in 'template' have different " + f"dtypes: got {template_subary.dtype}, but " + f"expected {common_dtype}.") + + # }}} + + # {{{ reshape flat_subary = ary[offset:offset + template_subary.size] try: @@ -640,12 +662,18 @@ def unflatten( "This functionality needs to be implemented by the " "array context.") from exc - if hasattr(template_subary, "strides"): + # }}} + + # {{{ check strides + + if strict and hasattr(template_subary, "strides"): if template_subary.strides != subary.strides: raise ValueError( f"strides do not match template: got {subary.strides}, " f"expected {template_subary.strides}") + # }}} + offset += template_subary.size return subary else: diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 632c95bd27c692bb6ac6c56f67f1d6d021b6b5aa..d97f00f0f1dd1a72c997109e28914376c917652d 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -926,6 +926,28 @@ def test_flatten_array_container(actx_factory, shapes): actx.np.linalg.norm(ary - ary_roundtrip) ) < 1.0e-15 + # {{{ complex to real + + if isinstance(shapes, (int, tuple)): + shapes = [shapes] + + ary = DOFArray(actx, tuple([ + actx.from_numpy(randn(shape, np.float64)) + for shape in shapes])) + + template = DOFArray(actx, tuple([ + actx.from_numpy(randn(shape, np.complex128)) + for shape in shapes])) + + flat = flatten(ary, actx) + ary_roundtrip = unflatten(template, flat, actx, strict=False) + + assert actx.to_numpy( + actx.np.linalg.norm(ary - ary_roundtrip) + ) < 1.0e-15 + + # }}} + def test_flatten_array_container_failure(actx_factory): actx = actx_factory()