gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dlasr.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "math" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/blas" 14 "gonum.org/v1/gonum/blas/blas64" 15 "gonum.org/v1/gonum/floats" 16 "gonum.org/v1/gonum/lapack" 17 ) 18 19 type Dlasrer interface { 20 Dlasr(side blas.Side, pivot lapack.Pivot, direct lapack.Direct, m, n int, c, s, a []float64, lda int) 21 } 22 23 func DlasrTest(t *testing.T, impl Dlasrer) { 24 rnd := rand.New(rand.NewSource(1)) 25 for _, side := range []blas.Side{blas.Left, blas.Right} { 26 for _, pivot := range []lapack.Pivot{lapack.Variable, lapack.Top, lapack.Bottom} { 27 for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} { 28 for _, test := range []struct { 29 m, n, lda int 30 }{ 31 {5, 5, 0}, 32 {5, 10, 0}, 33 {10, 5, 0}, 34 35 {5, 5, 20}, 36 {5, 10, 20}, 37 {10, 5, 20}, 38 } { 39 m := test.m 40 n := test.n 41 lda := test.lda 42 if lda == 0 { 43 lda = n 44 } 45 // Allocate n×n matrix A and fill it with random numbers. 46 a := make([]float64, m*lda) 47 for i := range a { 48 a[i] = rnd.Float64() 49 } 50 51 // Allocate slices for implicitly 52 // represented rotation matrices. 53 var s, c []float64 54 if side == blas.Left { 55 s = make([]float64, m-1) 56 c = make([]float64, m-1) 57 } else { 58 s = make([]float64, n-1) 59 c = make([]float64, n-1) 60 } 61 for k := range s { 62 // Generate a random number in [0,2*pi). 63 theta := rnd.Float64() * 2 * math.Pi 64 s[k] = math.Sin(theta) 65 c[k] = math.Cos(theta) 66 } 67 aCopy := make([]float64, len(a)) 68 copy(a, aCopy) 69 70 // Apply plane a sequence of plane 71 // rotation in s and c to the matrix A. 72 impl.Dlasr(side, pivot, direct, m, n, c, s, a, lda) 73 74 // Compute a reference solution by multiplying A 75 // by explicitly formed rotation matrix P. 76 pSize := m 77 if side == blas.Right { 78 pSize = n 79 } 80 // Allocate matrix P. 81 p := blas64.General{ 82 Rows: pSize, 83 Cols: pSize, 84 Stride: pSize, 85 Data: make([]float64, pSize*pSize), 86 } 87 // Allocate matrix P_k. 88 pk := blas64.General{ 89 Rows: pSize, 90 Cols: pSize, 91 Stride: pSize, 92 Data: make([]float64, pSize*pSize), 93 } 94 ptmp := blas64.General{ 95 Rows: pSize, 96 Cols: pSize, 97 Stride: pSize, 98 Data: make([]float64, pSize*pSize), 99 } 100 // Initialize P to the identity matrix. 101 for i := 0; i < pSize; i++ { 102 p.Data[i*p.Stride+i] = 1 103 ptmp.Data[i*p.Stride+i] = 1 104 } 105 // Iterate over the sequence of plane rotations. 106 for k := range s { 107 // Set P_k to the identity matrix. 108 for i := range p.Data { 109 pk.Data[i] = 0 110 } 111 for i := 0; i < pSize; i++ { 112 pk.Data[i*p.Stride+i] = 1 113 } 114 // Set the corresponding elements of P_k. 115 switch pivot { 116 case lapack.Variable: 117 pk.Data[k*p.Stride+k] = c[k] 118 pk.Data[k*p.Stride+k+1] = s[k] 119 pk.Data[(k+1)*p.Stride+k] = -s[k] 120 pk.Data[(k+1)*p.Stride+k+1] = c[k] 121 case lapack.Top: 122 pk.Data[0] = c[k] 123 pk.Data[k+1] = s[k] 124 pk.Data[(k+1)*p.Stride] = -s[k] 125 pk.Data[(k+1)*p.Stride+k+1] = c[k] 126 case lapack.Bottom: 127 pk.Data[(pSize-1-k)*p.Stride+pSize-k-1] = c[k] 128 pk.Data[(pSize-1-k)*p.Stride+pSize-1] = s[k] 129 pk.Data[(pSize-1)*p.Stride+pSize-1-k] = -s[k] 130 pk.Data[(pSize-1)*p.Stride+pSize-1] = c[k] 131 } 132 // Compute P <- P_k * P or P <- P * P_k. 133 if direct == lapack.Forward { 134 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, pk, ptmp, 0, p) 135 } else { 136 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, ptmp, pk, 0, p) 137 } 138 copy(ptmp.Data, p.Data) 139 } 140 141 aMat := blas64.General{ 142 Rows: m, 143 Cols: n, 144 Stride: lda, 145 Data: make([]float64, m*lda), 146 } 147 copy(a, aCopy) 148 newA := blas64.General{ 149 Rows: m, 150 Cols: n, 151 Stride: lda, 152 Data: make([]float64, m*lda), 153 } 154 // Compute P * A or A * P. 155 if side == blas.Left { 156 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, p, aMat, 0, newA) 157 } else { 158 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, p, 0, newA) 159 } 160 // Compare the result from Dlasr with the reference solution. 161 if !floats.EqualApprox(newA.Data, a, 1e-12) { 162 t.Errorf("A update mismatch") 163 } 164 } 165 } 166 } 167 } 168 }