From 087351c62eb4ccd68f90990ed8de8317a68c6d29 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Wed, 29 Sep 2021 20:01:31 -0500 Subject: [PATCH] add tests for flatten edge cases --- arraycontext/container/traversal.py | 6 +++--- test/test_arraycontext.py | 21 +++++++++++++++++---- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index db3633d..309571b 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -569,8 +569,7 @@ def unflatten( try: iterable = serialize_container(template_subary) except TypeError: - # NOTE: the max is needed to handle device scalars with size == 0 - offset += max(1, template_subary.size) + offset += template_subary.size if offset > ary.size: raise ValueError("'template' and 'ary' sizes do not match") @@ -592,7 +591,8 @@ def unflatten( "array context.") from exc if hasattr(template_subary, "strides"): - if template_subary.strides != subary.strides: + if template_subary.size != 0 \ + and template_subary.strides != subary.strides: raise ValueError( f"strides do not match template: got {subary.strides}, " f"expected {template_subary.strides}") diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 9d27eaa..0ed8625 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -892,13 +892,20 @@ def test_container_norm(actx_factory, ord): # {{{ test flatten and unflatten -def test_flatten_array_container(actx_factory): +@pytest.mark.parametrize("shapes", [ + 0, # tests device scalars when flattening + 512, + [(128, 67)], + [(127, 67), (18, 0)], # tests 0-sized arrays + [(64, 7), (154, 12)] + ]) +def test_flatten_array_container(actx_factory, shapes): actx = actx_factory() + if np.prod(shapes) == 0 and isinstance(actx, PytatoPyOpenCLArrayContext): + pytest.skip("operation not supported on PytatoPyOpenCLArrayContext") from arraycontext import flatten, unflatten - arys = _get_test_containers(actx, shapes=512) \ - + _get_test_containers(actx, shapes=(128, 67)) \ - + _get_test_containers(actx, shapes=[(64, 7), (154, 12)]) + arys = _get_test_containers(actx, shapes=shapes) for ary in arys: flat = flatten(ary, actx) @@ -909,6 +916,12 @@ def test_flatten_array_container(actx_factory): actx.np.linalg.norm(ary - ary_roundtrip) ) < 1.0e-15 + from arraycontext import rec_multimap_reduce_array_container + assert rec_multimap_reduce_array_container( + np.prod, + lambda x, y: x.shape == y.shape, + ary, ary_roundtrip) + # }}} -- GitLab