github.com/gopherd/gonum@v0.0.4/lapack/testlapack/dlarfb.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "testing" 9 10 "math/rand" 11 12 "github.com/gopherd/gonum/blas" 13 "github.com/gopherd/gonum/blas/blas64" 14 "github.com/gopherd/gonum/floats" 15 "github.com/gopherd/gonum/lapack" 16 ) 17 18 type Dlarfber interface { 19 Dlarfter 20 Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct, 21 store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int, 22 c []float64, ldc int, work []float64, ldwork int) 23 } 24 25 func DlarfbTest(t *testing.T, impl Dlarfber) { 26 rnd := rand.New(rand.NewSource(1)) 27 for _, store := range []lapack.StoreV{lapack.ColumnWise, lapack.RowWise} { 28 for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} { 29 for _, side := range []blas.Side{blas.Left, blas.Right} { 30 for _, trans := range []blas.Transpose{blas.Trans, blas.NoTrans} { 31 for cas, test := range []struct { 32 ma, na, cdim, lda, ldt, ldc int 33 }{ 34 {6, 6, 6, 0, 0, 0}, 35 {6, 8, 10, 0, 0, 0}, 36 {6, 10, 8, 0, 0, 0}, 37 {8, 6, 10, 0, 0, 0}, 38 {8, 10, 6, 0, 0, 0}, 39 {10, 6, 8, 0, 0, 0}, 40 {10, 8, 6, 0, 0, 0}, 41 {6, 6, 6, 12, 15, 30}, 42 {6, 8, 10, 12, 15, 30}, 43 {6, 10, 8, 12, 15, 30}, 44 {8, 6, 10, 12, 15, 30}, 45 {8, 10, 6, 12, 15, 30}, 46 {10, 6, 8, 12, 15, 30}, 47 {10, 8, 6, 12, 15, 30}, 48 {6, 6, 6, 15, 12, 30}, 49 {6, 8, 10, 15, 12, 30}, 50 {6, 10, 8, 15, 12, 30}, 51 {8, 6, 10, 15, 12, 30}, 52 {8, 10, 6, 15, 12, 30}, 53 {10, 6, 8, 15, 12, 30}, 54 {10, 8, 6, 15, 12, 30}, 55 } { 56 // Generate a matrix for QR 57 ma := test.ma 58 na := test.na 59 lda := test.lda 60 if lda == 0 { 61 lda = na 62 } 63 a := make([]float64, ma*lda) 64 for i := 0; i < ma; i++ { 65 for j := 0; j < lda; j++ { 66 a[i*lda+j] = rnd.Float64() 67 } 68 } 69 k := min(ma, na) 70 71 // H is always ma x ma 72 var m, n, rowsWork int 73 switch { 74 default: 75 panic("not implemented") 76 case side == blas.Left: 77 m = test.ma 78 n = test.cdim 79 rowsWork = n 80 case side == blas.Right: 81 m = test.cdim 82 n = test.ma 83 rowsWork = m 84 } 85 86 // Use dgeqr2 to find the v vectors 87 tau := make([]float64, na) 88 work := make([]float64, na) 89 impl.Dgeqr2(ma, k, a, lda, tau, work) 90 91 // Correct the v vectors based on the direct and store 92 vMatTmp := extractVMat(ma, na, a, lda, lapack.Forward, lapack.ColumnWise) 93 vMat := constructVMat(vMatTmp, store, direct) 94 v := vMat.Data 95 ldv := vMat.Stride 96 97 // Use dlarft to find the t vector 98 ldt := test.ldt 99 if ldt == 0 { 100 ldt = k 101 } 102 tm := make([]float64, k*ldt) 103 104 impl.Dlarft(direct, store, ma, k, v, ldv, tau, tm, ldt) 105 106 // Generate c matrix 107 ldc := test.ldc 108 if ldc == 0 { 109 ldc = n 110 } 111 c := make([]float64, m*ldc) 112 for i := 0; i < m; i++ { 113 for j := 0; j < ldc; j++ { 114 c[i*ldc+j] = rnd.Float64() 115 } 116 } 117 cCopy := make([]float64, len(c)) 118 copy(cCopy, c) 119 120 ldwork := k 121 work = make([]float64, rowsWork*k) 122 123 // Call Dlarfb with this information 124 impl.Dlarfb(side, trans, direct, store, m, n, k, v, ldv, tm, ldt, c, ldc, work, ldwork) 125 126 h := constructH(tau, vMat, store, direct) 127 128 cMat := blas64.General{ 129 Rows: m, 130 Cols: n, 131 Stride: ldc, 132 Data: make([]float64, m*ldc), 133 } 134 copy(cMat.Data, cCopy) 135 ans := blas64.General{ 136 Rows: m, 137 Cols: n, 138 Stride: ldc, 139 Data: make([]float64, m*ldc), 140 } 141 copy(ans.Data, cMat.Data) 142 switch { 143 default: 144 panic("not implemented") 145 case side == blas.Left && trans == blas.NoTrans: 146 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, cMat, 0, ans) 147 case side == blas.Left && trans == blas.Trans: 148 blas64.Gemm(blas.Trans, blas.NoTrans, 1, h, cMat, 0, ans) 149 case side == blas.Right && trans == blas.NoTrans: 150 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMat, h, 0, ans) 151 case side == blas.Right && trans == blas.Trans: 152 blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMat, h, 0, ans) 153 } 154 if !floats.EqualApprox(ans.Data, c, 1e-14) { 155 t.Errorf("Cas %v mismatch. Want %v, got %v.", cas, ans.Data, c) 156 } 157 } 158 } 159 } 160 } 161 } 162 }