From 6eaf518648ccb6a393d6a1c499b1a1c24955d155 Mon Sep 17 00:00:00 2001
From: Matt Smith <mjsmith6@illinois.edu>
Date: Wed, 20 Oct 2021 15:13:49 -0500
Subject: [PATCH] Add alternate outer product (#46)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* Make broadcast settings visible on users of with_container_arithmetic

* Remove stray debug print

Co-authored-by: Alex Fikl <alexfikl@gmail.com>

* Fix _outer_bcast_types attribute name

* add alternate outer product

* add empty line

Co-authored-by: Andreas Klöckner <inform@tiker.net>

* add comment explaining _outer_bcast_types

Co-authored-by: Andreas Klöckner <inform@tiker.net>

* add more detail to docstring first line

Co-authored-by: Andreas Klöckner <inform@tiker.net>

* rename is_scalar -> treat_as_scalar

* treat non-object numpy arrays as scalars in outer

* remove use of deprecated is_array_container

* clarify usage of isinstance(..., ndarray) for object array detection

Co-authored-by: Andreas Klöckner <inform@tiker.net>

Co-authored-by: Andreas Kloeckner <inform@tiker.net>
Co-authored-by: Alex Fikl <alexfikl@gmail.com>
---
 arraycontext/__init__.py             |  4 +-
 arraycontext/container/arithmetic.py |  5 ++
 arraycontext/container/traversal.py  | 50 ++++++++++++++++++
 test/test_arraycontext.py            | 77 ++++++++++++++++++++++++++++
 4 files changed, 135 insertions(+), 1 deletion(-)

diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py
index c0bca77..0147e62 100644
--- a/arraycontext/__init__.py
+++ b/arraycontext/__init__.py
@@ -59,7 +59,8 @@ from .container.traversal import (
         rec_multimap_reduce_array_container,
         thaw, freeze,
         flatten, unflatten,
-        from_numpy, to_numpy)
+        from_numpy, to_numpy,
+        outer)
 
 from .impl.pyopencl import PyOpenCLArrayContext
 from .impl.pytato import PytatoPyOpenCLArrayContext
