From d03855d1c03177854af4d4593259b726cf35ecd6 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Tue, 19 Oct 2021 21:52:44 -0500 Subject: [PATCH] convert some is_array_container_type checks --- arraycontext/container/__init__.py | 33 ++++++++--------- arraycontext/container/traversal.py | 45 ++++++++++++++++-------- arraycontext/fake_numpy.py | 35 +++++++++++------- arraycontext/impl/pyopencl/fake_numpy.py | 36 ++++++++++--------- arraycontext/impl/pytato/fake_numpy.py | 41 ++++++++++++--------- test/test_arraycontext.py | 25 +++++++++++++ 6 files changed, 140 insertions(+), 75 deletions(-) diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index d2f9761..bae0bcd 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -247,28 +247,29 @@ def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]: If different components that have different array contexts are found at any level, an assertion error is raised. """ - actx = None - if not is_array_container_type(ary.__class__): - return actx - # try getting the array context directly actx = get_container_context(ary) if actx is not None: return actx - for _, subary in serialize_container(ary): - context = get_container_context_recursively(subary) - if context is None: - continue - - if not __debug__: - return context - elif actx is None: - actx = context - else: - assert actx is context + try: + iterable = serialize_container(ary) + except TypeError: + return actx + else: + for _, subary in iterable: + context = get_container_context_recursively(subary) + if context is None: + continue + + if not __debug__: + return context + elif actx is None: + actx = context + else: + assert actx is context - return actx + return actx # }}} diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 85d665f..78eab3b 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -90,12 +90,15 @@ def _map_array_container_impl( def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: if type(_ary) is leaf_cls: # type(ary) is never None return f(_ary) - elif is_array_container_type(_ary.__class__): + + try: + iterable = serialize_container(_ary) + except TypeError: + return f(_ary) + else: return deserialize_container(_ary, [ - (key, frec(subary)) for key, subary in serialize_container(_ary) + (key, frec(subary)) for key, subary in iterable ]) - else: - return f(_ary) frec = rec if recursive else f return rec(ary) @@ -457,9 +460,9 @@ def freeze( See :meth:`ArrayContext.thaw`. """ - if is_array_container_type(ary.__class__): - return map_array_container(partial(freeze, actx=actx), ary) - else: + try: + iterable = serialize_container(ary) + except TypeError: if actx is None: raise TypeError( f"cannot freeze arrays of type {type(ary).__name__} " @@ -467,6 +470,10 @@ def freeze( "directly or supplying an array context") else: return actx.freeze(ary) + else: + return deserialize_container(ary, [ + (key, freeze(subary, actx=actx)) for key, subary in iterable + ]) @singledispatch @@ -635,15 +642,15 @@ def from_numpy(ary: Any, actx: ArrayContext) -> Any: The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`. """ - def _from_numpy(subary: Any) -> Any: - if isinstance(subary, np.ndarray) and subary.dtype != "O": + def _from_numpy_with_check(subary: Any) -> Any: + if np.isscalar(subary): + return subary + elif isinstance(subary, np.ndarray): return actx.from_numpy(subary) - elif is_array_container_type(subary.__class__): - return map_array_container(_from_numpy, subary) else: - raise TypeError(f"unrecognized array type: '{type(subary).__name__}'") + raise TypeError(f"array is not an ndarray: '{type(subary).__name__}'") - return _from_numpy(ary) + return rec_map_array_container(_from_numpy_with_check, ary) def to_numpy(ary: Any, actx: ArrayContext) -> Any: @@ -652,7 +659,17 @@ def to_numpy(ary: Any, actx: ArrayContext) -> Any: The conversion is done using :meth:`arraycontext.ArrayContext.to_numpy`. """ - return rec_map_array_container(actx.to_numpy, ary) + def _to_numpy_with_check(subary: Any) -> Any: + if np.isscalar(subary): + return subary + elif isinstance(subary, actx.array_types): + return actx.to_numpy(subary) + else: + raise TypeError( + f"array of type '{type(subary).__name__}' not in " + f"supported types {actx.array_types}") + + return rec_map_array_container(_to_numpy_with_check, ary) # }}} diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index 4e26880..da72b7d 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -24,7 +24,7 @@ THE SOFTWARE. import numpy as np -from arraycontext.container import is_array_container_type, serialize_container +from arraycontext.container import serialize_container, deserialize_container from arraycontext.container.traversal import ( rec_map_array_container, multimapped_over_array_containers) from pytools import memoize_in @@ -176,20 +176,28 @@ class BaseFakeNumpyNamespace: def _new_like(self, ary, alloc_like): from numbers import Number + if isinstance(ary, Number): + # NOTE: `np.zeros_like(x)` returns `array(x, shape=())`, which + # is best implemented by concrete array contexts, if at all + raise NotImplementedError("operation not implemented for scalars") if isinstance(ary, np.ndarray) and ary.dtype.char == "O": # NOTE: we don't want to match numpy semantics on object arrays, # e.g. `np.zeros_like(x)` returns `array([0, 0, ...], dtype=object)` # FIXME: what about object arrays nested in an ArrayContainer? raise NotImplementedError("operation not implemented for object arrays") - elif is_array_container_type(ary.__class__): - return rec_map_array_container(alloc_like, ary) - elif isinstance(ary, Number): - # NOTE: `np.zeros_like(x)` returns `array(x, shape=())`, which - # is best implemented by concrete array contexts, if at all - raise NotImplementedError("operation not implemented for scalars") - else: - return alloc_like(ary) + + def _new_like_container(_ary): + try: + iterable = serialize_container(_ary) + except TypeError: + return alloc_like(_ary) + else: + return deserialize_container(_ary, [ + (key, alloc_like(subary)) for key, subary in iterable + ]) + + return _new_like_container(ary) def empty_like(self, ary): return self._new_like(ary, self._array_context.empty_like) @@ -258,10 +266,13 @@ class BaseFakeNumpyLinalgNamespace: return flat_norm(ary, ord=ord) - if is_array_container_type(ary.__class__): + try: + iterable = serialize_container(ary) + except TypeError: + pass + else: return _reduce_norm(actx, [ - self.norm(subary, ord=ord) - for _, subary in serialize_container(ary) + self.norm(subary, ord=ord) for _, subary in iterable ], ord=ord) if ord is None: diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index c8c1715..83c1b43 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -29,13 +29,14 @@ THE SOFTWARE. from functools import partial, reduce import operator +import numpy as np + from arraycontext.fake_numpy import \ BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace -from arraycontext.container import is_array_container_type +from arraycontext.container import serialize_container from arraycontext.container.traversal import ( rec_map_array_container, rec_multimap_array_container, - multimap_reduce_array_container, rec_map_reduce_array_container, rec_multimap_reduce_array_container, ) @@ -239,32 +240,33 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace): return result def array_equal(self, a, b): - def as_device_scalar(bool_value): - import numpy as np - return self._array_context.from_numpy( - np.array(int(bool_value), dtype=np.int8)) + actx = self._array_context + queue = actx.queue + + # NOTE: pyopencl doesn't like `bool` much, so use `int8` instead + false = actx.from_numpy(np.int8(False)) - # Do recursion separately from device-to-host conversion (below) so that - # we don't pass host booleans to cl_array.minimum def rec_equal(x, y): if type(x) != type(y): - return as_device_scalar(False) - elif not is_array_container_type(x.__class__): + return false + + try: + iterable = zip(serialize_container(x), serialize_container(y)) + except TypeError: if x.shape != y.shape: - return as_device_scalar(False) + return false else: return (x == y).all() else: - queue = self._array_context.queue - reduce_func = partial(reduce, partial(cl_array.minimum, queue=queue)) - map_func = rec_equal - return multimap_reduce_array_container( - reduce_func, map_func, x, y) + return reduce( + partial(cl_array.minimum, queue=queue), + [rec_equal(ix, iy)for (_, ix), (_, iy) in iterable] + ) result = rec_equal(a, b) - if not self._array_context._force_device_scalars: result = result.get()[()] + return result # }}} diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index dacf727..389ea07 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -23,14 +23,15 @@ THE SOFTWARE. """ from functools import partial, reduce +import numpy as np + from arraycontext.fake_numpy import ( BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace, ) -from arraycontext.container import is_array_container_type +from arraycontext.container import serialize_container from arraycontext.container.traversal import ( rec_map_array_container, rec_multimap_array_container, - multimap_reduce_array_container, rec_map_reduce_array_container, ) import pytato as pt @@ -174,20 +175,28 @@ class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): lambda subary: pt.all(subary), a) def array_equal(self, a, b): - def as_device_scalar(bool_value): - import numpy as np - return self._array_context.from_numpy( - np.array(int(bool_value), dtype=np.int8)) - - if type(a) != type(b): - return as_device_scalar(False) - elif not is_array_container_type(a.__class__): - if a.shape != b.shape: - return as_device_scalar(False) + actx = self._array_context + + # NOTE: not all backends support `bool` properly, so use `int8` instead + false = actx.from_numpy(np.int8(False)) + + def rec_equal(x, y): + if type(x) != type(y): + return false + + try: + iterable = zip(serialize_container(x), serialize_container(y)) + except TypeError: + if x.shape != y.shape: + return false + else: + return pt.all(pt.equal(x, y)) else: - return pt.all(pt.equal(a, b)) - else: - return multimap_reduce_array_container( - partial(reduce, pt.logical_and), self.array_equal, a, b) + return reduce( + pt.logical_and, + [rec_equal(ix, iy) for (_, ix), (_, iy) in iterable] + ) + + return rec_equal(a, b) # }}} diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 07b4c37..632c95b 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1243,6 +1243,31 @@ def test_outer(actx_factory): # }}} +# {{{ test_array_container_with_numpy + +@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True) +@dataclass_array_container +@dataclass(frozen=True) +class ArrayContainerWithNumpy: + u: np.ndarray + v: DOFArray + + +def test_array_container_with_numpy(actx_factory): + actx = actx_factory() + + mystate = ArrayContainerWithNumpy( + u=np.zeros(10), + v=DOFArray(actx, (actx.from_numpy(np.zeros(42)),)), + ) + + from arraycontext import rec_map_array_container + rec_map_array_container(lambda x: x, mystate) + + +# }}} + + if __name__ == "__main__": import sys if len(sys.argv) > 1: -- GitLab