From f4183f8e77a7eb8fe8a1c16db79a7d49c495c3e3 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Wed, 20 Oct 2021 17:26:23 -0500 Subject: [PATCH] allow scalars in to_numpy --- arraycontext/impl/pyopencl/__init__.py | 2 +- arraycontext/impl/pytato/__init__.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index a737428..585a99e 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -161,7 +161,7 @@ class PyOpenCLArrayContext(ArrayContext): return cl_array.to_device(self.queue, array, allocator=self.allocator) def to_numpy(self, array): - if not self._force_device_scalars and np.isscalar(array): + if np.isscalar(array): return array return array.get(queue=self.queue) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 672a285..a7a0b1b 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -105,6 +105,9 @@ class PytatoPyOpenCLArrayContext(ArrayContext): return pt.make_data_wrapper(cl_array) def to_numpy(self, array): + if np.isscalar(array): + return array + cl_array = self.freeze(array) return cl_array.get(queue=self.queue) -- GitLab