github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dsytrd.go (about)

     1  // Copyright ©2016 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math/rand"
    10  	"testing"
    11  
    12  	"github.com/gonum/blas"
    13  	"github.com/gonum/blas/blas64"
    14  )
    15  
    16  type Dsytrder interface {
    17  	Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau, work []float64, lwork int)
    18  
    19  	Dorgqr(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
    20  	Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
    21  }
    22  
    23  func DsytrdTest(t *testing.T, impl Dsytrder) {
    24  	const tol = 1e-13
    25  	rnd := rand.New(rand.NewSource(1))
    26  	for tc, test := range []struct {
    27  		n, lda int
    28  	}{
    29  		{1, 0},
    30  		{2, 0},
    31  		{3, 0},
    32  		{4, 0},
    33  		{10, 0},
    34  		{50, 0},
    35  		{100, 0},
    36  		{150, 0},
    37  		{300, 0},
    38  
    39  		{1, 3},
    40  		{2, 3},
    41  		{3, 7},
    42  		{4, 9},
    43  		{10, 20},
    44  		{50, 70},
    45  		{100, 120},
    46  		{150, 170},
    47  		{300, 320},
    48  	} {
    49  		for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
    50  			for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} {
    51  				n := test.n
    52  				lda := test.lda
    53  				if lda == 0 {
    54  					lda = n
    55  				}
    56  				a := randomGeneral(n, n, lda, rnd)
    57  				for i := 1; i < n; i++ {
    58  					for j := 0; j < i; j++ {
    59  						a.Data[i*a.Stride+j] = a.Data[j*a.Stride+i]
    60  					}
    61  				}
    62  				aCopy := cloneGeneral(a)
    63  
    64  				d := nanSlice(n)
    65  				e := nanSlice(n - 1)
    66  				tau := nanSlice(n - 1)
    67  
    68  				var lwork int
    69  				switch wl {
    70  				case minimumWork:
    71  					lwork = 1
    72  				case mediumWork:
    73  					work := make([]float64, 1)
    74  					impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1)
    75  					lwork = (int(work[0]) + 1) / 2
    76  					lwork = max(1, lwork)
    77  				case optimumWork:
    78  					work := make([]float64, 1)
    79  					impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1)
    80  					lwork = int(work[0])
    81  				}
    82  				work := make([]float64, lwork)
    83  
    84  				impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, lwork)
    85  
    86  				prefix := fmt.Sprintf("Case #%v: uplo=%v,n=%v,lda=%v,work=%v",
    87  					tc, uplo, n, lda, wl)
    88  
    89  				if !generalOutsideAllNaN(a) {
    90  					t.Errorf("%v: out-of-range write to A", prefix)
    91  				}
    92  
    93  				// Extract Q by doing what Dorgtr does.
    94  				q := cloneGeneral(a)
    95  				if uplo == blas.Upper {
    96  					for j := 0; j < n-1; j++ {
    97  						for i := 0; i < j; i++ {
    98  							q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j+1]
    99  						}
   100  						q.Data[(n-1)*q.Stride+j] = 0
   101  					}
   102  					for i := 0; i < n-1; i++ {
   103  						q.Data[i*q.Stride+n-1] = 0
   104  					}
   105  					q.Data[(n-1)*q.Stride+n-1] = 1
   106  					if n > 1 {
   107  						work = make([]float64, n-1)
   108  						impl.Dorgql(n-1, n-1, n-1, q.Data, q.Stride, tau, work, len(work))
   109  					}
   110  				} else {
   111  					for j := n - 1; j > 0; j-- {
   112  						q.Data[j] = 0
   113  						for i := j + 1; i < n; i++ {
   114  							q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j-1]
   115  						}
   116  					}
   117  					q.Data[0] = 1
   118  					for i := 1; i < n; i++ {
   119  						q.Data[i*q.Stride] = 0
   120  					}
   121  					if n > 1 {
   122  						work = make([]float64, n-1)
   123  						impl.Dorgqr(n-1, n-1, n-1, q.Data[q.Stride+1:], q.Stride, tau, work, len(work))
   124  					}
   125  				}
   126  				if !isOrthonormal(q) {
   127  					t.Errorf("%v: Q not orthogonal", prefix)
   128  				}
   129  
   130  				// Contruct symmetric tridiagonal T from d and e.
   131  				tMat := zeros(n, n, n)
   132  				for i := 0; i < n; i++ {
   133  					tMat.Data[i*tMat.Stride+i] = d[i]
   134  				}
   135  				if uplo == blas.Upper {
   136  					for j := 1; j < n; j++ {
   137  						tMat.Data[(j-1)*tMat.Stride+j] = e[j-1]
   138  						tMat.Data[j*tMat.Stride+j-1] = e[j-1]
   139  					}
   140  				} else {
   141  					for j := 0; j < n-1; j++ {
   142  						tMat.Data[(j+1)*tMat.Stride+j] = e[j]
   143  						tMat.Data[j*tMat.Stride+j+1] = e[j]
   144  					}
   145  				}
   146  
   147  				// Compute Q^T * A * Q.
   148  				tmp := zeros(n, n, n)
   149  				blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aCopy, 0, tmp)
   150  				got := zeros(n, n, n)
   151  				blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, got)
   152  
   153  				// Compare with T.
   154  				if !equalApproxGeneral(got, tMat, tol) {
   155  					t.Errorf("%v: Q^T*A*Q != T", prefix)
   156  				}
   157  			}
   158  		}
   159  	}
   160  }