@@ -95,6 +96,7 @@ __all__ = (
         "thaw", "freeze",
         "flatten", "unflatten",
         "from_numpy", "to_numpy",
+        "outer",
 
         "PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext",
 
diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py
index 3ade1b3..df5b662 100644
--- a/arraycontext/container/arithmetic.py
+++ b/arraycontext/container/arithmetic.py
@@ -303,6 +303,11 @@ def with_container_arithmetic(
             else:
                 return "(%s,)" % ", ".join(t)
 
+        gen(f"cls._outer_bcast_types = {tup_str(outer_bcast_type_names)}")
+        gen(f"cls._bcast_numpy_array = {bcast_numpy_array}")
+        gen(f"cls._bcast_obj_array = {bcast_obj_array}")
+        gen("")
+
         # {{{ unary operators
 
         for dunder_name, op_str, op_cls in _UNARY_OP_AND_DUNDER:
diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 8d0f9f3..85d665f 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -32,6 +32,10 @@ Numpy conversion
 ~~~~~~~~~~~~~~~~
 .. autofunction:: from_numpy
 .. autofunction:: to_numpy
+
+Algebraic operations
+~~~~~~~~~~~~~~~~~~~~
+.. autofunction:: outer
 """
 
 __copyright__ = """
@@ -652,4 +656,50 @@ def to_numpy(ary: Any, actx: ArrayContext) -> Any:
 
 # }}}
 
+
+# {{{ algebraic operations
+
+def outer(a: Any, b: Any) -> Any:
+    """
+    Compute the outer product of *a* and *b* while allowing either of them
+    to be an :class:`ArrayContainer`.
+
+    Tweaks the behavior of :func:`numpy.outer` to return a lower-dimensional
+    object if either/both of *a* and *b* are scalars (whereas :func:`numpy.outer`
+    always returns a matrix). Here the definition of "scalar" includes
+    all non-array-container types and any scalar-like array container types
+    (including non-object numpy arrays).
+
+    If *a* and *b* are both array containers, the result will have the same type
+    as *a*. If both are array containers and neither is an object array, they must
+    have the same type.
+    """
+
+    def treat_as_scalar(x: Any) -> bool:
+        try:
+            serialize_container(x)
+        except TypeError:
+            return True
+        else:
+            return (
+                not isinstance(x, np.ndarray)
+                # This condition is whether "ndarrays should broadcast inside x".
+                and np.ndarray not in x.__class__._outer_bcast_types)
+
+    if treat_as_scalar(a) or treat_as_scalar(b):
+        return a*b
+    # After this point, "isinstance(o, ndarray)" means o is an object array.
+    elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
+        return np.outer(a, b)
+    elif isinstance(a, np.ndarray) or isinstance(b, np.ndarray):
+        return map_array_container(lambda x: outer(x, b), a)
+    else:
+        if type(a) != type(b):
+            raise TypeError(
+                "both arguments must have the same type if they are both "
+                "non-object-array array containers.")
+        return multimap_array_container(lambda x, y: outer(x, y), a, b)
+
+# }}}
+
 # vim: foldmethod=marker
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 0dcb1a6..07b4c37 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -1166,6 +1166,83 @@ def test_leaf_array_type_broadcasting(actx_factory):
 # }}}
 
 
+# {{{ test outer product
+
+def test_outer(actx_factory):
+    actx = actx_factory()
+
+    a_dof, a_ary_of_dofs, _, _, a_bcast_dc_of_dofs = _get_test_containers(actx)
+
+    b_dof = a_dof + 1
+    b_ary_of_dofs = a_ary_of_dofs + 1
+    b_bcast_dc_of_dofs = a_bcast_dc_of_dofs + 1
+
+    from arraycontext import outer
+
+    def equal(a, b):
+        return actx.to_numpy(actx.np.array_equal(a, b))
+
+    # Two scalars
+    assert equal(outer(a_dof, b_dof), a_dof*b_dof)
+
+    # Scalar and vector
+    assert equal(outer(a_dof, b_ary_of_dofs), a_dof*b_ary_of_dofs)
+
+    # Vector and scalar
+    assert equal(outer(a_ary_of_dofs, b_dof), a_ary_of_dofs*b_dof)
+
+    # Two vectors
+    assert equal(
+        outer(a_ary_of_dofs, b_ary_of_dofs),
+        np.outer(a_ary_of_dofs, b_ary_of_dofs))
+
+    # Scalar and array container
+    assert equal(
+        outer(a_dof, b_bcast_dc_of_dofs),
+        a_dof*b_bcast_dc_of_dofs)
+
+    # Array container and scalar
+    assert equal(
+        outer(a_bcast_dc_of_dofs, b_dof),
+        a_bcast_dc_of_dofs*b_dof)
+
+    # Vector and array container
+    assert equal(
+        outer(a_ary_of_dofs, b_bcast_dc_of_dofs),
+        make_obj_array([a_i*b_bcast_dc_of_dofs for a_i in a_ary_of_dofs]))
+
+    # Array container and vector
+    assert equal(
+        outer(a_bcast_dc_of_dofs, b_ary_of_dofs),
+        MyContainerDOFBcast(
+            name="container",
+            mass=a_bcast_dc_of_dofs.mass*b_ary_of_dofs,
+            momentum=np.outer(a_bcast_dc_of_dofs.momentum, b_ary_of_dofs),
+            enthalpy=a_bcast_dc_of_dofs.enthalpy*b_ary_of_dofs))
+
+    # Two array containers
+    assert equal(
+        outer(a_bcast_dc_of_dofs, b_bcast_dc_of_dofs),
+        MyContainerDOFBcast(
+            name="container",
+            mass=a_bcast_dc_of_dofs.mass*b_bcast_dc_of_dofs.mass,
+            momentum=np.outer(
+                a_bcast_dc_of_dofs.momentum,
+                b_bcast_dc_of_dofs.momentum),
+            enthalpy=a_bcast_dc_of_dofs.enthalpy*b_bcast_dc_of_dofs.enthalpy))
+
+    # Non-object numpy arrays should be treated as scalars
+    ary_of_floats = np.ones(len(b_bcast_dc_of_dofs.mass))
+    assert equal(
+        outer(ary_of_floats, b_bcast_dc_of_dofs),
+        ary_of_floats*b_bcast_dc_of_dofs)
+    assert equal(
+        outer(a_bcast_dc_of_dofs, ary_of_floats),
+        a_bcast_dc_of_dofs*ary_of_floats)
+
+# }}}
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1:
-- 
GitLab