github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dlasr.go (about) 1 // Copyright ©2015 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "math" 9 "math/rand" 10 "testing" 11 12 "github.com/gonum/blas" 13 "github.com/gonum/blas/blas64" 14 "github.com/gonum/floats" 15 "github.com/gonum/lapack" 16 ) 17 18 type Dlasrer interface { 19 Dlasr(side blas.Side, pivot lapack.Pivot, direct lapack.Direct, m, n int, c, s, a []float64, lda int) 20 } 21 22 func DlasrTest(t *testing.T, impl Dlasrer) { 23 rnd := rand.New(rand.NewSource(1)) 24 for _, side := range []blas.Side{blas.Left, blas.Right} { 25 for _, pivot := range []lapack.Pivot{lapack.Variable, lapack.Top, lapack.Bottom} { 26 for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} { 27 for _, test := range []struct { 28 m, n, lda int 29 }{ 30 {5, 5, 0}, 31 {5, 10, 0}, 32 {10, 5, 0}, 33 34 {5, 5, 20}, 35 {5, 10, 20}, 36 {10, 5, 20}, 37 } { 38 m := test.m 39 n := test.n 40 lda := test.lda 41 if lda == 0 { 42 lda = n 43 } 44 a := make([]float64, m*lda) 45 for i := range a { 46 a[i] = rnd.Float64() 47 } 48 var s, c []float64 49 if side == blas.Left { 50 s = make([]float64, m-1) 51 c = make([]float64, m-1) 52 } else { 53 s = make([]float64, n-1) 54 c = make([]float64, n-1) 55 } 56 for k := range s { 57 theta := rnd.Float64() * 2 * math.Pi 58 s[k] = math.Sin(theta) 59 c[k] = math.Cos(theta) 60 } 61 aCopy := make([]float64, len(a)) 62 copy(a, aCopy) 63 impl.Dlasr(side, pivot, direct, m, n, c, s, a, lda) 64 65 pSize := m 66 if side == blas.Right { 67 pSize = n 68 } 69 p := blas64.General{ 70 Rows: pSize, 71 Cols: pSize, 72 Stride: pSize, 73 Data: make([]float64, pSize*pSize), 74 } 75 pk := blas64.General{ 76 Rows: pSize, 77 Cols: pSize, 78 Stride: pSize, 79 Data: make([]float64, pSize*pSize), 80 } 81 ptmp := blas64.General{ 82 Rows: pSize, 83 Cols: pSize, 84 Stride: pSize, 85 Data: make([]float64, pSize*pSize), 86 } 87 for i := 0; i < pSize; i++ { 88 p.Data[i*p.Stride+i] = 1 89 ptmp.Data[i*p.Stride+i] = 1 90 } 91 // Compare to direct computation. 92 for k := range s { 93 for i := range p.Data { 94 pk.Data[i] = 0 95 } 96 for i := 0; i < pSize; i++ { 97 pk.Data[i*p.Stride+i] = 1 98 } 99 if pivot == lapack.Variable { 100 pk.Data[k*p.Stride+k] = c[k] 101 pk.Data[k*p.Stride+k+1] = s[k] 102 pk.Data[(k+1)*p.Stride+k] = -s[k] 103 pk.Data[(k+1)*p.Stride+k+1] = c[k] 104 } else if pivot == lapack.Top { 105 pk.Data[0] = c[k] 106 pk.Data[k+1] = s[k] 107 pk.Data[(k+1)*p.Stride] = -s[k] 108 pk.Data[(k+1)*p.Stride+k+1] = c[k] 109 } else { 110 pk.Data[(pSize-1-k)*p.Stride+pSize-k-1] = c[k] 111 pk.Data[(pSize-1-k)*p.Stride+pSize-1] = s[k] 112 pk.Data[(pSize-1)*p.Stride+pSize-1-k] = -s[k] 113 pk.Data[(pSize-1)*p.Stride+pSize-1] = c[k] 114 } 115 if direct == lapack.Forward { 116 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, pk, ptmp, 0, p) 117 } else { 118 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, ptmp, pk, 0, p) 119 } 120 copy(ptmp.Data, p.Data) 121 } 122 123 aMat := blas64.General{ 124 Rows: m, 125 Cols: n, 126 Stride: lda, 127 Data: make([]float64, m*lda), 128 } 129 copy(a, aCopy) 130 newA := blas64.General{ 131 Rows: m, 132 Cols: n, 133 Stride: lda, 134 Data: make([]float64, m*lda), 135 } 136 if side == blas.Left { 137 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, p, aMat, 0, newA) 138 } else { 139 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, p, 0, newA) 140 } 141 if !floats.EqualApprox(newA.Data, a, 1e-12) { 142 t.Errorf("A update mismatch") 143 } 144 } 145 } 146 } 147 } 148 }