github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dlags2.go (about) 1 // Copyright ©2017 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "math" 9 "math/rand" 10 "testing" 11 12 "github.com/gonum/blas" 13 "github.com/gonum/blas/blas64" 14 "github.com/gonum/floats" 15 ) 16 17 type Dlags2er interface { 18 Dlags2(upper bool, a1, a2, a3, b1, b2, b3 float64) (csu, snu, csv, snv, csq, snq float64) 19 } 20 21 func Dlags2Test(t *testing.T, impl Dlags2er) { 22 rnd := rand.New(rand.NewSource(1)) 23 for _, upper := range []bool{true, false} { 24 for i := 0; i < 100; i++ { 25 a1 := rnd.Float64() 26 a2 := rnd.Float64() 27 a3 := rnd.Float64() 28 b1 := rnd.Float64() 29 b2 := rnd.Float64() 30 b3 := rnd.Float64() 31 32 csu, snu, csv, snv, csq, snq := impl.Dlags2(upper, a1, a2, a3, b1, b2, b3) 33 34 detU := det2x2(csu, snu, -snu, csu) 35 if !floats.EqualWithinAbsOrRel(math.Abs(detU), 1, 1e-14, 1e-14) { 36 t.Errorf("U not orthogonal: det(U)=%v", detU) 37 } 38 detV := det2x2(csv, snv, -snv, csv) 39 if !floats.EqualWithinAbsOrRel(math.Abs(detV), 1, 1e-14, 1e-14) { 40 t.Errorf("V not orthogonal: det(V)=%v", detV) 41 } 42 detQ := det2x2(csq, snq, -snq, csq) 43 if !floats.EqualWithinAbsOrRel(math.Abs(detQ), 1, 1e-14, 1e-14) { 44 t.Errorf("Q not orthogonal: det(Q)=%v", detQ) 45 } 46 47 u := blas64.General{ 48 Rows: 2, 49 Cols: 2, 50 Stride: 2, 51 Data: []float64{csu, snu, -snu, csu}, 52 } 53 v := blas64.General{ 54 Rows: 2, 55 Cols: 2, 56 Stride: 2, 57 Data: []float64{csv, snv, -snv, csv}, 58 } 59 q := blas64.General{ 60 Rows: 2, 61 Cols: 2, 62 Stride: 2, 63 Data: []float64{csq, snq, -snq, csq}, 64 } 65 66 a := blas64.General{Rows: 2, Cols: 2, Stride: 2} 67 b := blas64.General{Rows: 2, Cols: 2, Stride: 2} 68 if upper { 69 a.Data = []float64{a1, a2, 0, a3} 70 b.Data = []float64{b1, b2, 0, b3} 71 } else { 72 a.Data = []float64{a1, 0, a2, a3} 73 b.Data = []float64{b1, 0, b2, b3} 74 } 75 76 tmp := blas64.General{Rows: 2, Cols: 2, Stride: 2, Data: make([]float64, 4)} 77 blas64.Gemm(blas.Trans, blas.NoTrans, 1, u, a, 0, tmp) 78 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, a) 79 blas64.Gemm(blas.Trans, blas.NoTrans, 1, v, b, 0, tmp) 80 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, b) 81 82 var gotA, gotB float64 83 if upper { 84 gotA = a.Data[1] 85 gotB = b.Data[1] 86 } else { 87 gotA = a.Data[2] 88 gotB = b.Data[2] 89 } 90 if !floats.EqualWithinAbsOrRel(gotA, 0, 1e-14, 1e-14) { 91 t.Errorf("unexpected non-zero value for zero triangle of U^T*A*Q: %v", gotA) 92 } 93 if !floats.EqualWithinAbsOrRel(gotB, 0, 1e-14, 1e-14) { 94 t.Errorf("unexpected non-zero value for zero triangle of V^T*B*Q: %v", gotB) 95 } 96 } 97 } 98 } 99 100 func det2x2(a, b, c, d float64) float64 { return a*d - b*c }