import pytest
import pyopencl as cl  # noqa: F401
from pyopencl.tools import (
        get_test_platforms_and_devices,
        clear_first_arg_caches
        )

import utilities as u

# {{{ mark for slow tests

# setup to mark slow tests with @pytest.mark.slow, so that they don't run by
# default, but can be forced to run with the command-line option --runslow
# taken from
# https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option

def pytest_addoption(parser):
    parser.addoption("--runslow", action="store_true", default=False,
            help="run slow tests")


def pytest_collection_modifyitems(config, items):
    if config.getoption("--runslow"):
        # --runslow given in cli: do not skip slow tests
        return
    skip_slow = pytest.mark.skip(reason="need --runslow option to run")
    for item in items:
        if "slow" in item.keywords:
            item.add_marker(skip_slow)

# }}}

# {{{ pytest_generate_tests

def pytest_generate_tests(metafunc):
    class ContextFactory:
        def __init__(self, device):
            self.device = device

        def __call__(self):
            # Get rid of leftovers from past tests.
            # CL implementations are surprisingly limited in how many
            # simultaneous contexts they allow...

            clear_first_arg_caches()

            from gc import collect
            collect()

            return cl.Context([self.device])

        def __str__(self):
            # Don't show address, so that parallel test collection works
            return ("<context factory for <pyopencl.Device '%s' on '%s'>" %
                    (self.device.name.strip(),
                     self.device.platform.name.strip()))

    test_plat_and_dev = get_test_platforms_and_devices()

    if "ctx_factory" in metafunc.fixturenames:
        factories = []

        for platform, plat_devs in test_plat_and_dev:
            for device in plat_devs:
                factories.append(ContextFactory(device))

        metafunc.parametrize("ctx_factory", factories, ids=str, scope="session")

# }}}

@pytest.fixture(scope="session")
def queue(ctx_factory):
    return u.get_queue(ctx_factory)


# vim: foldmethod=marker
