jax: change DeviceArray to jnp.ndarray (#211)
* limit JAX version due to CI errors * rename DeviceArray -> Array * import as JAXArray * Revert "limit JAX version due to CI errors" This reverts commit 9b89d14c3ee8eee1f5ded8b9fc8eee31f4c3fb85. * change to jax.numpy.ndarray