diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 78eab3b711f89a73d058dff93e02d60cbf48fbae..badb7d6052026dcacc2a6ec2e2f267bf5f58ef2c 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -69,7 +69,7 @@ import numpy as np from arraycontext.context import ArrayContext from arraycontext.container import ( - ContainerT, ArrayOrContainerT, is_array_container_type, + ContainerT, ArrayOrContainerT, serialize_container, deserialize_container) @@ -117,12 +117,21 @@ def _multimap_array_container_impl( specific container classes. By default, the recursion is stopped when a non-:class:`ArrayContainer` class is encountered. """ + + # {{{ recursive traversal + def rec(*_args: Any) -> Any: template_ary = _args[container_indices[0]] - if (type(template_ary) is leaf_cls - or not is_array_container_type(template_ary.__class__)): + if type(template_ary) is leaf_cls: return f(*_args) + try: + iterable_template = serialize_container(template_ary) + except TypeError: + return f(*_args) + else: + pass + assert all( type(_args[i]) is type(template_ary) for i in container_indices[1:] ), f"expected type '{type(template_ary).__name__}'" @@ -130,9 +139,10 @@ def _multimap_array_container_impl( result = [] new_args = list(_args) - for subarys in zip(*[ - serialize_container(_args[i]) for i in container_indices - ]): + for subarys in zip( + iterable_template, + *[serialize_container(_args[i]) for i in container_indices[1:]] + ): key = None for i, (subkey, subary) in zip(container_indices, subarys): if key is None: @@ -146,13 +156,36 @@ def _multimap_array_container_impl( return process_container(template_ary, result) # type: ignore[operator] - container_indices: List[int] = [ - i for i, arg in enumerate(args) - if is_array_container_type(arg.__class__) and type(arg) is not leaf_cls] + # }}} + + # {{{ find all containers in the argument list + + container_indices: List[int] = [] + + for i, arg in enumerate(args): + if type(arg) is leaf_cls: + continue + + try: + # FIXME: this will serialize again once `rec` is called, which is + # not great, but it doesn't seem like there's a good way to avoid it + _ = serialize_container(arg) + except TypeError: + pass + else: + container_indices.append(i) + + # }}} + + # {{{ #containers == 0 => call `f` if not container_indices: return f(*args) + # }}} + + # {{{ #containers == 1 => call `map_array_container` + if len(container_indices) == 1 and reduce_func is None: # NOTE: if we just have one ArrayContainer in args, passing it through # _map_array_container_impl should be faster @@ -167,9 +200,15 @@ def _multimap_array_container_impl( wrapper, template_ary, leaf_cls=leaf_cls, recursive=recursive) + # }}} + + # {{{ #containers > 1 => call `rec` + process_container = deserialize_container if reduce_func is None else reduce_func frec = rec if recursive else f + # }}} + return rec(*args) # }}} diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index a7e74bd3a1628897b2f86c59a968313cfbd5572f..a772a85a5705793dc3313131cf19add769d8dd3f 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -28,10 +28,9 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from arraycontext.container import ArrayContainer +from arraycontext.container import ArrayContainer, is_array_container_type from arraycontext import PytatoPyOpenCLArrayContext -from arraycontext.container.traversal import (rec_keyed_map_array_container, - is_array_container_type) +from arraycontext.container.traversal import rec_keyed_map_array_container import numpy as np from typing import Any, Callable, Tuple, Dict, Mapping @@ -162,8 +161,8 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name): def _rec_to_placeholder(keys, ary): name = arg_id_to_name[(kw,) + keys] return pt.make_placeholder(name, ary.shape, ary.dtype) - return rec_keyed_map_array_container(_rec_to_placeholder, - arg) + + return rec_keyed_map_array_container(_rec_to_placeholder, arg) else: raise NotImplementedError(type(arg))