gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dgels.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "testing" 9 10 "golang.org/x/exp/rand" 11 12 "gonum.org/v1/gonum/blas" 13 "gonum.org/v1/gonum/blas/blas64" 14 "gonum.org/v1/gonum/floats" 15 ) 16 17 type Dgelser interface { 18 Dgels(trans blas.Transpose, m, n, nrhs int, a []float64, lda int, b []float64, ldb int, work []float64, lwork int) bool 19 } 20 21 func DgelsTest(t *testing.T, impl Dgelser) { 22 rnd := rand.New(rand.NewSource(1)) 23 for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} { 24 for _, test := range []struct { 25 m, n, nrhs, lda, ldb int 26 }{ 27 {3, 4, 5, 0, 0}, 28 {3, 5, 4, 0, 0}, 29 {4, 3, 5, 0, 0}, 30 {4, 5, 3, 0, 0}, 31 {5, 3, 4, 0, 0}, 32 {5, 4, 3, 0, 0}, 33 {3, 4, 5, 10, 20}, 34 {3, 5, 4, 10, 20}, 35 {4, 3, 5, 10, 20}, 36 {4, 5, 3, 10, 20}, 37 {5, 3, 4, 10, 20}, 38 {5, 4, 3, 10, 20}, 39 {3, 4, 5, 20, 10}, 40 {3, 5, 4, 20, 10}, 41 {4, 3, 5, 20, 10}, 42 {4, 5, 3, 20, 10}, 43 {5, 3, 4, 20, 10}, 44 {5, 4, 3, 20, 10}, 45 {200, 300, 400, 0, 0}, 46 {200, 400, 300, 0, 0}, 47 {300, 200, 400, 0, 0}, 48 {300, 400, 200, 0, 0}, 49 {400, 200, 300, 0, 0}, 50 {400, 300, 200, 0, 0}, 51 {200, 300, 400, 500, 600}, 52 {200, 400, 300, 500, 600}, 53 {300, 200, 400, 500, 600}, 54 {300, 400, 200, 500, 600}, 55 {400, 200, 300, 500, 600}, 56 {400, 300, 200, 500, 600}, 57 {200, 300, 400, 600, 500}, 58 {200, 400, 300, 600, 500}, 59 {300, 200, 400, 600, 500}, 60 {300, 400, 200, 600, 500}, 61 {400, 200, 300, 600, 500}, 62 {400, 300, 200, 600, 500}, 63 } { 64 m := test.m 65 n := test.n 66 nrhs := test.nrhs 67 68 lda := test.lda 69 if lda == 0 { 70 lda = n 71 } 72 a := make([]float64, m*lda) 73 for i := range a { 74 a[i] = rnd.Float64() 75 } 76 aCopy := make([]float64, len(a)) 77 copy(aCopy, a) 78 79 // Size of b is the same trans or no trans, because the number of rows 80 // has to be the max of (m,n). 81 mb := max(m, n) 82 nb := nrhs 83 ldb := test.ldb 84 if ldb == 0 { 85 ldb = nb 86 } 87 b := make([]float64, mb*ldb) 88 for i := range b { 89 b[i] = rnd.Float64() 90 } 91 bCopy := make([]float64, len(b)) 92 copy(bCopy, b) 93 94 // Find optimal work length. 95 work := make([]float64, 1) 96 impl.Dgels(trans, m, n, nrhs, a, lda, b, ldb, work, -1) 97 98 // Perform linear solve 99 work = make([]float64, int(work[0])) 100 lwork := len(work) 101 for i := range work { 102 work[i] = rnd.Float64() 103 } 104 impl.Dgels(trans, m, n, nrhs, a, lda, b, ldb, work, lwork) 105 106 // Check that the answer is correct by comparing to the normal equations. 107 aMat := blas64.General{ 108 Rows: m, 109 Cols: n, 110 Stride: lda, 111 Data: make([]float64, len(aCopy)), 112 } 113 copy(aMat.Data, aCopy) 114 szAta := n 115 if trans == blas.Trans { 116 szAta = m 117 } 118 aTA := blas64.General{ 119 Rows: szAta, 120 Cols: szAta, 121 Stride: szAta, 122 Data: make([]float64, szAta*szAta), 123 } 124 125 // Compute Aᵀ * A if notrans and A * Aᵀ otherwise. 126 if trans == blas.NoTrans { 127 blas64.Gemm(blas.Trans, blas.NoTrans, 1, aMat, aMat, 0, aTA) 128 } else { 129 blas64.Gemm(blas.NoTrans, blas.Trans, 1, aMat, aMat, 0, aTA) 130 } 131 132 // Multiply by X. 133 X := blas64.General{ 134 Rows: szAta, 135 Cols: nrhs, 136 Stride: ldb, 137 Data: b, 138 } 139 ans := blas64.General{ 140 Rows: aTA.Rows, 141 Cols: X.Cols, 142 Stride: X.Cols, 143 Data: make([]float64, aTA.Rows*X.Cols), 144 } 145 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aTA, X, 0, ans) 146 147 B := blas64.General{ 148 Rows: szAta, 149 Cols: nrhs, 150 Stride: ldb, 151 Data: make([]float64, len(bCopy)), 152 } 153 154 copy(B.Data, bCopy) 155 var ans2 blas64.General 156 if trans == blas.NoTrans { 157 ans2 = blas64.General{ 158 Rows: aMat.Cols, 159 Cols: B.Cols, 160 Stride: B.Cols, 161 Data: make([]float64, aMat.Cols*B.Cols), 162 } 163 } else { 164 ans2 = blas64.General{ 165 Rows: aMat.Rows, 166 Cols: B.Cols, 167 Stride: B.Cols, 168 Data: make([]float64, aMat.Rows*B.Cols), 169 } 170 } 171 172 // Compute Aᵀ B if Trans or A * B otherwise 173 if trans == blas.NoTrans { 174 blas64.Gemm(blas.Trans, blas.NoTrans, 1, aMat, B, 0, ans2) 175 } else { 176 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, B, 0, ans2) 177 } 178 if !floats.EqualApprox(ans.Data, ans2.Data, 1e-12) { 179 t.Errorf("Normal equations not satisfied") 180 } 181 } 182 } 183 }