diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 2c1e9d699e03c8d7e59f24a5cf49ef04e659aab6..aa91d342ff95adc7bad80142859bc33c8d5234ab 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -179,12 +179,13 @@ def map_array_container( or an instance of a base array type. """ try: - ser_ctr = serialize_container(ary) - except TypeError: + iterable = serialize_container(ary) + except NotImplementedError: return f(ary) else: return deserialize_container(ary, [ - (key, f(subary)) for key, subary in ser_ctr]) + (key, f(subary)) for key, subary in iterable + ]) def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any: @@ -262,12 +263,15 @@ def keyed_map_array_container(f: Callable[[Any, Any], Any], :param ary: a (potentially nested) structure of :class:`ArrayContainer`\ s, or an instance of a base array type. """ - if is_array_container(ary): - return deserialize_container(ary, [ - (key, f(key, subary)) for key, subary in serialize_container(ary) - ]) + try: + iterable = serialize_container(ary) + except NotImplementedError: + raise ValueError( + f"Non-array container type has no key: {type(ary).__name__}") else: - raise ValueError("Not an array-container, i.e. unknown key to pass.") + return deserialize_container(ary, [ + (key, f(key, subary)) for key, subary in iterable + ]) def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any], @@ -281,13 +285,14 @@ def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any], def rec(keys: Tuple[Union[str, int], ...], _ary: ArrayOrContainerT) -> ArrayOrContainerT: - if is_array_container(_ary): - return deserialize_container(_ary, [ - (key, rec(keys + (key,), subary)) - for key, subary in serialize_container(_ary) - ]) - else: + try: + iterable = serialize_container(_ary) + except NotImplementedError: return f(keys, _ary) + else: + return deserialize_container(_ary, [ + (key, rec(keys + (key,), subary)) for key, subary in iterable + ]) return rec((), ary) @@ -309,12 +314,14 @@ def map_reduce_array_container( :class:`arraycontext.ArrayContext.array_types`. Returns an array of the same type or a scalar. """ - if is_array_container(ary): + try: + iterable = serialize_container(ary) + except NotImplementedError: + return map_func(ary) + else: return reduce_func([ - map_func(subary) for _, subary in serialize_container(ary) + map_func(subary) for _, subary in iterable ]) - else: - return map_func(ary) def multimap_reduce_array_container( @@ -382,12 +389,14 @@ def rec_map_reduce_array_container( or any other such traversal. """ def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: - if is_array_container(_ary): + try: + iterable = serialize_container(_ary) + except NotImplementedError: + return map_func(_ary) + else: return reduce_func([ - rec(subary) for _, subary in serialize_container(_ary) + rec(subary) for _, subary in iterable ]) - else: - return map_func(_ary) return rec(ary) @@ -472,13 +481,14 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT: in :mod:`meshmode`. This was necessary because :func:`~functools.singledispatch` only dispatches on the first argument. """ - if is_array_container(ary): + try: + iterable = serialize_container(ary) + except NotImplementedError: + return actx.thaw(ary) + else: return deserialize_container(ary, [ - (key, thaw(subary, actx)) - for key, subary in serialize_container(ary) + (key, thaw(subary, actx)) for key, subary in iterable ]) - else: - return actx.thaw(ary) # }}} diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 268eb3e0677b8db995a1fe6cc9069038123c4164..ed7c1b1ef49964d2651666ef318838a42288d476 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -134,6 +134,9 @@ class DOFArray: def __getitem__(self, i): return self.data[i] + def __repr__(self): + return f"DOFArray({repr(self.data)})" + @classmethod def _serialize_init_arrays_code(cls, instance_name): return {"_": @@ -669,26 +672,29 @@ class MyContainerDOFBcast: return self.mass.array_context -def _get_test_containers(actx, ambient_dim=2): - x = DOFArray(actx, (actx.from_numpy(np.random.randn(50_000)),)) +def _get_test_containers(actx, ambient_dim=2, size=50_000): + if size == 0: + x = DOFArray(actx, (actx.from_numpy(np.array(np.random.randn())),)) + else: + x = DOFArray(actx, (actx.from_numpy(np.random.randn(size)),)) # pylint: disable=unexpected-keyword-arg, no-value-for-parameter dataclass_of_dofs = MyContainer( name="container", mass=x, - momentum=make_obj_array([x, x]), + momentum=make_obj_array([x] * ambient_dim), enthalpy=x) # pylint: disable=unexpected-keyword-arg, no-value-for-parameter bcast_dataclass_of_dofs = MyContainerDOFBcast( name="container", mass=x, - momentum=make_obj_array([x, x]), + momentum=make_obj_array([x] * ambient_dim), enthalpy=x) ary_dof = x - ary_of_dofs = make_obj_array([x, x, x]) - mat_of_dofs = np.empty((3, 3), dtype=object) + ary_of_dofs = make_obj_array([x] * ambient_dim) + mat_of_dofs = np.empty((ambient_dim, ambient_dim), dtype=object) for i in np.ndindex(mat_of_dofs.shape): mat_of_dofs[i] = x @@ -696,6 +702,29 @@ def _get_test_containers(actx, ambient_dim=2): bcast_dataclass_of_dofs) +def test_container_scalar_map(actx_factory): + actx = actx_factory() + + arys = _get_test_containers(actx, size=0) + arys += (np.pi,) + + from arraycontext import ( + map_array_container, rec_map_array_container, + map_reduce_array_container, rec_map_reduce_array_container, + ) + + for ary in arys: + result = map_array_container(lambda x: x, ary) + assert result is not None + result = rec_map_array_container(lambda x: x, ary) + assert result is not None + + result = map_reduce_array_container(np.shape, lambda x: x, ary) + assert result is not None + result = rec_map_reduce_array_container(np.shape, lambda x: x, ary) + assert result is not None + + def test_container_multimap(actx_factory): actx = actx_factory() ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \ @@ -783,10 +812,11 @@ def test_container_arithmetic(actx_factory): grad_matvec_result = mock_gradient @ ary_of_dofs assert isinstance(grad_matvec_result.mass, DOFArray) - assert grad_matvec_result.momentum.shape == (3,) + assert grad_matvec_result.momentum.shape == ary_of_dofs.shape - assert actx.to_numpy(actx.np.linalg.norm(grad_matvec_result.mass - - 3*ary_of_dofs**2)) < 1e-8 + assert actx.to_numpy(actx.np.linalg.norm( + grad_matvec_result.mass - sum(ary_of_dofs**2) + )) < 1e-8 # }}}