github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dgetrs.go (about) 1 // Copyright ©2015 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "math/rand" 9 "testing" 10 11 "github.com/gonum/blas" 12 "github.com/gonum/blas/blas64" 13 "github.com/gonum/floats" 14 ) 15 16 type Dgetrser interface { 17 Dgetrfer 18 Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int) 19 } 20 21 func DgetrsTest(t *testing.T, impl Dgetrser) { 22 rnd := rand.New(rand.NewSource(1)) 23 // TODO(btracey): Put more thought into creating more regularized matrices 24 // and what correct tolerances should be. Consider also seeding the random 25 // number in this test to make it more robust to code changes in other 26 // parts of the suite. 27 for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} { 28 for _, test := range []struct { 29 n, nrhs, lda, ldb int 30 tol float64 31 }{ 32 {3, 3, 0, 0, 1e-12}, 33 {3, 5, 0, 0, 1e-12}, 34 {5, 3, 0, 0, 1e-12}, 35 36 {3, 3, 8, 10, 1e-12}, 37 {3, 5, 8, 10, 1e-12}, 38 {5, 3, 8, 10, 1e-12}, 39 40 {300, 300, 0, 0, 1e-8}, 41 {300, 500, 0, 0, 1e-8}, 42 {500, 300, 0, 0, 1e-6}, 43 44 {300, 300, 700, 600, 1e-8}, 45 {300, 500, 700, 600, 1e-8}, 46 {500, 300, 700, 600, 1e-6}, 47 } { 48 n := test.n 49 nrhs := test.nrhs 50 lda := test.lda 51 if lda == 0 { 52 lda = n 53 } 54 ldb := test.ldb 55 if ldb == 0 { 56 ldb = nrhs 57 } 58 a := make([]float64, n*lda) 59 for i := range a { 60 a[i] = rnd.Float64() 61 } 62 b := make([]float64, n*ldb) 63 for i := range b { 64 b[i] = rnd.Float64() 65 } 66 aCopy := make([]float64, len(a)) 67 copy(aCopy, a) 68 bCopy := make([]float64, len(b)) 69 copy(bCopy, b) 70 71 ipiv := make([]int, n) 72 for i := range ipiv { 73 ipiv[i] = rnd.Int() 74 } 75 76 // Compute the LU factorization. 77 impl.Dgetrf(n, n, a, lda, ipiv) 78 // Solve the system of equations given the result. 79 impl.Dgetrs(trans, n, nrhs, a, lda, ipiv, b, ldb) 80 81 // Check that the system of equations holds. 82 A := blas64.General{ 83 Rows: n, 84 Cols: n, 85 Stride: lda, 86 Data: aCopy, 87 } 88 B := blas64.General{ 89 Rows: n, 90 Cols: nrhs, 91 Stride: ldb, 92 Data: bCopy, 93 } 94 X := blas64.General{ 95 Rows: n, 96 Cols: nrhs, 97 Stride: ldb, 98 Data: b, 99 } 100 tmp := blas64.General{ 101 Rows: n, 102 Cols: nrhs, 103 Stride: ldb, 104 Data: make([]float64, n*ldb), 105 } 106 copy(tmp.Data, bCopy) 107 blas64.Gemm(trans, blas.NoTrans, 1, A, X, 0, B) 108 if !floats.EqualApprox(tmp.Data, bCopy, test.tol) { 109 t.Errorf("Linear solve mismatch. trans = %v, n = %v, nrhs = %v, lda = %v, ldb = %v", trans, n, nrhs, lda, ldb) 110 } 111 } 112 } 113 }