diff --git a/test/test_clrandom.py b/test/test_clrandom.py index 2be2e46f8356ec7a59658df1c1edcfec2ebc41b9..2846e24c97e1c4c0356420dfa35a1278e615a77f 100644 --- a/test/test_clrandom.py +++ b/test/test_clrandom.py @@ -40,21 +40,13 @@ else: faulthandler.enable() -class RanluxGeneratorShim(object): - - def __init__(self, cl_ctx): - self.queue = cl.CommandQueue(cl_ctx) - self.gen = clrandom.RanluxGenerator(self.queue) - - def uniform(self, *args, **kwargs): - return self.gen.uniform(*args, **kwargs) - - def normal(self, *args, **kwargs): - return self.gen.normal(*args, **kwargs) +def make_ranlux_generator(cl_ctx): + queue = cl.CommandQueue(cl_ctx) + return clrandom.RanluxGenerator(queue) @pytest.mark.parametrize("rng_class", [ - RanluxGeneratorShim, + make_ranlux_generator, clrandom.PhiloxGenerator, clrandom.ThreefryGenerator]) @pytest.mark.parametrize("dtype", [