Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
A
arraycontext
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Andreas Klöckner
arraycontext
Commits
ff1cd0cf
Commit
ff1cd0cf
authored
2 years ago
by
Alexandru Fikl
Committed by
Andreas Klöckner
2 years ago
Browse files
Options
Downloads
Patches
Plain Diff
rearrange jax.fake_numpy to match other contexts
parent
be1429c2
No related branches found
No related tags found
No related merge requests found
Pipeline
#309657
passed
2 years ago
Stage: test
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
arraycontext/impl/jax/fake_numpy.py
+82
-43
82 additions, 43 deletions
arraycontext/impl/jax/fake_numpy.py
with
82 additions
and
43 deletions
arraycontext/impl/jax/fake_numpy.py
+
82
−
43
View file @
ff1cd0cf
...
...
@@ -50,40 +50,81 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace):
def
__getattr__
(
self
,
name
):
return
partial
(
rec_multimap_array_container
,
getattr
(
jnp
,
name
))
# NOTE: the order of these follows the order in numpy docs
# NOTE: when adding a function here, also add it to `array_context.rst` docs!
# {{{ array creation routines
def
ones_like
(
self
,
ary
):
return
self
.
full_like
(
ary
,
1
)
def
full_like
(
self
,
ary
,
fill_value
):
def
_full_like
(
subary
):
return
jnp
.
full_like
(
ary
,
fill_value
)
return
self
.
_new_like
(
ary
,
_full_like
)
# }}}
# {{{ array manipulation routies
def
reshape
(
self
,
a
,
newshape
,
order
=
"
C
"
):
return
rec_map_array_container
(
lambda
ary
:
jnp
.
reshape
(
ary
,
newshape
,
order
=
order
),
a
)
def
transpose
(
self
,
a
,
axes
=
None
):
return
rec_multimap_array_container
(
jnp
.
transpose
,
a
,
axes
)
def
ravel
(
self
,
a
,
order
=
"
C
"
):
"""
.. warning::
def
concatenate
(
self
,
arrays
,
axis
=
0
):
return
rec_multimap_array_container
(
jnp
.
concatenate
,
arrays
,
axis
)
Since :func:`jax.numpy.reshape` does not support orders `A`` and
``K``, in such cases we fallback to using ``order = C``.
"""
if
order
in
"
AK
"
:
from
warnings
import
warn
warn
(
f
"
ravel with order=
'
{
order
}
'
not supported by JAX,
"
"
using order=C.
"
)
order
=
"
C
"
def
where
(
self
,
criterion
,
then
,
else_
):
return
rec_multimap_array_container
(
jnp
.
where
,
criterion
,
then
,
else_
)
return
rec_map_array_container
(
lambda
subary
:
jnp
.
ravel
(
subary
,
order
=
order
),
a
)
def
sum
(
self
,
a
,
axis
=
None
,
dtype
=
None
):
return
rec_map_reduce_array_container
(
sum
,
partial
(
jnp
.
sum
,
axis
=
axis
,
dtype
=
dtype
),
a
)
def
transpose
(
self
,
a
,
axes
=
None
):
return
rec_multimap_array_container
(
jnp
.
transpose
,
a
,
axes
)
def
min
(
self
,
a
,
axis
=
None
):
return
rec_map_reduce_array_container
(
partial
(
reduce
,
jnp
.
minimum
),
partial
(
jnp
.
amin
,
axis
=
axis
),
a
)
def
broadcast_to
(
self
,
array
,
shape
):
return
rec_map_array_container
(
partial
(
jnp
.
broadcast_to
,
shape
=
shape
),
array
)
def
max
(
self
,
a
,
axis
=
None
):
return
rec_map_reduce_array_container
(
partial
(
reduce
,
jnp
.
maximum
),
partial
(
jnp
.
amax
,
axis
=
axis
),
a
)
def
concatenate
(
self
,
arrays
,
axis
=
0
):
return
rec_multimap_array_container
(
jnp
.
concatenate
,
arrays
,
axis
)
def
stack
(
self
,
arrays
,
axis
=
0
):
return
rec_multimap_array_container
(
lambda
*
args
:
jnp
.
stack
(
arrays
=
args
,
axis
=
axis
),
*
arrays
)
# }}}
# {{{ linear algebra
def
vdot
(
self
,
x
,
y
,
dtype
=
None
):
from
arraycontext
import
rec_multimap_reduce_array_container
def
_rec_vdot
(
ary1
,
ary2
):
if
dtype
not
in
[
None
,
numpy
.
find_common_type
((
ary1
.
dtype
,
ary2
.
dtype
),
())]:
raise
NotImplementedError
(
f
"
{
type
(
self
)
}
cannot take dtype in
"
"
vdot.
"
)
return
jnp
.
vdot
(
ary1
,
ary2
)
return
rec_multimap_reduce_array_container
(
sum
,
_rec_vdot
,
x
,
y
)
# }}}
# {{{ logic functions
def
array_equal
(
self
,
a
,
b
):
actx
=
self
.
_array_context
...
...
@@ -109,35 +150,33 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace):
return
rec_equal
(
a
,
b
)
def
ravel
(
self
,
a
,
order
=
"
C
"
):
"""
.. warning::
# }}}
Since :func:`jax.numpy.reshape` does not support orders `A`` and
``K``, in such cases we fallback to using ``order = C``.
"""
if
order
in
"
AK
"
:
from
warnings
import
warn
warn
(
f
"
ravel with order=
'
{
order
}
'
not supported by JAX,
"
"
using order=C.
"
)
order
=
"
C
"
# {{{ mathematical functions
def
sum
(
self
,
a
,
axis
=
None
,
dtype
=
None
):
return
rec_map_reduce_array_container
(
sum
,
partial
(
jnp
.
sum
,
axis
=
axis
,
dtype
=
dtype
),
a
)
return
rec_map_array_container
(
lambda
subary
:
jnp
.
ravel
(
subary
,
order
=
order
),
a
)
def
amin
(
self
,
a
,
axis
=
None
):
return
rec_map_reduce_array_container
(
partial
(
reduce
,
jnp
.
minimum
),
partial
(
jnp
.
amin
,
axis
=
axis
),
a
)
def
vdot
(
self
,
x
,
y
,
dtype
=
None
):
from
arraycontext
import
rec_multimap_reduce_array_container
min
=
amin
def
_rec_vdot
(
ary1
,
ary2
):
if
dtype
not
in
[
None
,
numpy
.
find_common_type
((
ary1
.
dtype
,
ary2
.
dtype
),
())]:
raise
NotImplementedError
(
f
"
{
type
(
self
)
}
cannot take dtype in
"
"
vdot.
"
)
def
amax
(
self
,
a
,
axis
=
None
):
return
rec_map_reduce_array_container
(
partial
(
reduce
,
jnp
.
maximum
),
partial
(
jnp
.
amax
,
axis
=
axis
),
a
)
return
jnp
.
vdot
(
ary1
,
ary2
)
max
=
amax
return
rec_multimap_reduce_array_container
(
sum
,
_rec_vdot
,
x
,
y
)
# }}}
def
broadcast_to
(
self
,
array
,
shape
):
return
rec_map_array_container
(
partial
(
jnp
.
broadcast_to
,
shape
=
shape
),
array
)
# {{{ sorting, searching and counting
def
where
(
self
,
criterion
,
then
,
else_
):
return
rec_multimap_array_container
(
jnp
.
where
,
criterion
,
then
,
else_
)
# }}}
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment