Skip to content
Snippets Groups Projects
Commit ff1cd0cf authored by Alexandru Fikl's avatar Alexandru Fikl Committed by Andreas Klöckner
Browse files

rearrange jax.fake_numpy to match other contexts

parent be1429c2
No related branches found
No related tags found
No related merge requests found
Pipeline #309657 passed
......@@ -50,40 +50,81 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace):
def __getattr__(self, name):
return partial(rec_multimap_array_container, getattr(jnp, name))
# NOTE: the order of these follows the order in numpy docs
# NOTE: when adding a function here, also add it to `array_context.rst` docs!
# {{{ array creation routines
def ones_like(self, ary):
return self.full_like(ary, 1)
def full_like(self, ary, fill_value):
def _full_like(subary):
return jnp.full_like(ary, fill_value)
return self._new_like(ary, _full_like)
# }}}
# {{{ array manipulation routies
def reshape(self, a, newshape, order="C"):
return rec_map_array_container(
lambda ary: jnp.reshape(ary, newshape, order=order),
a)
def transpose(self, a, axes=None):
return rec_multimap_array_container(jnp.transpose, a, axes)
def ravel(self, a, order="C"):
"""
.. warning::
def concatenate(self, arrays, axis=0):
return rec_multimap_array_container(jnp.concatenate, arrays, axis)
Since :func:`jax.numpy.reshape` does not support orders `A`` and
``K``, in such cases we fallback to using ``order = C``.
"""
if order in "AK":
from warnings import warn
warn(f"ravel with order='{order}' not supported by JAX,"
" using order=C.")
order = "C"
def where(self, criterion, then, else_):
return rec_multimap_array_container(jnp.where, criterion, then, else_)
return rec_map_array_container(
lambda subary: jnp.ravel(subary, order=order), a)
def sum(self, a, axis=None, dtype=None):
return rec_map_reduce_array_container(sum,
partial(jnp.sum,
axis=axis,
dtype=dtype),
a)
def transpose(self, a, axes=None):
return rec_multimap_array_container(jnp.transpose, a, axes)
def min(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, jnp.minimum), partial(jnp.amin, axis=axis), a)
def broadcast_to(self, array, shape):
return rec_map_array_container(partial(jnp.broadcast_to, shape=shape), array)
def max(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, jnp.maximum), partial(jnp.amax, axis=axis), a)
def concatenate(self, arrays, axis=0):
return rec_multimap_array_container(jnp.concatenate, arrays, axis)
def stack(self, arrays, axis=0):
return rec_multimap_array_container(
lambda *args: jnp.stack(arrays=args, axis=axis),
*arrays)
# }}}
# {{{ linear algebra
def vdot(self, x, y, dtype=None):
from arraycontext import rec_multimap_reduce_array_container
def _rec_vdot(ary1, ary2):
if dtype not in [None, numpy.find_common_type((ary1.dtype,
ary2.dtype),
())]:
raise NotImplementedError(f"{type(self)} cannot take dtype in"
" vdot.")
return jnp.vdot(ary1, ary2)
return rec_multimap_reduce_array_container(sum, _rec_vdot, x, y)
# }}}
# {{{ logic functions
def array_equal(self, a, b):
actx = self._array_context
......@@ -109,35 +150,33 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace):
return rec_equal(a, b)
def ravel(self, a, order="C"):
"""
.. warning::
# }}}
Since :func:`jax.numpy.reshape` does not support orders `A`` and
``K``, in such cases we fallback to using ``order = C``.
"""
if order in "AK":
from warnings import warn
warn(f"ravel with order='{order}' not supported by JAX,"
" using order=C.")
order = "C"
# {{{ mathematical functions
def sum(self, a, axis=None, dtype=None):
return rec_map_reduce_array_container(
sum,
partial(jnp.sum, axis=axis, dtype=dtype),
a)
return rec_map_array_container(lambda subary: jnp.ravel(subary, order=order),
a)
def amin(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, jnp.minimum), partial(jnp.amin, axis=axis), a)
def vdot(self, x, y, dtype=None):
from arraycontext import rec_multimap_reduce_array_container
min = amin
def _rec_vdot(ary1, ary2):
if dtype not in [None, numpy.find_common_type((ary1.dtype,
ary2.dtype),
())]:
raise NotImplementedError(f"{type(self)} cannot take dtype in"
" vdot.")
def amax(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, jnp.maximum), partial(jnp.amax, axis=axis), a)
return jnp.vdot(ary1, ary2)
max = amax
return rec_multimap_reduce_array_container(sum, _rec_vdot, x, y)
# }}}
def broadcast_to(self, array, shape):
return rec_map_array_container(partial(jnp.broadcast_to, shape=shape), array)
# {{{ sorting, searching and counting
def where(self, criterion, then, else_):
return rec_multimap_array_container(jnp.where, criterion, then, else_)
# }}}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment