gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dgetrs.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "testing" 9 10 "golang.org/x/exp/rand" 11 12 "gonum.org/v1/gonum/blas" 13 "gonum.org/v1/gonum/blas/blas64" 14 "gonum.org/v1/gonum/floats" 15 ) 16 17 type Dgetrser interface { 18 Dgetrfer 19 Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int) 20 } 21 22 func DgetrsTest(t *testing.T, impl Dgetrser) { 23 rnd := rand.New(rand.NewSource(1)) 24 // TODO(btracey): Put more thought into creating more regularized matrices 25 // and what correct tolerances should be. Consider also seeding the random 26 // number in this test to make it more robust to code changes in other 27 // parts of the suite. 28 for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} { 29 for _, test := range []struct { 30 n, nrhs, lda, ldb int 31 tol float64 32 }{ 33 {3, 3, 0, 0, 1e-12}, 34 {3, 5, 0, 0, 1e-12}, 35 {5, 3, 0, 0, 1e-12}, 36 37 {3, 3, 8, 10, 1e-12}, 38 {3, 5, 8, 10, 1e-12}, 39 {5, 3, 8, 10, 1e-12}, 40 41 {300, 300, 0, 0, 1e-8}, 42 {300, 500, 0, 0, 1e-8}, 43 {500, 300, 0, 0, 1e-6}, 44 45 {300, 300, 700, 600, 1e-8}, 46 {300, 500, 700, 600, 1e-8}, 47 {500, 300, 700, 600, 1e-6}, 48 } { 49 n := test.n 50 nrhs := test.nrhs 51 lda := test.lda 52 if lda == 0 { 53 lda = n 54 } 55 ldb := test.ldb 56 if ldb == 0 { 57 ldb = nrhs 58 } 59 a := make([]float64, n*lda) 60 for i := range a { 61 a[i] = rnd.Float64() 62 } 63 b := make([]float64, n*ldb) 64 for i := range b { 65 b[i] = rnd.Float64() 66 } 67 aCopy := make([]float64, len(a)) 68 copy(aCopy, a) 69 bCopy := make([]float64, len(b)) 70 copy(bCopy, b) 71 72 ipiv := make([]int, n) 73 for i := range ipiv { 74 ipiv[i] = rnd.Int() 75 } 76 77 // Compute the LU factorization. 78 impl.Dgetrf(n, n, a, lda, ipiv) 79 // Solve the system of equations given the result. 80 impl.Dgetrs(trans, n, nrhs, a, lda, ipiv, b, ldb) 81 82 // Check that the system of equations holds. 83 A := blas64.General{ 84 Rows: n, 85 Cols: n, 86 Stride: lda, 87 Data: aCopy, 88 } 89 B := blas64.General{ 90 Rows: n, 91 Cols: nrhs, 92 Stride: ldb, 93 Data: bCopy, 94 } 95 X := blas64.General{ 96 Rows: n, 97 Cols: nrhs, 98 Stride: ldb, 99 Data: b, 100 } 101 tmp := blas64.General{ 102 Rows: n, 103 Cols: nrhs, 104 Stride: ldb, 105 Data: make([]float64, n*ldb), 106 } 107 copy(tmp.Data, bCopy) 108 blas64.Gemm(trans, blas.NoTrans, 1, A, X, 0, B) 109 if !floats.EqualApprox(tmp.Data, bCopy, test.tol) { 110 t.Errorf("Linear solve mismatch. trans = %v, n = %v, nrhs = %v, lda = %v, ldb = %v", trans, n, nrhs, lda, ldb) 111 } 112 } 113 } 114 }