gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dormbr.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "testing" 9 10 "golang.org/x/exp/rand" 11 12 "gonum.org/v1/gonum/blas" 13 "gonum.org/v1/gonum/blas/blas64" 14 "gonum.org/v1/gonum/floats" 15 "gonum.org/v1/gonum/lapack" 16 ) 17 18 type Dormbrer interface { 19 Dormbr(vect lapack.ApplyOrtho, side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) 20 Dgebrder 21 } 22 23 func DormbrTest(t *testing.T, impl Dormbrer) { 24 rnd := rand.New(rand.NewSource(1)) 25 bi := blas64.Implementation() 26 for _, vect := range []lapack.ApplyOrtho{lapack.ApplyQ, lapack.ApplyP} { 27 for _, side := range []blas.Side{blas.Left, blas.Right} { 28 for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} { 29 for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} { 30 for _, test := range []struct { 31 m, n, k, lda, ldc int 32 }{ 33 {3, 4, 5, 0, 0}, 34 {3, 5, 4, 0, 0}, 35 {4, 3, 5, 0, 0}, 36 {4, 5, 3, 0, 0}, 37 {5, 3, 4, 0, 0}, 38 {5, 4, 3, 0, 0}, 39 40 {3, 4, 5, 10, 12}, 41 {3, 5, 4, 10, 12}, 42 {4, 3, 5, 10, 12}, 43 {4, 5, 3, 10, 12}, 44 {5, 3, 4, 10, 12}, 45 {5, 4, 3, 10, 12}, 46 47 {150, 140, 130, 0, 0}, 48 } { 49 m := test.m 50 n := test.n 51 k := test.k 52 ldc := test.ldc 53 if ldc == 0 { 54 ldc = n 55 } 56 nq := n 57 nw := m 58 if side == blas.Left { 59 nq = m 60 nw = n 61 } 62 63 // Compute a decomposition. 64 var ma, na int 65 var a []float64 66 if vect == lapack.ApplyQ { 67 ma = nq 68 na = k 69 } else { 70 ma = k 71 na = nq 72 } 73 lda := test.lda 74 if lda == 0 { 75 lda = na 76 } 77 a = make([]float64, ma*lda) 78 for i := range a { 79 a[i] = rnd.NormFloat64() 80 } 81 nTau := min(nq, k) 82 tauP := make([]float64, nTau) 83 tauQ := make([]float64, nTau) 84 d := make([]float64, nTau) 85 e := make([]float64, nTau) 86 87 work := make([]float64, 1) 88 impl.Dgebrd(ma, na, a, lda, d, e, tauQ, tauP, work, -1) 89 work = make([]float64, int(work[0])) 90 impl.Dgebrd(ma, na, a, lda, d, e, tauQ, tauP, work, len(work)) 91 92 // Apply and compare update. 93 c := make([]float64, m*ldc) 94 for i := range c { 95 c[i] = rnd.NormFloat64() 96 } 97 cCopy := make([]float64, len(c)) 98 copy(cCopy, c) 99 100 var lwork int 101 switch wl { 102 case minimumWork: 103 lwork = nw 104 case optimumWork: 105 impl.Dormbr(vect, side, trans, m, n, k, a, lda, tauQ, c, ldc, work, -1) 106 lwork = int(work[0]) 107 case mediumWork: 108 work := make([]float64, 1) 109 impl.Dormbr(vect, side, trans, m, n, k, a, lda, tauQ, c, ldc, work, -1) 110 lwork = (int(work[0]) + nw) / 2 111 } 112 lwork = max(1, lwork) 113 work = make([]float64, lwork) 114 115 if vect == lapack.ApplyQ { 116 impl.Dormbr(vect, side, trans, m, n, k, a, lda, tauQ, c, ldc, work, lwork) 117 } else { 118 impl.Dormbr(vect, side, trans, m, n, k, a, lda, tauP, c, ldc, work, lwork) 119 } 120 121 // Check that the multiplication was correct. 122 cOrig := blas64.General{ 123 Rows: m, 124 Cols: n, 125 Stride: ldc, 126 Data: make([]float64, len(cCopy)), 127 } 128 copy(cOrig.Data, cCopy) 129 cAns := blas64.General{ 130 Rows: m, 131 Cols: n, 132 Stride: ldc, 133 Data: make([]float64, len(cCopy)), 134 } 135 copy(cAns.Data, cCopy) 136 nb := min(ma, na) 137 var mulMat blas64.General 138 if vect == lapack.ApplyQ { 139 mulMat = constructQPBidiagonal(lapack.ApplyQ, ma, na, nb, a, lda, tauQ) 140 } else { 141 mulMat = constructQPBidiagonal(lapack.ApplyP, ma, na, nb, a, lda, tauP) 142 } 143 144 mulTrans := trans 145 146 if side == blas.Left { 147 bi.Dgemm(mulTrans, blas.NoTrans, m, n, m, 1, mulMat.Data, mulMat.Stride, cOrig.Data, cOrig.Stride, 0, cAns.Data, cAns.Stride) 148 } else { 149 bi.Dgemm(blas.NoTrans, mulTrans, m, n, n, 1, cOrig.Data, cOrig.Stride, mulMat.Data, mulMat.Stride, 0, cAns.Data, cAns.Stride) 150 } 151 152 if !floats.EqualApprox(cAns.Data, c, 1e-13) { 153 isApplyQ := vect == lapack.ApplyQ 154 isLeft := side == blas.Left 155 isTrans := trans == blas.Trans 156 157 t.Errorf("C mismatch. isApplyQ: %v, isLeft: %v, isTrans: %v, m = %v, n = %v, k = %v, lda = %v, ldc = %v", 158 isApplyQ, isLeft, isTrans, m, n, k, lda, ldc) 159 } 160 } 161 } 162 } 163 } 164 } 165 }