Skip to content
Snippets Groups Projects
Unverified Commit e5ff3c6e authored by Andreas Klöckner's avatar Andreas Klöckner Committed by GitHub
Browse files

Implement, test container-over-container broadcasting (#15)


* Implement, test container-over-container broadcasting

* Tweak bcast_container_types arg doc

Co-authored-by: default avatarAlex Fikl <alexfikl@gmail.com>

* Improvements to with_container_arithmetic

- Support matmul with a separate operator class
- Join container and number broadcasting code paths
- Make args keyword-only

Co-authored-by: default avatarAlex Fikl <alexfikl@gmail.com>
parent 4e81f3a7
No related branches found
No related tags found
No related merge requests found
......@@ -30,10 +30,14 @@ THE SOFTWARE.
"""
import numpy as np
# {{{ with_container_arithmetic
class _OpClass(enum.Enum):
ARITHMETIC = enum.auto
MATMUL = enum.auto
BITWISE = enum.auto
SHIFT = enum.auto
EQ_COMPARISON = enum.auto
......@@ -56,6 +60,8 @@ _BINARY_OP_AND_DUNDER = [
("mod", "{} % {}", True, _OpClass.ARITHMETIC),
("divmod", "divmod({}, {})", True, _OpClass.ARITHMETIC),
("matmul", "{} @ {}", True, _OpClass.MATMUL),
("and", "{} & {}", True, _OpClass.BITWISE),
("or", "{} | {}", True, _OpClass.BITWISE),
("xor", "{} ^ {}", True, _OpClass.BITWISE),
......@@ -109,9 +115,10 @@ def _format_binary_op_str(op_str, arg1, arg2):
return op_str.format(arg1, arg2)
def with_container_arithmetic(
def with_container_arithmetic(*,
bcast_number=True, bcast_obj_array=None, bcast_numpy_array=False,
arithmetic=True, bitwise=False, shift=False,
bcast_container_types=None,
arithmetic=True, matmul=False, bitwise=False, shift=False,
eq_comparison=None, rel_comparison=None):
"""A class decorator that implements built-in operators for array containers
by propagating the operations to the elements of the container.
......@@ -122,6 +129,13 @@ def with_container_arithmetic(
the container. (with the container as the 'inner' structure)
:arg bcast_numpy_array: If *True*, any :class:`numpy.ndarray` will broadcast
over the container. (with the container as the 'inner' structure)
If this is set to *True*, *bcast_obj_array* must also be *True*.
:arg bcast_container_types: A sequence of container types that will broadcast
over this container (with this container as the 'outer' structure).
:class:`numpy.ndarray` is permitted to be part of this sequence to
indicate that, in such broadcasting situations, this container should
be the 'outer' structure. In this case, *bcast_obj_array*
(and consequently *bcast_numpy_array*) must be *False*.
:arg arithmetic: Implement the conventional arithmetic operators, including
``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as
:func:`abs`.
......@@ -183,9 +197,18 @@ def with_container_arithmetic(
def numpy_pred(name):
return "False" # optimized away
if bcast_container_types is None:
bcast_container_types = ()
if np.ndarray in bcast_container_types and bcast_obj_array:
raise ValueError("If numpy.ndarray is part of bcast_container_types, "
"bcast_obj_array must be False.")
desired_op_classes = set()
if arithmetic:
desired_op_classes.add(_OpClass.ARITHMETIC)
if matmul:
desired_op_classes.add(_OpClass.MATMUL)
if bitwise:
desired_op_classes.add(_OpClass.BITWISE)
if shift:
......@@ -215,10 +238,25 @@ def with_container_arithmetic(
""")
gen("")
if bcast_container_types:
for i, bct in enumerate(bcast_container_types):
gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}")
gen("")
outer_bcast_type_names = [
f"_bctype{i}" for i in range(len(bcast_container_types))]
if bcast_number:
outer_bcast_type_names.append("Number")
def same_key(k1, k2):
assert k1 == k2
return k1
def tup_str(t):
if not t:
return "()"
else:
return "(%s,)" % ", ".join(t)
# {{{ unary operators
for dunder_name, op_str, op_cls in _UNARY_OP_AND_DUNDER:
......@@ -266,10 +304,10 @@ def with_container_arithmetic(
def {fname}(arg1, arg2):
if arg2.__class__ is cls:
return cls({zip_init_args})
if {bcast_number}: # optimized away
if isinstance(arg2, Number):
if {bool(outer_bcast_type_names)}: # optimized away
if isinstance(arg2, {tup_str(outer_bcast_type_names)}):
return cls({bcast_init_args})
if {numpy_pred("arg2")}:
if {numpy_pred("arg2")}: # optimized away
result = np.empty_like(arg2, dtype=object)
for i in np.ndindex(arg2.shape):
result[i] = {op_str.format("arg1", "arg2[i]")}
......@@ -294,10 +332,10 @@ def with_container_arithmetic(
def {fname}(arg2, arg1):
# assert other.__cls__ is not cls
if {bcast_number}: # optimized away
if isinstance(arg1, Number):
if {bool(outer_bcast_type_names)}: # optimized away
if isinstance(arg1, {tup_str(outer_bcast_type_names)}):
return cls({bcast_init_args})
if {numpy_pred("arg1")}:
if {numpy_pred("arg1")}: # optimized away
result = np.empty_like(arg1, dtype=object)
for i in np.ndindex(arg1.shape):
result[i] = {op_str.format("arg1[i]", "arg2")}
......
......@@ -461,6 +461,24 @@ class MyContainer:
return self.mass.array_context
@with_container_arithmetic(
bcast_obj_array=False,
bcast_container_types=(DOFArray, np.ndarray),
matmul=True,
rel_comparison=True,)
@dataclass_array_container
@dataclass(frozen=True)
class MyContainerDOFBcast:
name: str
mass: DOFArray
momentum: np.ndarray
enthalpy: DOFArray
@property
def array_context(self):
return self.mass.array_context
def _get_test_containers(actx, ambient_dim=2):
x = DOFArray(actx, (actx.from_numpy(np.random.randn(50_000)),))
......@@ -471,18 +489,27 @@ def _get_test_containers(actx, ambient_dim=2):
momentum=make_obj_array([x, x]),
enthalpy=x)
# pylint: disable=unexpected-keyword-arg, no-value-for-parameter
bcast_dataclass_of_dofs = MyContainerDOFBcast(
name="container",
mass=x,
momentum=make_obj_array([x, x]),
enthalpy=x)
ary_dof = x
ary_of_dofs = make_obj_array([x, x, x])
mat_of_dofs = np.empty((2, 2), dtype=object)
mat_of_dofs = np.empty((3, 3), dtype=object)
for i in np.ndindex(mat_of_dofs.shape):
mat_of_dofs[i] = x
return ary_dof, ary_of_dofs, mat_of_dofs, dataclass_of_dofs
return (ary_dof, ary_of_dofs, mat_of_dofs, dataclass_of_dofs,
bcast_dataclass_of_dofs)
def test_container_multimap(actx_factory):
actx = actx_factory()
ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs = _get_test_containers(actx)
ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
_get_test_containers(actx)
# {{{ check
......@@ -522,7 +549,8 @@ def test_container_multimap(actx_factory):
def test_container_arithmetic(actx_factory):
actx = actx_factory()
ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs = _get_test_containers(actx)
ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
_get_test_containers(actx)
# {{{ check
......@@ -542,12 +570,38 @@ def test_container_arithmetic(actx_factory):
with pytest.raises(TypeError):
ary_of_dofs + dc_of_dofs
with pytest.raises(TypeError):
dc_of_dofs + ary_of_dofs
with pytest.raises(TypeError):
ary_dof + dc_of_dofs
with pytest.raises(TypeError):
dc_of_dofs + ary_dof
bcast_result = ary_dof + bcast_dc_of_dofs
bcast_dc_of_dofs + ary_dof
assert actx.np.linalg.norm(bcast_result.mass - 2*ary_of_dofs) < 1e-8
mock_gradient = MyContainerDOFBcast(
name="yo",
mass=ary_of_dofs,
momentum=mat_of_dofs,
enthalpy=ary_of_dofs)
grad_matvec_result = mock_gradient @ ary_of_dofs
assert isinstance(grad_matvec_result.mass, DOFArray)
assert grad_matvec_result.momentum.shape == (3,)
assert actx.np.linalg.norm(grad_matvec_result.mass - 3*ary_of_dofs**2) < 1e-8
# }}}
def test_container_freeze_thaw(actx_factory):
actx = actx_factory()
ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs = _get_test_containers(actx)
ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
_get_test_containers(actx)
# {{{ check
......@@ -577,7 +631,8 @@ def test_container_freeze_thaw(actx_factory):
def test_container_norm(actx_factory, ord):
actx = actx_factory()
ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs = _get_test_containers(actx)
ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
_get_test_containers(actx)
from pytools.obj_array import make_obj_array
c = MyContainer(name="hey", mass=1, momentum=make_obj_array([2, 3]), enthalpy=5)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment