github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dorgtr.go (about) 1 // Copyright ©2016 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "math/rand" 9 "testing" 10 11 "github.com/gonum/blas" 12 "github.com/gonum/blas/blas64" 13 "github.com/gonum/floats" 14 ) 15 16 type Dorgtrer interface { 17 Dorgtr(uplo blas.Uplo, n int, a []float64, lda int, tau, work []float64, lwork int) 18 Dsytrder 19 } 20 21 func DorgtrTest(t *testing.T, impl Dorgtrer) { 22 rnd := rand.New(rand.NewSource(1)) 23 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 24 for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} { 25 for _, test := range []struct { 26 n, lda int 27 }{ 28 {1, 0}, 29 {2, 0}, 30 {3, 0}, 31 {6, 0}, 32 {33, 0}, 33 {100, 0}, 34 35 {1, 3}, 36 {2, 5}, 37 {3, 7}, 38 {6, 10}, 39 {33, 50}, 40 {100, 120}, 41 } { 42 n := test.n 43 lda := test.lda 44 if lda == 0 { 45 lda = n 46 } 47 a := make([]float64, n*lda) 48 for i := range a { 49 a[i] = rnd.NormFloat64() 50 } 51 aCopy := make([]float64, len(a)) 52 copy(aCopy, a) 53 54 d := make([]float64, n) 55 e := make([]float64, n-1) 56 tau := make([]float64, n-1) 57 work := make([]float64, 1) 58 impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, -1) 59 work = make([]float64, int(work[0])) 60 impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, len(work)) 61 62 var lwork int 63 switch wl { 64 case minimumWork: 65 lwork = max(1, n-1) 66 case mediumWork: 67 work := make([]float64, 1) 68 impl.Dorgtr(uplo, n, a, lda, tau, work, -1) 69 lwork = (int(work[0]) + n - 1) / 2 70 lwork = max(1, lwork) 71 case optimumWork: 72 work := make([]float64, 1) 73 impl.Dorgtr(uplo, n, a, lda, tau, work, -1) 74 lwork = int(work[0]) 75 } 76 work = nanSlice(lwork) 77 78 impl.Dorgtr(uplo, n, a, lda, tau, work, len(work)) 79 80 q := blas64.General{ 81 Rows: n, 82 Cols: n, 83 Stride: lda, 84 Data: a, 85 } 86 tri := blas64.General{ 87 Rows: n, 88 Cols: n, 89 Stride: n, 90 Data: make([]float64, n*n), 91 } 92 for i := 0; i < n; i++ { 93 tri.Data[i*tri.Stride+i] = d[i] 94 if i != n-1 { 95 tri.Data[i*tri.Stride+i+1] = e[i] 96 tri.Data[(i+1)*tri.Stride+i] = e[i] 97 } 98 } 99 100 aMat := blas64.General{ 101 Rows: n, 102 Cols: n, 103 Stride: n, 104 Data: make([]float64, n*n), 105 } 106 if uplo == blas.Upper { 107 for i := 0; i < n; i++ { 108 for j := i; j < n; j++ { 109 v := aCopy[i*lda+j] 110 aMat.Data[i*aMat.Stride+j] = v 111 aMat.Data[j*aMat.Stride+i] = v 112 } 113 } 114 } else { 115 for i := 0; i < n; i++ { 116 for j := 0; j <= i; j++ { 117 v := aCopy[i*lda+j] 118 aMat.Data[i*aMat.Stride+j] = v 119 aMat.Data[j*aMat.Stride+i] = v 120 } 121 } 122 } 123 124 tmp := blas64.General{Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n)} 125 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, q, 0, tmp) 126 127 ans := blas64.General{Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n)} 128 blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, tmp, 0, ans) 129 130 if !floats.EqualApprox(ans.Data, tri.Data, 1e-13) { 131 t.Errorf("Recombination mismatch. n = %v, isUpper = %v", n, uplo == blas.Upper) 132 } 133 } 134 } 135 } 136 }