From f4183f8e77a7eb8fe8a1c16db79a7d49c495c3e3 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Wed, 20 Oct 2021 17:26:23 -0500
Subject: [PATCH] allow scalars in to_numpy

---
 arraycontext/impl/pyopencl/__init__.py | 2 +-
 arraycontext/impl/pytato/__init__.py   | 3 +++
 2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py
index a737428..585a99e 100644
--- a/arraycontext/impl/pyopencl/__init__.py
+++ b/arraycontext/impl/pyopencl/__init__.py
@@ -161,7 +161,7 @@ class PyOpenCLArrayContext(ArrayContext):
         return cl_array.to_device(self.queue, array, allocator=self.allocator)
 
     def to_numpy(self, array):
-        if not self._force_device_scalars and np.isscalar(array):
+        if np.isscalar(array):
             return array
 
         return array.get(queue=self.queue)
diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index 672a285..a7a0b1b 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -105,6 +105,9 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
         return pt.make_data_wrapper(cl_array)
 
     def to_numpy(self, array):
+        if np.isscalar(array):
+            return array
+
         cl_array = self.freeze(array)
         return cl_array.get(queue=self.queue)
 
-- 
GitLab