From 4c8b4314e5cb58312aaf8a84d73b46a647d7fbfe Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 15 Dec 2022 19:56:38 +0100 Subject: [PATCH] jax: change DeviceArray to jnp.ndarray (#211) * limit JAX version due to CI errors * rename DeviceArray -> Array * import as JAXArray * Revert "limit JAX version due to CI errors" This reverts commit 9b89d14c3ee8eee1f5ded8b9fc8eee31f4c3fb85. * change to jax.numpy.ndarray --- arraycontext/impl/jax/__init__.py | 4 ++-- arraycontext/impl/pytato/__init__.py | 34 +++++++++++++--------------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index f4794e4..4aa30c2 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -54,8 +54,8 @@ class EagerJAXArrayContext(ArrayContext): def __init__(self) -> None: super().__init__() - from jax.numpy import DeviceArray - self.array_types = (DeviceArray, ) + import jax.numpy as jnp + self.array_types = (jnp.ndarray, ) def _get_fake_numpy_namespace(self): from .fake_numpy import EagerJAXFakeNumpyNamespace diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 8ccc768..c3e4462 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -686,14 +686,14 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): unstable. """ import pytato as pt - from jax.numpy import DeviceArray + import jax.numpy as jnp super().__init__(compile_trace_callback=compile_trace_callback) - self.array_types = (pt.Array, DeviceArray) + self.array_types = (pt.Array, jnp.ndarray) @property def _frozen_array_types(self) -> Tuple[Type, ...]: - from jax.numpy import DeviceArray - return (DeviceArray, ) + import jax.numpy as jnp + return (jnp.ndarray, ) def _rec_map_container( self, func: Callable[[Array], Array], array: ArrayOrContainer, @@ -756,16 +756,16 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): import pytato as pt - from jax.numpy import DeviceArray + import jax.numpy as jnp from arraycontext.container.traversal import rec_keyed_map_array_container from arraycontext.impl.pytato.compile import _ary_container_key_stringifier - array_as_dict: Dict[str, Union[DeviceArray, pt.Array]] = {} - key_to_frozen_subary: Dict[str, DeviceArray] = {} + array_as_dict: Dict[str, Union[jnp.ndarray, pt.Array]] = {} + key_to_frozen_subary: Dict[str, jnp.ndarray] = {} key_to_pt_arrays: Dict[str, pt.Array] = {} def _record_leaf_ary_in_dict(key: Tuple[Any, ...], - ary: Union[DeviceArray, pt.Array]) -> None: + ary: Union[jnp.ndarray, pt.Array]) -> None: key_str = "_ary" + _ary_container_key_stringifier(key) array_as_dict[key_str] = ary @@ -774,7 +774,7 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): # {{{ remove any non pytato arrays from array_as_dict for key, subary in array_as_dict.items(): - if isinstance(subary, DeviceArray): + if isinstance(subary, jnp.ndarray): key_to_frozen_subary[key] = subary.block_until_ready() elif isinstance(subary, pt.DataWrapper): # trivial freeze. @@ -801,7 +801,7 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): for k, v in out_dict.items()} } - def _to_frozen(key: Tuple[Any, ...], ary) -> DeviceArray: + def _to_frozen(key: Tuple[Any, ...], ary) -> jnp.ndarray: key_str = "_ary" + _ary_container_key_stringifier(key) return key_to_frozen_subary[key_str] @@ -824,10 +824,9 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): return LazilyJAXCompilingFunctionCaller(self, f) def tag(self, tags: ToTagSetConvertible, array): - from jax.numpy import DeviceArray - def _tag(ary): - if isinstance(ary, DeviceArray): + import jax.numpy as jnp + if isinstance(ary, jnp.ndarray): return ary else: return ary.tagged(_preprocess_array_tags(tags)) @@ -835,10 +834,9 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): return self._rec_map_container(_tag, array) def tag_axis(self, iaxis, tags: ToTagSetConvertible, array): - from jax.numpy import DeviceArray - def _tag_axis(ary): - if isinstance(ary, DeviceArray): + import jax.numpy as jnp + if isinstance(ary, jnp.ndarray): return ary else: return ary.with_tagged_axis(iaxis, tags) @@ -857,12 +855,12 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): def einsum(self, spec, *args, arg_names=None, tagged=()): import pytato as pt - from jax.numpy import DeviceArray if arg_names is None: arg_names = (None,) * len(args) def preprocess_arg(name, arg): - if isinstance(arg, DeviceArray): + import jax.numpy as jnp + if isinstance(arg, jnp.ndarray): ary = self.thaw(arg) elif isinstance(arg, pt.Array): ary = arg -- GitLab