From 6eaf518648ccb6a393d6a1c499b1a1c24955d155 Mon Sep 17 00:00:00 2001 From: Matt Smith <mjsmith6@illinois.edu> Date: Wed, 20 Oct 2021 15:13:49 -0500 Subject: [PATCH] Add alternate outer product (#46) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Make broadcast settings visible on users of with_container_arithmetic * Remove stray debug print Co-authored-by: Alex Fikl <alexfikl@gmail.com> * Fix _outer_bcast_types attribute name * add alternate outer product * add empty line Co-authored-by: Andreas Klöckner <inform@tiker.net> * add comment explaining _outer_bcast_types Co-authored-by: Andreas Klöckner <inform@tiker.net> * add more detail to docstring first line Co-authored-by: Andreas Klöckner <inform@tiker.net> * rename is_scalar -> treat_as_scalar * treat non-object numpy arrays as scalars in outer * remove use of deprecated is_array_container * clarify usage of isinstance(..., ndarray) for object array detection Co-authored-by: Andreas Klöckner <inform@tiker.net> Co-authored-by: Andreas Kloeckner <inform@tiker.net> Co-authored-by: Alex Fikl <alexfikl@gmail.com> --- arraycontext/__init__.py | 4 +- arraycontext/container/arithmetic.py | 5 ++ arraycontext/container/traversal.py | 50 ++++++++++++++++++ test/test_arraycontext.py | 77 ++++++++++++++++++++++++++++ 4 files changed, 135 insertions(+), 1 deletion(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index c0bca77..0147e62 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -59,7 +59,8 @@ from .container.traversal import ( rec_multimap_reduce_array_container, thaw, freeze, flatten, unflatten, - from_numpy, to_numpy) + from_numpy, to_numpy, + outer) from .impl.pyopencl import PyOpenCLArrayContext from .impl.pytato import PytatoPyOpenCLArrayContext @@ -95,6 +96,7 @@ __all__ = ( "thaw", "freeze", "flatten", "unflatten", "from_numpy", "to_numpy", + "outer", "PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext", diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 3ade1b3..df5b662 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -303,6 +303,11 @@ def with_container_arithmetic( else: return "(%s,)" % ", ".join(t) + gen(f"cls._outer_bcast_types = {tup_str(outer_bcast_type_names)}") + gen(f"cls._bcast_numpy_array = {bcast_numpy_array}") + gen(f"cls._bcast_obj_array = {bcast_obj_array}") + gen("") + # {{{ unary operators for dunder_name, op_str, op_cls in _UNARY_OP_AND_DUNDER: diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 8d0f9f3..85d665f 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -32,6 +32,10 @@ Numpy conversion ~~~~~~~~~~~~~~~~ .. autofunction:: from_numpy .. autofunction:: to_numpy + +Algebraic operations +~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: outer """ __copyright__ = """ @@ -652,4 +656,50 @@ def to_numpy(ary: Any, actx: ArrayContext) -> Any: # }}} + +# {{{ algebraic operations + +def outer(a: Any, b: Any) -> Any: + """ + Compute the outer product of *a* and *b* while allowing either of them + to be an :class:`ArrayContainer`. + + Tweaks the behavior of :func:`numpy.outer` to return a lower-dimensional + object if either/both of *a* and *b* are scalars (whereas :func:`numpy.outer` + always returns a matrix). Here the definition of "scalar" includes + all non-array-container types and any scalar-like array container types + (including non-object numpy arrays). + + If *a* and *b* are both array containers, the result will have the same type + as *a*. If both are array containers and neither is an object array, they must + have the same type. + """ + + def treat_as_scalar(x: Any) -> bool: + try: + serialize_container(x) + except TypeError: + return True + else: + return ( + not isinstance(x, np.ndarray) + # This condition is whether "ndarrays should broadcast inside x". + and np.ndarray not in x.__class__._outer_bcast_types) + + if treat_as_scalar(a) or treat_as_scalar(b): + return a*b + # After this point, "isinstance(o, ndarray)" means o is an object array. + elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + return np.outer(a, b) + elif isinstance(a, np.ndarray) or isinstance(b, np.ndarray): + return map_array_container(lambda x: outer(x, b), a) + else: + if type(a) != type(b): + raise TypeError( + "both arguments must have the same type if they are both " + "non-object-array array containers.") + return multimap_array_container(lambda x, y: outer(x, y), a, b) + +# }}} + # vim: foldmethod=marker diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 0dcb1a6..07b4c37 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1166,6 +1166,83 @@ def test_leaf_array_type_broadcasting(actx_factory): # }}} +# {{{ test outer product + +def test_outer(actx_factory): + actx = actx_factory() + + a_dof, a_ary_of_dofs, _, _, a_bcast_dc_of_dofs = _get_test_containers(actx) + + b_dof = a_dof + 1 + b_ary_of_dofs = a_ary_of_dofs + 1 + b_bcast_dc_of_dofs = a_bcast_dc_of_dofs + 1 + + from arraycontext import outer + + def equal(a, b): + return actx.to_numpy(actx.np.array_equal(a, b)) + + # Two scalars + assert equal(outer(a_dof, b_dof), a_dof*b_dof) + + # Scalar and vector + assert equal(outer(a_dof, b_ary_of_dofs), a_dof*b_ary_of_dofs) + + # Vector and scalar + assert equal(outer(a_ary_of_dofs, b_dof), a_ary_of_dofs*b_dof) + + # Two vectors + assert equal( + outer(a_ary_of_dofs, b_ary_of_dofs), + np.outer(a_ary_of_dofs, b_ary_of_dofs)) + + # Scalar and array container + assert equal( + outer(a_dof, b_bcast_dc_of_dofs), + a_dof*b_bcast_dc_of_dofs) + + # Array container and scalar + assert equal( + outer(a_bcast_dc_of_dofs, b_dof), + a_bcast_dc_of_dofs*b_dof) + + # Vector and array container + assert equal( + outer(a_ary_of_dofs, b_bcast_dc_of_dofs), + make_obj_array([a_i*b_bcast_dc_of_dofs for a_i in a_ary_of_dofs])) + + # Array container and vector + assert equal( + outer(a_bcast_dc_of_dofs, b_ary_of_dofs), + MyContainerDOFBcast( + name="container", + mass=a_bcast_dc_of_dofs.mass*b_ary_of_dofs, + momentum=np.outer(a_bcast_dc_of_dofs.momentum, b_ary_of_dofs), + enthalpy=a_bcast_dc_of_dofs.enthalpy*b_ary_of_dofs)) + + # Two array containers + assert equal( + outer(a_bcast_dc_of_dofs, b_bcast_dc_of_dofs), + MyContainerDOFBcast( + name="container", + mass=a_bcast_dc_of_dofs.mass*b_bcast_dc_of_dofs.mass, + momentum=np.outer( + a_bcast_dc_of_dofs.momentum, + b_bcast_dc_of_dofs.momentum), + enthalpy=a_bcast_dc_of_dofs.enthalpy*b_bcast_dc_of_dofs.enthalpy)) + + # Non-object numpy arrays should be treated as scalars + ary_of_floats = np.ones(len(b_bcast_dc_of_dofs.mass)) + assert equal( + outer(ary_of_floats, b_bcast_dc_of_dofs), + ary_of_floats*b_bcast_dc_of_dofs) + assert equal( + outer(a_bcast_dc_of_dofs, ary_of_floats), + a_bcast_dc_of_dofs*ary_of_floats) + +# }}} + + if __name__ == "__main__": import sys if len(sys.argv) > 1: -- GitLab