diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index 43ec6b696e4c8aebf715d11e12e91c0bd037d216..2de78d1b6905cc155e665300c7eefd58e3a03247 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -155,7 +155,7 @@ class BaseFakeNumpyNamespace: delta_actx = self._array_context.from_numpy(delta) # sequences with 0 items or 1 item with endpoint=True (i.e. div <= 0) # have an undefined step - step = np.NaN + step = np.nan # Multiply with delta to allow possible override of output class. y = y * delta_actx diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 27799cfebcd0e836ebc00f659aa80762b0203cde..d20448a4d4060c8667485ada97b582ff961ed42b 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -130,8 +130,8 @@ class EagerJAXFakeNumpyNamespace(BaseFakeNumpyNamespace): from arraycontext import rec_multimap_reduce_array_container def _rec_vdot(ary1, ary2): - common_dtype = np.find_common_type((ary1.dtype, ary2.dtype), ()) - if dtype not in [None, common_dtype]: + common_dtype = np.result_type(ary1, ary2) + if dtype not in (None, common_dtype): raise NotImplementedError( f"{type(self).__name__} cannot take dtype in vdot.")