github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/testlapack/dlarft.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "testing" 9 10 "golang.org/x/exp/rand" 11 12 "github.com/jingcheng-WU/gonum/blas" 13 "github.com/jingcheng-WU/gonum/blas/blas64" 14 "github.com/jingcheng-WU/gonum/floats" 15 "github.com/jingcheng-WU/gonum/lapack" 16 ) 17 18 type Dlarfter interface { 19 Dgeqr2er 20 Dlarft(direct lapack.Direct, store lapack.StoreV, n, k int, v []float64, ldv int, tau []float64, t []float64, ldt int) 21 } 22 23 func DlarftTest(t *testing.T, impl Dlarfter) { 24 rnd := rand.New(rand.NewSource(1)) 25 for _, store := range []lapack.StoreV{lapack.ColumnWise, lapack.RowWise} { 26 for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} { 27 for _, test := range []struct { 28 m, n, ldv, ldt int 29 }{ 30 {6, 6, 0, 0}, 31 {8, 6, 0, 0}, 32 {6, 8, 0, 0}, 33 {6, 6, 10, 15}, 34 {8, 6, 10, 15}, 35 {6, 8, 10, 15}, 36 {6, 6, 15, 10}, 37 {8, 6, 15, 10}, 38 {6, 8, 15, 10}, 39 } { 40 // Generate a matrix 41 m := test.m 42 n := test.n 43 lda := n 44 if lda == 0 { 45 lda = n 46 } 47 48 a := make([]float64, m*lda) 49 for i := 0; i < m; i++ { 50 for j := 0; j < lda; j++ { 51 a[i*lda+j] = rnd.Float64() 52 } 53 } 54 // Use dgeqr2 to find the v vectors 55 tau := make([]float64, n) 56 work := make([]float64, n) 57 impl.Dgeqr2(m, n, a, lda, tau, work) 58 59 // Construct H using these answers 60 vMatTmp := extractVMat(m, n, a, lda, lapack.Forward, lapack.ColumnWise) 61 vMat := constructVMat(vMatTmp, store, direct) 62 v := vMat.Data 63 ldv := vMat.Stride 64 65 h := constructH(tau, vMat, store, direct) 66 67 k := min(m, n) 68 ldt := test.ldt 69 if ldt == 0 { 70 ldt = k 71 } 72 // Find T from the actual function 73 tm := make([]float64, k*ldt) 74 for i := range tm { 75 tm[i] = 100 + rnd.Float64() 76 } 77 // The v data has been put into a. 78 impl.Dlarft(direct, store, m, k, v, ldv, tau, tm, ldt) 79 80 tData := make([]float64, len(tm)) 81 copy(tData, tm) 82 if direct == lapack.Forward { 83 // Zero out the lower traingular portion. 84 for i := 0; i < k; i++ { 85 for j := 0; j < i; j++ { 86 tData[i*ldt+j] = 0 87 } 88 } 89 } else { 90 // Zero out the upper traingular portion. 91 for i := 0; i < k; i++ { 92 for j := i + 1; j < k; j++ { 93 tData[i*ldt+j] = 0 94 } 95 } 96 } 97 98 T := blas64.General{ 99 Rows: k, 100 Cols: k, 101 Stride: ldt, 102 Data: tData, 103 } 104 105 vMatT := blas64.General{ 106 Rows: vMat.Cols, 107 Cols: vMat.Rows, 108 Stride: vMat.Rows, 109 Data: make([]float64, vMat.Cols*vMat.Rows), 110 } 111 for i := 0; i < vMat.Rows; i++ { 112 for j := 0; j < vMat.Cols; j++ { 113 vMatT.Data[j*vMatT.Stride+i] = vMat.Data[i*vMat.Stride+j] 114 } 115 } 116 var comp blas64.General 117 if store == lapack.ColumnWise { 118 // H = I - V * T * Vᵀ 119 tmp := blas64.General{ 120 Rows: T.Rows, 121 Cols: vMatT.Cols, 122 Stride: vMatT.Cols, 123 Data: make([]float64, T.Rows*vMatT.Cols), 124 } 125 // T * Vᵀ 126 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, T, vMatT, 0, tmp) 127 comp = blas64.General{ 128 Rows: vMat.Rows, 129 Cols: tmp.Cols, 130 Stride: tmp.Cols, 131 Data: make([]float64, vMat.Rows*tmp.Cols), 132 } 133 // V * (T * Vᵀ) 134 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vMat, tmp, 0, comp) 135 } else { 136 // H = I - Vᵀ * T * V 137 tmp := blas64.General{ 138 Rows: T.Rows, 139 Cols: vMat.Cols, 140 Stride: vMat.Cols, 141 Data: make([]float64, T.Rows*vMat.Cols), 142 } 143 // T * V 144 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, T, vMat, 0, tmp) 145 comp = blas64.General{ 146 Rows: vMatT.Rows, 147 Cols: tmp.Cols, 148 Stride: tmp.Cols, 149 Data: make([]float64, vMatT.Rows*tmp.Cols), 150 } 151 // Vᵀ * (T * V) 152 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vMatT, tmp, 0, comp) 153 } 154 // I - Vᵀ * T * V 155 for i := 0; i < comp.Rows; i++ { 156 for j := 0; j < comp.Cols; j++ { 157 comp.Data[i*m+j] *= -1 158 if i == j { 159 comp.Data[i*m+j] += 1 160 } 161 } 162 } 163 if !floats.EqualApprox(comp.Data, h.Data, 1e-14) { 164 t.Errorf("T does not construct proper H. Store = %v, Direct = %v.\nWant %v\ngot %v.", string(store), string(direct), h.Data, comp.Data) 165 } 166 } 167 } 168 } 169 }