diff --git a/arraycontext/context.py b/arraycontext/context.py index 36a7acee77792e83a7d3c63ca1bfdef491d9c79a..e1528387d801330cf4887b7e42740b4a6eef8b74 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -299,9 +299,19 @@ class ArrayContext(ABC): pass def empty_like(self, ary: Array) -> Array: + from warnings import warn + warn(f"{type(self).__name__}.empty_like is deprecated and will stop " + "working in 2023. Prefer actx.np.zeros_like instead.", + DeprecationWarning, stacklevel=2) + return self.empty(shape=ary.shape, dtype=ary.dtype) def zeros_like(self, ary: Array) -> Array: + from warnings import warn + warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " + "working in 2023. Use actx.np.zeros_like instead.", + DeprecationWarning, stacklevel=2) + return self.zeros(shape=ary.shape, dtype=ary.dtype) @abstractmethod diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index d5c8fce9b4e0e884d20bb18e675c0c813741bbcc..c3e37f824bdb19ff14996cb51333cf83b857bc4c 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -91,12 +91,6 @@ class BaseFakeNumpyNamespace: # "interp", }) - def empty_like(self, ary): - return self._array_context.empty_like(ary) - - def zeros_like(self, ary): - return self._array_context.zeros_like(ary) - def conjugate(self, x): # NOTE: conjugate distributes over object arrays, but it looks for a # `conjugate` ufunc, while some implementations only have the shorter diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index dfb89c45b29bdd3cc90c915fa213b269522f8985..f4794e467b69b236d7d1768d8cbb6b773ac6c2a9 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -88,6 +88,11 @@ class EagerJAXArrayContext(ArrayContext): # {{{ ArrayContext interface def empty(self, shape, dtype): + from warnings import warn + warn(f"{type(self).__name__}.empty is deprecated and will stop " + "working in 2023. Prefer actx.zeros instead.", + DeprecationWarning, stacklevel=2) + import jax.numpy as jnp return jnp.empty(shape=shape, dtype=dtype) @@ -96,16 +101,23 @@ class EagerJAXArrayContext(ArrayContext): return jnp.zeros(shape=shape, dtype=dtype) def empty_like(self, ary): + from warnings import warn + warn(f"{type(self).__name__}.empty_like is deprecated and will stop " + "working in 2023. Prefer actx.np.zeros_like instead.", + DeprecationWarning, stacklevel=2) + def _empty_like(array): return self.empty(array.shape, array.dtype) return self._rec_map_container(_empty_like, ary) def zeros_like(self, ary): - def _zeros_like(array): - return self.zeros(array.shape, array.dtype) + from warnings import warn + warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " + "working in 2023. Use actx.np.zeros_like instead.", + DeprecationWarning, stacklevel=2) - return self._rec_map_container(_zeros_like, ary, default_scalar=0) + return self.np.zeros_like(ary) def from_numpy(self, array): def _from_numpy(ary): diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 37c99b4ab6bacebc10738471fd08cf78593b9a87..daaf8803d9ffacae99a358fa49b28cc9e28d2b70 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -56,6 +56,25 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace): # {{{ array creation routines + def empty_like(self, ary): + from warnings import warn + warn(f"{type(self._array_context).__name__}.np.empty_like is " + "deprecated and will stop working in 2023. Prefer actx.np.zeros_like " + "instead.", + DeprecationWarning, stacklevel=2) + + def _empty_like(array): + return self._array_context.empty(array.shape, array.dtype) + + return self._array_context._rec_map_container(_empty_like, ary) + + def zeros_like(self, ary): + def _zeros_like(array): + return self._array_context.zeros(array.shape, array.dtype) + + return self._array_context._rec_map_container( + _zeros_like, ary, default_scalar=0) + def ones_like(self, ary): return self.full_like(ary, 1) diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 71b04c721552547a2389bdd95d22bea6bee9c8d2..aced309c2d5959f7ecde5538d79c4d4c8f4e6063 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -189,6 +189,11 @@ class PyOpenCLArrayContext(ArrayContext): # {{{ ArrayContext interface def empty(self, shape, dtype): + from warnings import warn + warn(f"{type(self).__name__}.empty is deprecated and will stop " + "working in 2023. Prefer actx.zeros instead.", + DeprecationWarning, stacklevel=2) + import arraycontext.impl.pyopencl.taggable_cl_array as tga return tga.empty(self.queue, shape, dtype, allocator=self.allocator) @@ -197,6 +202,11 @@ class PyOpenCLArrayContext(ArrayContext): return tga.zeros(self.queue, shape, dtype, allocator=self.allocator) def empty_like(self, ary): + from warnings import warn + warn(f"{type(self).__name__}.empty_like is deprecated and will stop " + "working in 2023. Prefer actx.np.zeros_like instead.", + DeprecationWarning, stacklevel=2) + import arraycontext.impl.pyopencl.taggable_cl_array as tga def _empty_like(array): @@ -206,13 +216,12 @@ class PyOpenCLArrayContext(ArrayContext): return self._rec_map_container(_empty_like, ary) def zeros_like(self, ary): - import arraycontext.impl.pyopencl.taggable_cl_array as tga - - def _zeros_like(array): - return tga.zeros(self.queue, array.shape, array.dtype, - allocator=self.allocator, axes=array.axes, tags=array.tags) + from warnings import warn + warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " + "working in 2023. Use actx.np.zeros_like instead.", + DeprecationWarning, stacklevel=2) - return self._rec_map_container(_zeros_like, ary, default_scalar=0) + return self.np.zeros_like(ary) def from_numpy(self, array): import arraycontext.impl.pyopencl.taggable_cl_array as tga diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 2e206a8bb9aded3cc11687e2bd0903d77d40b4b4..a0180e73615f86fc57d59c1d051bd33e51588bc8 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -63,22 +63,49 @@ class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): # {{{ array creation routines + def empty_like(self, ary): + from warnings import warn + warn(f"{type(self._array_context).__name__}.np.empty_like is " + "deprecated and will stop working in 2023. Prefer actx.np.zeros_like " + "instead.", + DeprecationWarning, stacklevel=2) + + import arraycontext.impl.pyopencl.taggable_cl_array as tga + actx = self._array_context + + def _empty_like(array): + return tga.empty(actx.queue, array.shape, array.dtype, + allocator=actx.allocator, axes=array.axes, tags=array.tags) + + return actx._rec_map_container(_empty_like, ary) + + def zeros_like(self, ary): + import arraycontext.impl.pyopencl.taggable_cl_array as tga + actx = self._array_context + + def _zeros_like(array): + return tga.zeros( + actx.queue, array.shape, array.dtype, + allocator=actx.allocator, axes=array.axes, tags=array.tags) + + return actx._rec_map_container(_zeros_like, ary, default_scalar=0) + def ones_like(self, ary): return self.full_like(ary, 1) def full_like(self, ary, fill_value): import arraycontext.impl.pyopencl.taggable_cl_array as tga + actx = self._array_context def _full_like(subary): filled = tga.empty( - self._array_context.queue, subary.shape, subary.dtype, - allocator=self._array_context.allocator, - axes=subary.axes, tags=subary.tags) + actx.queue, subary.shape, subary.dtype, + allocator=actx.allocator, axes=subary.axes, tags=subary.tags) filled.fill(fill_value) + return filled - return self._array_context._rec_map_container( - _full_like, ary, default_scalar=fill_value) + return actx._rec_map_container(_full_like, ary, default_scalar=fill_value) def copy(self, ary): def _copy(subary): diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 6b9ac6b7ee2f0a8a83506ddbd02c0030552494cd..8ccc768917aeb589fae069b8569b78695967f2ae 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -348,10 +348,12 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): # {{{ ArrayContext interface def zeros_like(self, ary): - def _zeros_like(array): - return self.zeros(array.shape, array.dtype) + from warnings import warn + warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " + "working in 2023. Use actx.np.zeros_like instead.", + DeprecationWarning, stacklevel=2) - return self._rec_map_container(_zeros_like, ary, default_scalar=0) + return self.np.zeros_like(ary) def from_numpy(self, array): import pytato as pt @@ -720,10 +722,12 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): # {{{ ArrayContext interface def zeros_like(self, ary): - def _zeros_like(array): - return self.zeros(array.shape, array.dtype) + from warnings import warn + warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " + "working in 2023. Use actx.np.zeros_like instead.", + DeprecationWarning, stacklevel=2) - return self._rec_map_container(_zeros_like, ary, default_scalar=0) + return self.np.zeros_like(ary) def from_numpy(self, array): import jax diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index d1890f29d7a514600bc7f8c3ecc990d03e53726a..e17f8ee7788c77adc503ddaa5f497507459a81ff 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -76,6 +76,13 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): # {{{ array creation routines + def zeros_like(self, ary): + def _zeros_like(array): + return self._array_context.zeros(array.shape, array.dtype) + + return self._array_context._rec_map_container( + _zeros_like, ary, default_scalar=0) + def ones_like(self, ary): return self.full_like(ary, 1)