From 93a012bc30a9c95fd199eb2521543e8b5c44e154 Mon Sep 17 00:00:00 2001 From: Vincent Favre-Nicolin Date: Tue, 13 Jun 2023 17:32:06 +0200 Subject: [PATCH 1/2] Add support for array sizes >=2**31 for random number generation --- pycuda/curandom.py | 130 +++++++++++++++--------------- test/test_gpuarray.py | 179 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 234 insertions(+), 75 deletions(-) diff --git a/pycuda/curandom.py b/pycuda/curandom.py index 160776a4..6dbbeede 100644 --- a/pycuda/curandom.py +++ b/pycuda/curandom.py @@ -276,41 +276,41 @@ if get_curand_version() >= (4, 0, 0): # {{{ Base class gen_template = """ -__global__ void %(name)s(%(state_type)s *s, %(out_type)s *d, const int n) +__global__ void %(name)s(%(state_type)s *s, %(out_type)s *d, const size_t n) { - const int tidx = blockIdx.x*blockDim.x+threadIdx.x; - const int delta = blockDim.x*gridDim.x; - for (int idx = tidx; idx < n; idx += delta) + const size_t tidx = blockIdx.x*blockDim.x+threadIdx.x; + const size_t delta = blockDim.x*gridDim.x; + for (size_t idx = tidx; idx < n; idx += delta) d[idx] = curand%(suffix)s(&s[tidx]); } """ gen_log_template = """ -__global__ void %(name)s(%(state_type)s *s, %(out_type)s *d, %(in_type)s mean, %(in_type)s stddev, const int n) +__global__ void %(name)s(%(state_type)s *s, %(out_type)s *d, %(in_type)s mean, %(in_type)s stddev, const size_t n) { - const int tidx = blockIdx.x*blockDim.x+threadIdx.x; - const int delta = blockDim.x*gridDim.x; - for (int idx = tidx; idx < n; idx += delta) + const size_t tidx = blockIdx.x*blockDim.x+threadIdx.x; + const size_t delta = blockDim.x*gridDim.x; + for (size_t idx = tidx; idx < n; idx += delta) d[idx] = curand_log%(suffix)s(&s[tidx], mean, stddev); } """ gen_poisson_template = """ -__global__ void %(name)s(%(state_type)s *s, %(out_type)s *d, double lambda, const int n) +__global__ void %(name)s(%(state_type)s *s, %(out_type)s *d, double lambda, const size_t n) { - const int tidx = blockIdx.x*blockDim.x+threadIdx.x; - const int delta = blockDim.x*gridDim.x; - for (int idx = tidx; idx < n; idx += delta) + const size_t tidx = blockIdx.x*blockDim.x+threadIdx.x; + const size_t delta = blockDim.x*gridDim.x; + for (size_t idx = tidx; idx < n; idx += delta) d[idx] = curand_poisson%(suffix)s(&s[tidx], lambda); } """ gen_poisson_inplace_template = """ -__global__ void %(name)s(%(state_type)s *s, %(inout_type)s *d, const int n) +__global__ void %(name)s(%(state_type)s *s, %(inout_type)s *d, const size_t n) { - const int tidx = blockIdx.x*blockDim.x+threadIdx.x; - const int delta = blockDim.x*gridDim.x; - for (int idx = tidx; idx < n; idx += delta) + const size_t tidx = blockIdx.x*blockDim.x+threadIdx.x; + const size_t delta = blockDim.x*gridDim.x; + for (size_t idx = tidx; idx < n; idx += delta) d[idx] = (%(inout_type)s)(curand_poisson%(suffix)s(&s[tidx], double(d[idx]))); } """ @@ -330,16 +330,16 @@ extern "C" random_skip_ahead32_source = """ extern "C" { -__global__ void skip_ahead(%(state_type)s *s, const int n, const unsigned int skip) +__global__ void skip_ahead(%(state_type)s *s, const size_t n, const unsigned int skip) { - const int idx = blockIdx.x*blockDim.x+threadIdx.x; + const size_t idx = blockIdx.x*blockDim.x+threadIdx.x; if (idx < n) skipahead(skip, &s[idx]); } -__global__ void skip_ahead_array(%(state_type)s *s, const int n, const unsigned int *skip) +__global__ void skip_ahead_array(%(state_type)s *s, const size_t n, const unsigned int *skip) { - const int idx = blockIdx.x*blockDim.x+threadIdx.x; + const size_t idx = blockIdx.x*blockDim.x+threadIdx.x; if (idx < n) skipahead(skip[idx], &s[idx]); } @@ -348,16 +348,16 @@ __global__ void skip_ahead_array(%(state_type)s *s, const int n, const unsigned random_skip_ahead64_source = """ extern "C" { -__global__ void skip_ahead(%(state_type)s *s, const int n, const unsigned long long skip) +__global__ void skip_ahead(%(state_type)s *s, const size_t n, const unsigned long long skip) { - const int idx = blockIdx.x*blockDim.x+threadIdx.x; + const size_t idx = blockIdx.x*blockDim.x+threadIdx.x; if (idx < n) skipahead(skip, &s[idx]); } -__global__ void skip_ahead_array(%(state_type)s *s, const int n, const unsigned long long *skip) +__global__ void skip_ahead_array(%(state_type)s *s, const size_t n, const unsigned long long *skip) { - const int idx = blockIdx.x*blockDim.x+threadIdx.x; + const size_t idx = blockIdx.x*blockDim.x+threadIdx.x; if (idx < n) skipahead(skip[idx], &s[idx]); } @@ -517,24 +517,24 @@ class _RandomNumberGeneratorBase: self.generators = {} for name, out_type, suffix in my_generators: gen_func = module.get_function(name) - gen_func.prepare("PPi") + gen_func.prepare("PPn") self.generators[name] = gen_func if get_curand_version() >= (4, 0, 0): for name, in_type, out_type, suffix in my_log_generators: gen_func = module.get_function(name) if in_type == "float": - gen_func.prepare("PPffi") + gen_func.prepare("PPffn") if in_type == "double": - gen_func.prepare("PPddi") + gen_func.prepare("PPddn") self.generators[name] = gen_func if get_curand_version() >= (5, 0, 0): for name, out_type, suffix in my_poisson_generators: gen_func = module.get_function(name) - gen_func.prepare("PPdi") + gen_func.prepare("PPdn") self.generators[name] = gen_func for name, inout_type, suffix in my_poisson_inplace_generators: gen_func = module.get_function(name) - gen_func.prepare("PPi") + gen_func.prepare("PPn") self.generators[name] = gen_func self.generator_bits = generator_bits @@ -546,11 +546,11 @@ class _RandomNumberGeneratorBase: def _prepare_skipahead(self): self.skip_ahead = self.module.get_function("skip_ahead") if self.generator_bits == 32: - self.skip_ahead.prepare("PiI") + self.skip_ahead.prepare("PnI") if self.generator_bits == 64: - self.skip_ahead.prepare("PiQ") + self.skip_ahead.prepare("PnQ") self.skip_ahead_array = self.module.get_function("skip_ahead_array") - self.skip_ahead_array.prepare("PiP") + self.skip_ahead_array.prepare("PnP") def _kernels(self): return list(self.generators.values()) + [ @@ -769,7 +769,7 @@ class _PseudoRandomNumberGeneratorBase(_RandomNumberGeneratorBase): raise TypeError("seed must be GPUArray of integers of right length") p = self.module.get_function("prepare") - p.prepare("PiPi") + p.prepare("PnPn") from pycuda.characterize import has_stack @@ -799,15 +799,15 @@ class _PseudoRandomNumberGeneratorBase(_RandomNumberGeneratorBase): def _prepare_skipahead(self): self.skip_ahead = self.module.get_function("skip_ahead") - self.skip_ahead.prepare("PiQ") + self.skip_ahead.prepare("PnQ") self.skip_ahead_array = self.module.get_function("skip_ahead_array") - self.skip_ahead_array.prepare("PiP") + self.skip_ahead_array.prepare("PnP") self.skip_ahead_sequence = self.module.get_function("skip_ahead_sequence") - self.skip_ahead_sequence.prepare("PiQ") + self.skip_ahead_sequence.prepare("PnQ") self.skip_ahead_sequence_array = self.module.get_function( "skip_ahead_sequence_array" ) - self.skip_ahead_sequence_array.prepare("PiP") + self.skip_ahead_sequence_array.prepare("PnP") def call_skip_ahead_sequence(self, i, stream=None): self.skip_ahead_sequence.prepared_async_call( @@ -855,10 +855,10 @@ def seed_getter_unique(n): xorwow_random_source = """ extern "C" { -__global__ void prepare(%(state_type)s *s, const int n, - %(vector_type)s *v, const unsigned int o) +__global__ void prepare(%(state_type)s *s, const size_t n, + %(vector_type)s *v, const size_t o) { - const int id = blockIdx.x*blockDim.x+threadIdx.x; + const size_t id = blockIdx.x*blockDim.x+threadIdx.x; if (id < n) curand_init(v[id], id, o, &s[id]); } @@ -867,16 +867,16 @@ __global__ void prepare(%(state_type)s *s, const int n, xorwow_skip_ahead_sequence_source = """ extern "C" { -__global__ void skip_ahead_sequence(%(state_type)s *s, const int n, const unsigned long long skip) +__global__ void skip_ahead_sequence(%(state_type)s *s, const size_t n, const unsigned long long skip) { - const int idx = blockIdx.x*blockDim.x+threadIdx.x; + const size_t idx = blockIdx.x*blockDim.x+threadIdx.x; if (idx < n) skipahead_sequence(skip, &s[idx]); } -__global__ void skip_ahead_sequence_array(%(state_type)s *s, const int n, const unsigned long long *skip) +__global__ void skip_ahead_sequence_array(%(state_type)s *s, const size_t n, const unsigned long long *skip) { - const int idx = blockIdx.x*blockDim.x+threadIdx.x; + const size_t idx = blockIdx.x*blockDim.x+threadIdx.x; if (idx < n) skipahead_sequence(skip[idx], &s[idx]); } @@ -912,10 +912,10 @@ if get_curand_version() >= (3, 2, 0): mrg32k3a_random_source = """ extern "C" { -__global__ void prepare(%(state_type)s *s, const int n, - %(vector_type)s *v, const unsigned int o) +__global__ void prepare(%(state_type)s *s, const size_t n, + %(vector_type)s *v, const size_t o) { - const int id = blockIdx.x*blockDim.x+threadIdx.x; + const size_t id = blockIdx.x*blockDim.x+threadIdx.x; if (id < n) curand_init(v[id], id, o, &s[id]); } @@ -924,30 +924,30 @@ __global__ void prepare(%(state_type)s *s, const int n, mrg32k3a_skip_ahead_sequence_source = """ extern "C" { -__global__ void skip_ahead_sequence(%(state_type)s *s, const int n, const unsigned long long skip) +__global__ void skip_ahead_sequence(%(state_type)s *s, const size_t n, const unsigned long long skip) { - const int idx = blockIdx.x*blockDim.x+threadIdx.x; + const size_t idx = blockIdx.x*blockDim.x+threadIdx.x; if (idx < n) skipahead_sequence(skip, &s[idx]); } -__global__ void skip_ahead_sequence_array(%(state_type)s *s, const int n, const unsigned long long *skip) +__global__ void skip_ahead_sequence_array(%(state_type)s *s, const size_t n, const unsigned long long *skip) { - const int idx = blockIdx.x*blockDim.x+threadIdx.x; + const size_t idx = blockIdx.x*blockDim.x+threadIdx.x; if (idx < n) skipahead_sequence(skip[idx], &s[idx]); } -__global__ void skip_ahead_subsequence(%(state_type)s *s, const int n, const unsigned long long skip) +__global__ void skip_ahead_subsequence(%(state_type)s *s, const size_t n, const unsigned long long skip) { - const int idx = blockIdx.x*blockDim.x+threadIdx.x; + const size_t idx = blockIdx.x*blockDim.x+threadIdx.x; if (idx < n) skipahead_subsequence(skip, &s[idx]); } -__global__ void skip_ahead_subsequence_array(%(state_type)s *s, const int n, const unsigned long long *skip) +__global__ void skip_ahead_subsequence_array(%(state_type)s *s, const size_t n, const unsigned long long *skip) { - const int idx = blockIdx.x*blockDim.x+threadIdx.x; + const size_t idx = blockIdx.x*blockDim.x+threadIdx.x; if (idx < n) skipahead_subsequence(skip[idx], &s[idx]); } @@ -981,11 +981,11 @@ if get_curand_version() >= (4, 1, 0): self.skip_ahead_subsequence = self.module.get_function( "skip_ahead_subsequence" ) - self.skip_ahead_subsequence.prepare("PiQ") + self.skip_ahead_subsequence.prepare("PnQ") self.skip_ahead_subsequence_array = self.module.get_function( "skip_ahead_subsequence_array" ) - self.skip_ahead_subsequence_array.prepare("PiP") + self.skip_ahead_subsequence_array.prepare("PnP") def call_skip_ahead_subsequence(self, i, stream=None): self.skip_ahead_subsequence.prepared_async_call( @@ -1049,10 +1049,10 @@ if get_curand_version() >= (4, 0, 0): sobol_random_source = """ extern "C" { -__global__ void prepare(%(state_type)s *s, const int n, - %(vector_type)s *v, const unsigned int o) +__global__ void prepare(%(state_type)s *s, const size_t n, + %(vector_type)s *v, const size_t o) { - const int id = blockIdx.x*blockDim.x+threadIdx.x; + const size_t id = blockIdx.x*blockDim.x+threadIdx.x; if (id < n) curand_init(v[id], o, &s[id]); } @@ -1099,7 +1099,7 @@ class _SobolRandomNumberGeneratorBase(_RandomNumberGeneratorBase): raise TypeError("seed must be GPUArray of integers of right length") p = self.module.get_function("prepare") - p.prepare("PiPi") + p.prepare("PnPn") from pycuda.characterize import has_stack @@ -1135,10 +1135,10 @@ class _SobolRandomNumberGeneratorBase(_RandomNumberGeneratorBase): scrambledsobol_random_source = """ extern "C" { -__global__ void prepare( %(state_type)s *s, const int n, - %(vector_type)s *v, %(scramble_type)s *scramble, const unsigned int o) +__global__ void prepare( %(state_type)s *s, const size_t n, + %(vector_type)s *v, %(scramble_type)s *scramble, const size_t o) { - const int id = blockIdx.x*blockDim.x+threadIdx.x; + const size_t id = blockIdx.x*blockDim.x+threadIdx.x; if (id < n) curand_init(v[id], scramble[id], o, &s[id]); } @@ -1200,7 +1200,7 @@ class _ScrambledSobolRandomNumberGeneratorBase(_RandomNumberGeneratorBase): raise TypeError("scramble must be GPUArray of integers of right length") p = self.module.get_function("prepare") - p.prepare("PiPPi") + p.prepare("PnPPn") from pycuda.characterize import has_stack diff --git a/test/test_gpuarray.py b/test/test_gpuarray.py index ea641d09..9b0d9fe2 100644 --- a/test/test_gpuarray.py +++ b/test/test_gpuarray.py @@ -1,4 +1,6 @@ #! /usr/bin/env python +import unittest +import os import numpy as np import numpy.linalg as la @@ -6,7 +8,6 @@ import sys from pycuda.tools import init_cuda_context_fixture from pycuda.characterize import has_double_support - import pycuda.gpuarray as gpuarray import pycuda.driver as drv from pycuda.compiler import SourceModule @@ -37,7 +38,7 @@ def get_random_array(rng, shape, dtype): def skip_if_not_enough_gpu_memory(required_mem_gigabytes): - device_mem_GB = drv.Context.get_device().total_memory() / 1e9 + device_mem_GB = drv.Context.get_device().total_memory() / 1024 ** 3 if device_mem_GB < required_mem_gigabytes: pytest.skip("Need at least %.1f GB memory" % required_mem_gigabytes) @@ -378,6 +379,165 @@ class TestGPUArray: # # Compare with scipy.stats.poisson.pmf(v - 1, v) # assert np.isclose(0.12511, tmp, atol=0.002) + def test_curand_wrappers_8gb(self): + """ Test random number generation with array sizes of 2**31 with 4-byte types + to test sizes beyond the signed int range.""" + skip_if_not_enough_gpu_memory(9) + from pycuda.curandom import get_curand_version + + if get_curand_version() is None: + from pytest import skip + + skip("curand not installed") + + generator_types = [] + if get_curand_version() >= (3, 2, 0): + from pycuda.curandom import ( + XORWOWRandomNumberGenerator, + Sobol32RandomNumberGenerator, + ) + + generator_types.extend( + [XORWOWRandomNumberGenerator, Sobol32RandomNumberGenerator] + ) + if get_curand_version() >= (4, 0, 0): + from pycuda.curandom import ( + ScrambledSobol32RandomNumberGenerator, + Sobol64RandomNumberGenerator, + ScrambledSobol64RandomNumberGenerator, + ) + + generator_types.extend( + [ + ScrambledSobol32RandomNumberGenerator, + Sobol64RandomNumberGenerator, + ScrambledSobol64RandomNumberGenerator, + ] + ) + if get_curand_version() >= (4, 1, 0): + from pycuda.curandom import MRG32k3aRandomNumberGenerator + + generator_types.extend([MRG32k3aRandomNumberGenerator]) + + # Test ~2**31 elements for 32-bit types to not exceed ~8GB + dtypes = [np.float32, np.int32, np.uint32] + + for gen_type in generator_types: + gen = gen_type() + + for dtype in dtypes: + if dtype in [np.float32, np.float64]: + gen.gen_normal(2 ** 31, dtype) + # test non-Box-Muller version, if available + gen.gen_normal(2 ** 31 + 1, dtype) + + if get_curand_version() >= (4, 0, 0): + gen.gen_log_normal(2 ** 31, dtype, 10.0, 3.0) + # test non-Box-Muller version, if available + gen.gen_log_normal(2 ** 31 + 1, dtype, 10.0, 3.0) + + x = gen.gen_uniform(2 ** 31, dtype) + if dtype in [np.float32, np.float64]: + x_host = x.get() + assert (-1 <= x_host).all() + assert (x_host <= 1).all() + del x + + if get_curand_version() >= (5, 0, 0): + gen.gen_poisson(2 ** 31, np.uint32, 13.0) + for dtype in [np.float32, np.uint32]: + a = gpuarray.empty(2 ** 31, dtype=dtype) + v = 10 + a.fill(v) + gen.fill_poisson(a) + tmp = (a.get() == (v - 1)).sum() / a.size # noqa: F841 + # Check Poisson statistics (need 1e6 values) + # Compare with scipy.stats.poisson.pmf(v - 1, v) + assert np.isclose(0.12511, tmp, atol=0.005) + del a + + def test_curand_wrappers_16gb(self): + """ Test random number generation with array sizes of 2**31 (for 8-byte types) + or 2**32 (for 4-byte types) to test sizes beyond the unsigned int range.""" + skip_if_not_enough_gpu_memory(17) + from pycuda.curandom import get_curand_version + + if get_curand_version() is None: + from pytest import skip + + skip("curand not installed") + + generator_types = [] + if get_curand_version() >= (3, 2, 0): + from pycuda.curandom import ( + XORWOWRandomNumberGenerator, + Sobol32RandomNumberGenerator, + ) + + generator_types.extend( + [XORWOWRandomNumberGenerator, Sobol32RandomNumberGenerator] + ) + if get_curand_version() >= (4, 0, 0): + from pycuda.curandom import ( + ScrambledSobol32RandomNumberGenerator, + Sobol64RandomNumberGenerator, + ScrambledSobol64RandomNumberGenerator, + ) + + generator_types.extend( + [ + ScrambledSobol32RandomNumberGenerator, + Sobol64RandomNumberGenerator, + ScrambledSobol64RandomNumberGenerator, + ] + ) + if get_curand_version() >= (4, 1, 0): + from pycuda.curandom import MRG32k3aRandomNumberGenerator + + generator_types.extend([MRG32k3aRandomNumberGenerator]) + + if has_double_support(): + dtypes = [np.float32, np.float64, np.int32, np.uint32] + else: + dtypes = [np.float32, np.int32, np.uint32] + + for gen_type in generator_types: + gen = gen_type() + + for dtype in dtypes: + # test 2**32 for double precision and 2**32 for single + s = 2 ** 31 if dtype == np.float64 else 2 ** 32 + if dtype in [np.float32, np.float64]: + gen.gen_normal(s, dtype) + # test non-Box-Muller version, if available + gen.gen_normal(s + 1, dtype) + + if get_curand_version() >= (4, 0, 0): + gen.gen_log_normal(s, dtype, 10.0, 3.0) + # test non-Box-Muller version, if available + gen.gen_log_normal(s + 1, dtype, 10.0, 3.0) + + x = gen.gen_uniform(s, dtype) + if dtype in [np.float32, np.float64]: + x_host = x.get() + assert (-1 <= x_host).all() + assert (x_host <= 1).all() + del x + + if get_curand_version() >= (5, 0, 0): + gen.gen_poisson(2 ** 32, np.uint32, 13.0) + for dtype in [np.float32, np.float64, np.uint32]: + s = 2 ** 31 if dtype == np.float64 else 2 ** 32 + a = gpuarray.empty(s, dtype=dtype) + v = 10 + a.fill(v) + gen.fill_poisson(a) + tmp = (a.get() == (v - 1)).sum() / a.size # noqa: F841 + # Check Poisson statistics (need 1e6 values) + # Compare with scipy.stats.poisson.pmf(v - 1, v) + assert np.isclose(0.12511, tmp, atol=0.005) + del a + def test_array_gt(self): """Test whether array contents are > the other array's contents""" @@ -508,7 +668,6 @@ class TestGPUArray: slice(1000, -1), ] ): - a_gpu = gpuarray.zeros((50000,), dtype=np.float32) a_cpu = np.zeros(a_gpu.shape, a_gpu.dtype) @@ -1429,7 +1588,7 @@ class TestGPUArray: def test_truth_value(self): for i in range(5): - shape = (1,)*i + shape = (1,) * i zeros = gpuarray.zeros(shape, dtype="float32") ones = gpuarray.ones(shape, dtype="float32") assert bool(ones) @@ -1453,7 +1612,7 @@ class TestGPUArray: from pycuda import cumath rng = np.random.default_rng(seed=0) - x_np = rng.random((10, 4)) + dtype(1j)*rng.random((10, 4)) + x_np = rng.random((10, 4)) + dtype(1j) * rng.random((10, 4)) x_cu = gpuarray.to_gpu(x_np) np.testing.assert_allclose(cumath.log10(x_cu).get(), np.log10(x_np), rtol=rtol) @@ -1489,7 +1648,7 @@ class TestGPUArray: skip_if_not_enough_gpu_memory(4.5) from pycuda.elementwise import ElementwiseKernel - n_items = 2**32 + n_items = 2 ** 32 eltwise = ElementwiseKernel( "unsigned char* d_arr", @@ -1508,7 +1667,7 @@ class TestGPUArray: skip_if_not_enough_gpu_memory(4.5) from pycuda.reduction import ReductionKernel - n_items = 2**32 + 11 + n_items = 2 ** 32 + 11 reduction = ReductionKernel( np.uint8, neutral="0", @@ -1524,7 +1683,7 @@ class TestGPUArray: def test_big_array_scan(self): skip_if_not_enough_gpu_memory(4.5) - n_items = 2**32 + 12 + n_items = 2 ** 32 + 12 from pycuda.scan import InclusiveScanKernel cumsum = InclusiveScanKernel(np.uint8, "(a+b) & 0b11111111") @@ -1533,11 +1692,11 @@ class TestGPUArray: result = cumsum(d_arr).get()[()] # Needs 8.6 GB on host. numpy.allclose() is way too slow otherwise. reference = np.tile( - np.roll(np.arange(256, dtype=np.int16), -1), n_items//256 + np.roll(np.arange(256, dtype=np.int16), -1), n_items // 256 ) reference -= result[:reference.size] assert np.max(reference) == 0 - assert np.allclose(result[2**32:], np.arange(1, 12+1)) + assert np.allclose(result[2 ** 32:], np.arange(1, 12 + 1)) def test_noncontig_transpose(self): # https://github.com/inducer/pycuda/issues/385 -- GitLab From 76b691c9b2393fe92ee542c237b70e7a8b078991 Mon Sep 17 00:00:00 2001 From: Vincent Favre-Nicolin Date: Tue, 13 Jun 2023 18:02:32 +0200 Subject: [PATCH 2/2] Remove unused imports --- test/test_gpuarray.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_gpuarray.py b/test/test_gpuarray.py index 9b0d9fe2..68ed63ba 100644 --- a/test/test_gpuarray.py +++ b/test/test_gpuarray.py @@ -1,6 +1,4 @@ #! /usr/bin/env python -import unittest -import os import numpy as np import numpy.linalg as la -- GitLab