diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index db3633d17d994fb861a03b5f9c54574aeb966656..309571b3090a0237f67671d8d964e7e6e74b5ae5 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -569,8 +569,7 @@ def unflatten( try: iterable = serialize_container(template_subary) except TypeError: - # NOTE: the max is needed to handle device scalars with size == 0 - offset += max(1, template_subary.size) + offset += template_subary.size if offset > ary.size: raise ValueError("'template' and 'ary' sizes do not match") @@ -592,7 +591,8 @@ def unflatten( "array context.") from exc if hasattr(template_subary, "strides"): - if template_subary.strides != subary.strides: + if template_subary.size != 0 \ + and template_subary.strides != subary.strides: raise ValueError( f"strides do not match template: got {subary.strides}, " f"expected {template_subary.strides}") diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 9d27eaae5f4bf3b86a8312c60789f2373ddde716..0ed86251980db4ac4b2fa7d79178ab02fbf21d7b 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -892,13 +892,20 @@ def test_container_norm(actx_factory, ord): # {{{ test flatten and unflatten -def test_flatten_array_container(actx_factory): +@pytest.mark.parametrize("shapes", [ + 0, # tests device scalars when flattening + 512, + [(128, 67)], + [(127, 67), (18, 0)], # tests 0-sized arrays + [(64, 7), (154, 12)] + ]) +def test_flatten_array_container(actx_factory, shapes): actx = actx_factory() + if np.prod(shapes) == 0 and isinstance(actx, PytatoPyOpenCLArrayContext): + pytest.skip("operation not supported on PytatoPyOpenCLArrayContext") from arraycontext import flatten, unflatten - arys = _get_test_containers(actx, shapes=512) \ - + _get_test_containers(actx, shapes=(128, 67)) \ - + _get_test_containers(actx, shapes=[(64, 7), (154, 12)]) + arys = _get_test_containers(actx, shapes=shapes) for ary in arys: flat = flatten(ary, actx) @@ -909,6 +916,12 @@ def test_flatten_array_container(actx_factory): actx.np.linalg.norm(ary - ary_roundtrip) ) < 1.0e-15 + from arraycontext import rec_multimap_reduce_array_container + assert rec_multimap_reduce_array_container( + np.prod, + lambda x, y: x.shape == y.shape, + ary, ary_roundtrip) + # }}}