github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dlags2.go (about)

     1  // Copyright ©2017 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math"
     9  	"math/rand"
    10  	"testing"
    11  
    12  	"github.com/gonum/blas"
    13  	"github.com/gonum/blas/blas64"
    14  	"github.com/gonum/floats"
    15  )
    16  
    17  type Dlags2er interface {
    18  	Dlags2(upper bool, a1, a2, a3, b1, b2, b3 float64) (csu, snu, csv, snv, csq, snq float64)
    19  }
    20  
    21  func Dlags2Test(t *testing.T, impl Dlags2er) {
    22  	rnd := rand.New(rand.NewSource(1))
    23  	for _, upper := range []bool{true, false} {
    24  		for i := 0; i < 100; i++ {
    25  			a1 := rnd.Float64()
    26  			a2 := rnd.Float64()
    27  			a3 := rnd.Float64()
    28  			b1 := rnd.Float64()
    29  			b2 := rnd.Float64()
    30  			b3 := rnd.Float64()
    31  
    32  			csu, snu, csv, snv, csq, snq := impl.Dlags2(upper, a1, a2, a3, b1, b2, b3)
    33  
    34  			detU := det2x2(csu, snu, -snu, csu)
    35  			if !floats.EqualWithinAbsOrRel(math.Abs(detU), 1, 1e-14, 1e-14) {
    36  				t.Errorf("U not orthogonal: det(U)=%v", detU)
    37  			}
    38  			detV := det2x2(csv, snv, -snv, csv)
    39  			if !floats.EqualWithinAbsOrRel(math.Abs(detV), 1, 1e-14, 1e-14) {
    40  				t.Errorf("V not orthogonal: det(V)=%v", detV)
    41  			}
    42  			detQ := det2x2(csq, snq, -snq, csq)
    43  			if !floats.EqualWithinAbsOrRel(math.Abs(detQ), 1, 1e-14, 1e-14) {
    44  				t.Errorf("Q not orthogonal: det(Q)=%v", detQ)
    45  			}
    46  
    47  			u := blas64.General{
    48  				Rows:   2,
    49  				Cols:   2,
    50  				Stride: 2,
    51  				Data:   []float64{csu, snu, -snu, csu},
    52  			}
    53  			v := blas64.General{
    54  				Rows:   2,
    55  				Cols:   2,
    56  				Stride: 2,
    57  				Data:   []float64{csv, snv, -snv, csv},
    58  			}
    59  			q := blas64.General{
    60  				Rows:   2,
    61  				Cols:   2,
    62  				Stride: 2,
    63  				Data:   []float64{csq, snq, -snq, csq},
    64  			}
    65  
    66  			a := blas64.General{Rows: 2, Cols: 2, Stride: 2}
    67  			b := blas64.General{Rows: 2, Cols: 2, Stride: 2}
    68  			if upper {
    69  				a.Data = []float64{a1, a2, 0, a3}
    70  				b.Data = []float64{b1, b2, 0, b3}
    71  			} else {
    72  				a.Data = []float64{a1, 0, a2, a3}
    73  				b.Data = []float64{b1, 0, b2, b3}
    74  			}
    75  
    76  			tmp := blas64.General{Rows: 2, Cols: 2, Stride: 2, Data: make([]float64, 4)}
    77  			blas64.Gemm(blas.Trans, blas.NoTrans, 1, u, a, 0, tmp)
    78  			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, a)
    79  			blas64.Gemm(blas.Trans, blas.NoTrans, 1, v, b, 0, tmp)
    80  			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, b)
    81  
    82  			var gotA, gotB float64
    83  			if upper {
    84  				gotA = a.Data[1]
    85  				gotB = b.Data[1]
    86  			} else {
    87  				gotA = a.Data[2]
    88  				gotB = b.Data[2]
    89  			}
    90  			if !floats.EqualWithinAbsOrRel(gotA, 0, 1e-14, 1e-14) {
    91  				t.Errorf("unexpected non-zero value for zero triangle of U^T*A*Q: %v", gotA)
    92  			}
    93  			if !floats.EqualWithinAbsOrRel(gotB, 0, 1e-14, 1e-14) {
    94  				t.Errorf("unexpected non-zero value for zero triangle of V^T*B*Q: %v", gotB)
    95  			}
    96  		}
    97  	}
    98  }
    99  
   100  func det2x2(a, b, c, d float64) float64 { return a*d - b*c }