gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dsytd2.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"gonum.org/v1/gonum/blas"
    14  	"gonum.org/v1/gonum/blas/blas64"
    15  )
    16  
    17  type Dsytd2er interface {
    18  	Dsytd2(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau []float64)
    19  }
    20  
    21  func Dsytd2Test(t *testing.T, impl Dsytd2er) {
    22  	const tol = 1e-14
    23  
    24  	rnd := rand.New(rand.NewSource(1))
    25  	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
    26  		for _, test := range []struct {
    27  			n, lda int
    28  		}{
    29  			{3, 0},
    30  			{4, 0},
    31  			{5, 0},
    32  
    33  			{3, 10},
    34  			{4, 10},
    35  			{5, 10},
    36  		} {
    37  			n := test.n
    38  			lda := test.lda
    39  			if lda == 0 {
    40  				lda = n
    41  			}
    42  			a := make([]float64, n*lda)
    43  			for i := range a {
    44  				a[i] = rnd.NormFloat64()
    45  			}
    46  			aCopy := make([]float64, len(a))
    47  			copy(aCopy, a)
    48  
    49  			d := make([]float64, n)
    50  			for i := range d {
    51  				d[i] = math.NaN()
    52  			}
    53  			e := make([]float64, n-1)
    54  			for i := range e {
    55  				e[i] = math.NaN()
    56  			}
    57  			tau := make([]float64, n-1)
    58  			for i := range tau {
    59  				tau[i] = math.NaN()
    60  			}
    61  
    62  			impl.Dsytd2(uplo, n, a, lda, d, e, tau)
    63  
    64  			// Construct Q
    65  			qMat := blas64.General{
    66  				Rows:   n,
    67  				Cols:   n,
    68  				Stride: n,
    69  				Data:   make([]float64, n*n),
    70  			}
    71  			qCopy := blas64.General{
    72  				Rows:   n,
    73  				Cols:   n,
    74  				Stride: n,
    75  				Data:   make([]float64, len(qMat.Data)),
    76  			}
    77  			// Set Q to I.
    78  			for i := 0; i < n; i++ {
    79  				qMat.Data[i*qMat.Stride+i] = 1
    80  			}
    81  			for i := 0; i < n-1; i++ {
    82  				hMat := blas64.General{
    83  					Rows:   n,
    84  					Cols:   n,
    85  					Stride: n,
    86  					Data:   make([]float64, n*n),
    87  				}
    88  				// Set H to I.
    89  				for i := 0; i < n; i++ {
    90  					hMat.Data[i*hMat.Stride+i] = 1
    91  				}
    92  				var vi blas64.Vector
    93  				if uplo == blas.Upper {
    94  					vi = blas64.Vector{
    95  						Inc:  1,
    96  						Data: make([]float64, n),
    97  					}
    98  					for j := 0; j < i; j++ {
    99  						vi.Data[j] = a[j*lda+i+1]
   100  					}
   101  					vi.Data[i] = 1
   102  				} else {
   103  					vi = blas64.Vector{
   104  						Inc:  1,
   105  						Data: make([]float64, n),
   106  					}
   107  					vi.Data[i+1] = 1
   108  					for j := i + 2; j < n; j++ {
   109  						vi.Data[j] = a[j*lda+i]
   110  					}
   111  				}
   112  				blas64.Ger(-tau[i], vi, vi, hMat)
   113  				copy(qCopy.Data, qMat.Data)
   114  
   115  				// Multiply q by the new h.
   116  				if uplo == blas.Upper {
   117  					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hMat, qCopy, 0, qMat)
   118  				} else {
   119  					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, hMat, 0, qMat)
   120  				}
   121  			}
   122  
   123  			if resid := residualOrthogonal(qMat, false); resid > tol {
   124  				t.Errorf("Q is not orthogonal; resid=%v, want<=%v", resid, tol)
   125  			}
   126  
   127  			// Compute Qᵀ * A * Q.
   128  			aMat := blas64.General{
   129  				Rows:   n,
   130  				Cols:   n,
   131  				Stride: n,
   132  				Data:   make([]float64, len(a)),
   133  			}
   134  
   135  			for i := 0; i < n; i++ {
   136  				for j := i; j < n; j++ {
   137  					v := aCopy[i*lda+j]
   138  					if uplo == blas.Lower {
   139  						v = aCopy[j*lda+i]
   140  					}
   141  					aMat.Data[i*aMat.Stride+j] = v
   142  					aMat.Data[j*aMat.Stride+i] = v
   143  				}
   144  			}
   145  
   146  			tmp := blas64.General{
   147  				Rows:   n,
   148  				Cols:   n,
   149  				Stride: n,
   150  				Data:   make([]float64, n*n),
   151  			}
   152  
   153  			ans := blas64.General{
   154  				Rows:   n,
   155  				Cols:   n,
   156  				Stride: n,
   157  				Data:   make([]float64, n*n),
   158  			}
   159  
   160  			blas64.Gemm(blas.Trans, blas.NoTrans, 1, qMat, aMat, 0, tmp)
   161  			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, qMat, 0, ans)
   162  
   163  			// Compare with T.
   164  			tMat := blas64.General{
   165  				Rows:   n,
   166  				Cols:   n,
   167  				Stride: n,
   168  				Data:   make([]float64, n*n),
   169  			}
   170  			for i := 0; i < n-1; i++ {
   171  				tMat.Data[i*tMat.Stride+i] = d[i]
   172  				tMat.Data[i*tMat.Stride+i+1] = e[i]
   173  				tMat.Data[(i+1)*tMat.Stride+i] = e[i]
   174  			}
   175  			tMat.Data[(n-1)*tMat.Stride+n-1] = d[n-1]
   176  
   177  			same := true
   178  			for i := 0; i < n; i++ {
   179  				for j := 0; j < n; j++ {
   180  					if math.Abs(ans.Data[i*ans.Stride+j]-tMat.Data[i*tMat.Stride+j]) > tol {
   181  						same = false
   182  					}
   183  				}
   184  			}
   185  			if !same {
   186  				t.Errorf("Matrix answer mismatch")
   187  			}
   188  		}
   189  	}
   190  }