From dbe50aefce0cbac2865e241d9b0e10db7e482aa8 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Sun, 26 Sep 2021 02:41:25 -0500 Subject: [PATCH] test NumpyArrayContext --- arraycontext/pytest.py | 22 ++++++++++++++++++++++ test/test_arraycontext.py | 2 ++ 2 files changed, 24 insertions(+) diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 8a1e027..e74a6ae 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -34,6 +34,7 @@ THE SOFTWARE. from typing import Any, Callable, Dict, Sequence, Type, Union +from arraycontext import NumpyArrayContext from arraycontext.context import ArrayContext @@ -221,6 +222,26 @@ class _PytestPytatoJaxArrayContextFactory(PytestArrayContextFactory): return "<PytatoJAXArrayContext>" +# {{{ _PytestArrayContextFactory + +class _NumpyArrayContextForTests(NumpyArrayContext): + def transform_loopy_program(self, t_unit): + return t_unit + + +class _PytestNumpyArrayContextFactory(PytestArrayContextFactory): + def __init__(self, *args, **kwargs): + super().__init__() + + def __call__(self): + return _NumpyArrayContextForTests() + + def __str__(self): + return "<NumpyArrayContext>" + +# }}} + + _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestArrayContextFactory]] = { "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, @@ -229,6 +250,7 @@ _ARRAY_CONTEXT_FACTORY_REGISTRY: \ "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, "pytato:jax": _PytestPytatoJaxArrayContextFactory, "eagerjax": _PytestEagerJaxArrayContextFactory, + "numpy": _PytestNumpyArrayContextFactory, } diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 3f06156..ffd7553 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -46,6 +46,7 @@ from arraycontext import ( ) from arraycontext.pytest import ( _PytestEagerJaxArrayContextFactory, + _PytestNumpyArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory, @@ -97,6 +98,7 @@ pytest_generate_tests = pytest_generate_tests_for_array_contexts([ _PytatoPyOpenCLArrayContextForTestsFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory, + _PytestNumpyArrayContextFactory, ]) -- GitLab