github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dgels.go (about) 1 // Copyright ©2015 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "math/rand" 9 "testing" 10 11 "github.com/gonum/blas" 12 "github.com/gonum/blas/blas64" 13 "github.com/gonum/floats" 14 ) 15 16 type Dgelser interface { 17 Dgels(trans blas.Transpose, m, n, nrhs int, a []float64, lda int, b []float64, ldb int, work []float64, lwork int) bool 18 } 19 20 func DgelsTest(t *testing.T, impl Dgelser) { 21 rnd := rand.New(rand.NewSource(1)) 22 for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} { 23 for _, test := range []struct { 24 m, n, nrhs, lda, ldb int 25 }{ 26 {3, 4, 5, 0, 0}, 27 {3, 5, 4, 0, 0}, 28 {4, 3, 5, 0, 0}, 29 {4, 5, 3, 0, 0}, 30 {5, 3, 4, 0, 0}, 31 {5, 4, 3, 0, 0}, 32 {3, 4, 5, 10, 20}, 33 {3, 5, 4, 10, 20}, 34 {4, 3, 5, 10, 20}, 35 {4, 5, 3, 10, 20}, 36 {5, 3, 4, 10, 20}, 37 {5, 4, 3, 10, 20}, 38 {3, 4, 5, 20, 10}, 39 {3, 5, 4, 20, 10}, 40 {4, 3, 5, 20, 10}, 41 {4, 5, 3, 20, 10}, 42 {5, 3, 4, 20, 10}, 43 {5, 4, 3, 20, 10}, 44 {200, 300, 400, 0, 0}, 45 {200, 400, 300, 0, 0}, 46 {300, 200, 400, 0, 0}, 47 {300, 400, 200, 0, 0}, 48 {400, 200, 300, 0, 0}, 49 {400, 300, 200, 0, 0}, 50 {200, 300, 400, 500, 600}, 51 {200, 400, 300, 500, 600}, 52 {300, 200, 400, 500, 600}, 53 {300, 400, 200, 500, 600}, 54 {400, 200, 300, 500, 600}, 55 {400, 300, 200, 500, 600}, 56 {200, 300, 400, 600, 500}, 57 {200, 400, 300, 600, 500}, 58 {300, 200, 400, 600, 500}, 59 {300, 400, 200, 600, 500}, 60 {400, 200, 300, 600, 500}, 61 {400, 300, 200, 600, 500}, 62 } { 63 m := test.m 64 n := test.n 65 nrhs := test.nrhs 66 67 lda := test.lda 68 if lda == 0 { 69 lda = n 70 } 71 a := make([]float64, m*lda) 72 for i := range a { 73 a[i] = rnd.Float64() 74 } 75 aCopy := make([]float64, len(a)) 76 copy(aCopy, a) 77 78 // Size of b is the same trans or no trans, because the number of rows 79 // has to be the max of (m,n). 80 mb := max(m, n) 81 nb := nrhs 82 ldb := test.ldb 83 if ldb == 0 { 84 ldb = nb 85 } 86 b := make([]float64, mb*ldb) 87 for i := range b { 88 b[i] = rnd.Float64() 89 } 90 bCopy := make([]float64, len(b)) 91 copy(bCopy, b) 92 93 // Find optimal work length. 94 work := make([]float64, 1) 95 impl.Dgels(trans, m, n, nrhs, a, lda, b, ldb, work, -1) 96 97 // Perform linear solve 98 work = make([]float64, int(work[0])) 99 lwork := len(work) 100 for i := range work { 101 work[i] = rnd.Float64() 102 } 103 impl.Dgels(trans, m, n, nrhs, a, lda, b, ldb, work, lwork) 104 105 // Check that the answer is correct by comparing to the normal equations. 106 aMat := blas64.General{ 107 Rows: m, 108 Cols: n, 109 Stride: lda, 110 Data: make([]float64, len(aCopy)), 111 } 112 copy(aMat.Data, aCopy) 113 szAta := n 114 if trans == blas.Trans { 115 szAta = m 116 } 117 aTA := blas64.General{ 118 Rows: szAta, 119 Cols: szAta, 120 Stride: szAta, 121 Data: make([]float64, szAta*szAta), 122 } 123 124 // Compute A^T * A if notrans and A * A^T otherwise. 125 if trans == blas.NoTrans { 126 blas64.Gemm(blas.Trans, blas.NoTrans, 1, aMat, aMat, 0, aTA) 127 } else { 128 blas64.Gemm(blas.NoTrans, blas.Trans, 1, aMat, aMat, 0, aTA) 129 } 130 131 // Multiply by X. 132 X := blas64.General{ 133 Rows: szAta, 134 Cols: nrhs, 135 Stride: ldb, 136 Data: b, 137 } 138 ans := blas64.General{ 139 Rows: aTA.Rows, 140 Cols: X.Cols, 141 Stride: X.Cols, 142 Data: make([]float64, aTA.Rows*X.Cols), 143 } 144 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aTA, X, 0, ans) 145 146 B := blas64.General{ 147 Rows: szAta, 148 Cols: nrhs, 149 Stride: ldb, 150 Data: make([]float64, len(bCopy)), 151 } 152 153 copy(B.Data, bCopy) 154 var ans2 blas64.General 155 if trans == blas.NoTrans { 156 ans2 = blas64.General{ 157 Rows: aMat.Cols, 158 Cols: B.Cols, 159 Stride: B.Cols, 160 Data: make([]float64, aMat.Cols*B.Cols), 161 } 162 } else { 163 ans2 = blas64.General{ 164 Rows: aMat.Rows, 165 Cols: B.Cols, 166 Stride: B.Cols, 167 Data: make([]float64, aMat.Rows*B.Cols), 168 } 169 } 170 171 // Compute A^T B if Trans or A * B otherwise 172 if trans == blas.NoTrans { 173 blas64.Gemm(blas.Trans, blas.NoTrans, 1, aMat, B, 0, ans2) 174 } else { 175 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, B, 0, ans2) 176 } 177 if !floats.EqualApprox(ans.Data, ans2.Data, 1e-12) { 178 t.Errorf("Normal equations not satisfied") 179 } 180 } 181 } 182 }