Skip to content
Snippets Groups Projects
Commit 97940584 authored by Matthias Diener's avatar Matthias Diener
Browse files

small fixes

parent bb15b4c0
No related branches found
No related tags found
No related merge requests found
......@@ -236,7 +236,6 @@ class BaseFakeNumpyLinalgNamespace:
return self._array_context.np.sum(abs(ary)**ord)**(1/ord)
else:
raise NotImplementedError(f"unsupported value of 'ord': {ord}")
# }}}
......
......@@ -54,6 +54,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
def __init__(self, queue, allocator=None):
super().__init__()
self._force_device_scalars = True
self.queue = queue
self.allocator = allocator
self.np = self._get_fake_numpy_namespace()
......
......@@ -51,8 +51,7 @@ def make_loopy_program(domains, statements, kernel_data=None,
statements,
kernel_data=kernel_data,
options=_DEFAULT_LOOPY_OPTIONS,
# FIXME: Restore when https://github.com/inducer/loopy/pull/431 is merged
# default_offset=lp.auto,
default_offset=lp.auto,
name=name,
lang_version=MOST_RECENT_LANGUAGE_VERSION)
......
......@@ -87,10 +87,24 @@ class _DeprecatedPyOpenCLArrayContextFactory(_PyOpenCLArrayContextFactory):
force_device_scalars = False
class _PytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory):
force_device_scalars = False
def __call__(self):
from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext
return PytatoPyOpenCLArrayContext(self.get_command_queue())
def __str__(self):
return ("<Pytato array context factory for <pyopencl.Device '%s' on '%s'>"
% (self.device.name.strip(),
self.device.platform.name.strip()))
_ARRAY_CONTEXT_FACTORY_REGISTRY: \
Dict[str, Type[PytestPyOpenCLArrayContextFactory]] = {
"pyopencl": _PyOpenCLArrayContextFactory,
"pyopencl-deprecated": _DeprecatedPyOpenCLArrayContextFactory,
"pytato-pyopencl": _PytatoPyOpenCLArrayContextFactory,
}
......
......@@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
"pyopencl", "pyopencl-deprecated",
"pyopencl", "pyopencl-deprecated", "pytato-pyopencl"
])
......
......@@ -24,13 +24,16 @@ THE SOFTWARE.
"""
from arraycontext import ( # noqa: F401
pytest_generate_tests_for_array_contexts
as pytest_generate_tests,
pytest_generate_tests_for_array_contexts,
_acf)
import logging
logger = logging.getLogger(__name__)
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
"pyopencl", "pyopencl-deprecated", "pytato-pyopencl"
])
def test_pt_actx_key_stringification_uniqueness():
from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment