github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/testlapack/dsytrd.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"github.com/jingcheng-WU/gonum/blas"
    14  	"github.com/jingcheng-WU/gonum/blas/blas64"
    15  	"github.com/jingcheng-WU/gonum/lapack"
    16  )
    17  
    18  type Dsytrder interface {
    19  	Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau, work []float64, lwork int)
    20  
    21  	Dorgqr(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
    22  	Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
    23  }
    24  
    25  func DsytrdTest(t *testing.T, impl Dsytrder) {
    26  	const tol = 1e-14
    27  
    28  	rnd := rand.New(rand.NewSource(1))
    29  	for tc, test := range []struct {
    30  		n, lda int
    31  	}{
    32  		{1, 0},
    33  		{2, 0},
    34  		{3, 0},
    35  		{4, 0},
    36  		{10, 0},
    37  		{50, 0},
    38  		{100, 0},
    39  		{150, 0},
    40  		{300, 0},
    41  
    42  		{1, 3},
    43  		{2, 3},
    44  		{3, 7},
    45  		{4, 9},
    46  		{10, 20},
    47  		{50, 70},
    48  		{100, 120},
    49  		{150, 170},
    50  		{300, 320},
    51  	} {
    52  		for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
    53  			for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} {
    54  				n := test.n
    55  				lda := test.lda
    56  				if lda == 0 {
    57  					lda = n
    58  				}
    59  				a := randomGeneral(n, n, lda, rnd)
    60  				for i := 1; i < n; i++ {
    61  					for j := 0; j < i; j++ {
    62  						a.Data[i*a.Stride+j] = a.Data[j*a.Stride+i]
    63  					}
    64  				}
    65  				aCopy := cloneGeneral(a)
    66  
    67  				d := nanSlice(n)
    68  				e := nanSlice(n - 1)
    69  				tau := nanSlice(n - 1)
    70  
    71  				var lwork int
    72  				switch wl {
    73  				case minimumWork:
    74  					lwork = 1
    75  				case mediumWork:
    76  					work := make([]float64, 1)
    77  					impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1)
    78  					lwork = (int(work[0]) + 1) / 2
    79  					lwork = max(1, lwork)
    80  				case optimumWork:
    81  					work := make([]float64, 1)
    82  					impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1)
    83  					lwork = int(work[0])
    84  				}
    85  				work := make([]float64, lwork)
    86  
    87  				impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, lwork)
    88  
    89  				prefix := fmt.Sprintf("Case #%v: uplo=%c,n=%v,lda=%v,work=%v",
    90  					tc, uplo, n, lda, wl)
    91  
    92  				if !generalOutsideAllNaN(a) {
    93  					t.Errorf("%v: out-of-range write to A", prefix)
    94  				}
    95  
    96  				// Extract Q by doing what Dorgtr does.
    97  				q := cloneGeneral(a)
    98  				if uplo == blas.Upper {
    99  					for j := 0; j < n-1; j++ {
   100  						for i := 0; i < j; i++ {
   101  							q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j+1]
   102  						}
   103  						q.Data[(n-1)*q.Stride+j] = 0
   104  					}
   105  					for i := 0; i < n-1; i++ {
   106  						q.Data[i*q.Stride+n-1] = 0
   107  					}
   108  					q.Data[(n-1)*q.Stride+n-1] = 1
   109  					if n > 1 {
   110  						work = make([]float64, n-1)
   111  						impl.Dorgql(n-1, n-1, n-1, q.Data, q.Stride, tau, work, len(work))
   112  					}
   113  				} else {
   114  					for j := n - 1; j > 0; j-- {
   115  						q.Data[j] = 0
   116  						for i := j + 1; i < n; i++ {
   117  							q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j-1]
   118  						}
   119  					}
   120  					q.Data[0] = 1
   121  					for i := 1; i < n; i++ {
   122  						q.Data[i*q.Stride] = 0
   123  					}
   124  					if n > 1 {
   125  						work = make([]float64, n-1)
   126  						impl.Dorgqr(n-1, n-1, n-1, q.Data[q.Stride+1:], q.Stride, tau, work, len(work))
   127  					}
   128  				}
   129  				if resid := residualOrthogonal(q, false); resid > tol*float64(n) {
   130  					t.Errorf("%v: Q is not orthogonal; resid=%v, want<=%v", prefix, resid, tol*float64(n))
   131  				}
   132  
   133  				// Contruct symmetric tridiagonal T from d and e.
   134  				tMat := zeros(n, n, n)
   135  				for i := 0; i < n; i++ {
   136  					tMat.Data[i*tMat.Stride+i] = d[i]
   137  				}
   138  				if uplo == blas.Upper {
   139  					for j := 1; j < n; j++ {
   140  						tMat.Data[(j-1)*tMat.Stride+j] = e[j-1]
   141  						tMat.Data[j*tMat.Stride+j-1] = e[j-1]
   142  					}
   143  				} else {
   144  					for j := 0; j < n-1; j++ {
   145  						tMat.Data[(j+1)*tMat.Stride+j] = e[j]
   146  						tMat.Data[j*tMat.Stride+j+1] = e[j]
   147  					}
   148  				}
   149  				// Compute Qᵀ*A*Q - T.
   150  				qa := zeros(n, n, n)
   151  				blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aCopy, 0, qa)
   152  				blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qa, q, -1, tMat)
   153  				// Check that |Qᵀ*A*Q - T| is small.
   154  				resid := dlange(lapack.MaxColumnSum, n, n, tMat.Data, tMat.Stride)
   155  				if resid > tol*float64(n) {
   156  					t.Errorf("%v: |Qᵀ*A*Q - T|=%v, want<=%v", prefix, resid, tol*float64(n))
   157  				}
   158  			}
   159  		}
   160  	}
   161  }