github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dlarfb.go (about) 1 // Copyright ©2015 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "math/rand" 9 "testing" 10 11 "github.com/gonum/blas" 12 "github.com/gonum/blas/blas64" 13 "github.com/gonum/floats" 14 "github.com/gonum/lapack" 15 ) 16 17 type Dlarfber interface { 18 Dlarfter 19 Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct, 20 store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int, 21 c []float64, ldc int, work []float64, ldwork int) 22 } 23 24 func DlarfbTest(t *testing.T, impl Dlarfber) { 25 rnd := rand.New(rand.NewSource(1)) 26 for _, store := range []lapack.StoreV{lapack.ColumnWise, lapack.RowWise} { 27 for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} { 28 for _, side := range []blas.Side{blas.Left, blas.Right} { 29 for _, trans := range []blas.Transpose{blas.Trans, blas.NoTrans} { 30 for cas, test := range []struct { 31 ma, na, cdim, lda, ldt, ldc int 32 }{ 33 {6, 6, 6, 0, 0, 0}, 34 {6, 8, 10, 0, 0, 0}, 35 {6, 10, 8, 0, 0, 0}, 36 {8, 6, 10, 0, 0, 0}, 37 {8, 10, 6, 0, 0, 0}, 38 {10, 6, 8, 0, 0, 0}, 39 {10, 8, 6, 0, 0, 0}, 40 {6, 6, 6, 12, 15, 30}, 41 {6, 8, 10, 12, 15, 30}, 42 {6, 10, 8, 12, 15, 30}, 43 {8, 6, 10, 12, 15, 30}, 44 {8, 10, 6, 12, 15, 30}, 45 {10, 6, 8, 12, 15, 30}, 46 {10, 8, 6, 12, 15, 30}, 47 {6, 6, 6, 15, 12, 30}, 48 {6, 8, 10, 15, 12, 30}, 49 {6, 10, 8, 15, 12, 30}, 50 {8, 6, 10, 15, 12, 30}, 51 {8, 10, 6, 15, 12, 30}, 52 {10, 6, 8, 15, 12, 30}, 53 {10, 8, 6, 15, 12, 30}, 54 } { 55 // Generate a matrix for QR 56 ma := test.ma 57 na := test.na 58 lda := test.lda 59 if lda == 0 { 60 lda = na 61 } 62 a := make([]float64, ma*lda) 63 for i := 0; i < ma; i++ { 64 for j := 0; j < lda; j++ { 65 a[i*lda+j] = rnd.Float64() 66 } 67 } 68 k := min(ma, na) 69 70 // H is always ma x ma 71 var m, n, rowsWork int 72 switch { 73 default: 74 panic("not implemented") 75 case side == blas.Left: 76 m = test.ma 77 n = test.cdim 78 rowsWork = n 79 case side == blas.Right: 80 m = test.cdim 81 n = test.ma 82 rowsWork = m 83 } 84 85 // Use dgeqr2 to find the v vectors 86 tau := make([]float64, na) 87 work := make([]float64, na) 88 impl.Dgeqr2(ma, k, a, lda, tau, work) 89 90 // Correct the v vectors based on the direct and store 91 vMatTmp := extractVMat(ma, na, a, lda, lapack.Forward, lapack.ColumnWise) 92 vMat := constructVMat(vMatTmp, store, direct) 93 v := vMat.Data 94 ldv := vMat.Stride 95 96 // Use dlarft to find the t vector 97 ldt := test.ldt 98 if ldt == 0 { 99 ldt = k 100 } 101 tm := make([]float64, k*ldt) 102 103 impl.Dlarft(direct, store, ma, k, v, ldv, tau, tm, ldt) 104 105 // Generate c matrix 106 ldc := test.ldc 107 if ldc == 0 { 108 ldc = n 109 } 110 c := make([]float64, m*ldc) 111 for i := 0; i < m; i++ { 112 for j := 0; j < ldc; j++ { 113 c[i*ldc+j] = rnd.Float64() 114 } 115 } 116 cCopy := make([]float64, len(c)) 117 copy(cCopy, c) 118 119 ldwork := k 120 work = make([]float64, rowsWork*k) 121 122 // Call Dlarfb with this information 123 impl.Dlarfb(side, trans, direct, store, m, n, k, v, ldv, tm, ldt, c, ldc, work, ldwork) 124 125 h := constructH(tau, vMat, store, direct) 126 127 cMat := blas64.General{ 128 Rows: m, 129 Cols: n, 130 Stride: ldc, 131 Data: make([]float64, m*ldc), 132 } 133 copy(cMat.Data, cCopy) 134 ans := blas64.General{ 135 Rows: m, 136 Cols: n, 137 Stride: ldc, 138 Data: make([]float64, m*ldc), 139 } 140 copy(ans.Data, cMat.Data) 141 switch { 142 default: 143 panic("not implemented") 144 case side == blas.Left && trans == blas.NoTrans: 145 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, cMat, 0, ans) 146 case side == blas.Left && trans == blas.Trans: 147 blas64.Gemm(blas.Trans, blas.NoTrans, 1, h, cMat, 0, ans) 148 case side == blas.Right && trans == blas.NoTrans: 149 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMat, h, 0, ans) 150 case side == blas.Right && trans == blas.Trans: 151 blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMat, h, 0, ans) 152 } 153 if !floats.EqualApprox(ans.Data, c, 1e-14) { 154 t.Errorf("Cas %v mismatch. Want %v, got %v.", cas, ans.Data, c) 155 } 156 } 157 } 158 } 159 } 160 } 161 }