github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dsytd2.go (about)

     1  // Copyright ©2016 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math"
     9  	"math/rand"
    10  	"testing"
    11  
    12  	"github.com/gonum/blas"
    13  	"github.com/gonum/blas/blas64"
    14  )
    15  
    16  type Dsytd2er interface {
    17  	Dsytd2(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau []float64)
    18  }
    19  
    20  func Dsytd2Test(t *testing.T, impl Dsytd2er) {
    21  	rnd := rand.New(rand.NewSource(1))
    22  	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
    23  		for _, test := range []struct {
    24  			n, lda int
    25  		}{
    26  			{3, 0},
    27  			{4, 0},
    28  			{5, 0},
    29  
    30  			{3, 10},
    31  			{4, 10},
    32  			{5, 10},
    33  		} {
    34  			n := test.n
    35  			lda := test.lda
    36  			if lda == 0 {
    37  				lda = n
    38  			}
    39  			a := make([]float64, n*lda)
    40  			for i := range a {
    41  				a[i] = rnd.NormFloat64()
    42  			}
    43  			aCopy := make([]float64, len(a))
    44  			copy(aCopy, a)
    45  
    46  			d := make([]float64, n)
    47  			for i := range d {
    48  				d[i] = math.NaN()
    49  			}
    50  			e := make([]float64, n-1)
    51  			for i := range e {
    52  				e[i] = math.NaN()
    53  			}
    54  			tau := make([]float64, n-1)
    55  			for i := range tau {
    56  				tau[i] = math.NaN()
    57  			}
    58  
    59  			impl.Dsytd2(uplo, n, a, lda, d, e, tau)
    60  
    61  			// Construct Q
    62  			qMat := blas64.General{
    63  				Rows:   n,
    64  				Cols:   n,
    65  				Stride: n,
    66  				Data:   make([]float64, n*n),
    67  			}
    68  			qCopy := blas64.General{
    69  				Rows:   n,
    70  				Cols:   n,
    71  				Stride: n,
    72  				Data:   make([]float64, len(qMat.Data)),
    73  			}
    74  			// Set Q to I.
    75  			for i := 0; i < n; i++ {
    76  				qMat.Data[i*qMat.Stride+i] = 1
    77  			}
    78  			for i := 0; i < n-1; i++ {
    79  				hMat := blas64.General{
    80  					Rows:   n,
    81  					Cols:   n,
    82  					Stride: n,
    83  					Data:   make([]float64, n*n),
    84  				}
    85  				// Set H to I.
    86  				for i := 0; i < n; i++ {
    87  					hMat.Data[i*hMat.Stride+i] = 1
    88  				}
    89  				var vi blas64.Vector
    90  				if uplo == blas.Upper {
    91  					vi = blas64.Vector{
    92  						Inc:  1,
    93  						Data: make([]float64, n),
    94  					}
    95  					for j := 0; j < i; j++ {
    96  						vi.Data[j] = a[j*lda+i+1]
    97  					}
    98  					vi.Data[i] = 1
    99  				} else {
   100  					vi = blas64.Vector{
   101  						Inc:  1,
   102  						Data: make([]float64, n),
   103  					}
   104  					vi.Data[i+1] = 1
   105  					for j := i + 2; j < n; j++ {
   106  						vi.Data[j] = a[j*lda+i]
   107  					}
   108  				}
   109  				blas64.Ger(-tau[i], vi, vi, hMat)
   110  				copy(qCopy.Data, qMat.Data)
   111  
   112  				// Multiply q by the new h.
   113  				if uplo == blas.Upper {
   114  					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hMat, qCopy, 0, qMat)
   115  				} else {
   116  					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, hMat, 0, qMat)
   117  				}
   118  			}
   119  			// Check that Q is orthonormal
   120  			othonormal := true
   121  			for i := 0; i < n; i++ {
   122  				for j := i; j < n; j++ {
   123  					dot := blas64.Dot(n,
   124  						blas64.Vector{Inc: 1, Data: qMat.Data[i*qMat.Stride:]},
   125  						blas64.Vector{Inc: 1, Data: qMat.Data[j*qMat.Stride:]},
   126  					)
   127  					if i == j {
   128  						if math.Abs(dot-1) > 1e-10 {
   129  							othonormal = false
   130  						}
   131  					} else {
   132  						if math.Abs(dot) > 1e-10 {
   133  							othonormal = false
   134  						}
   135  					}
   136  				}
   137  			}
   138  			if !othonormal {
   139  				t.Errorf("Q not orthonormal")
   140  			}
   141  
   142  			// Compute Q^T * A * Q.
   143  			aMat := blas64.General{
   144  				Rows:   n,
   145  				Cols:   n,
   146  				Stride: n,
   147  				Data:   make([]float64, len(a)),
   148  			}
   149  
   150  			for i := 0; i < n; i++ {
   151  				for j := i; j < n; j++ {
   152  					v := aCopy[i*lda+j]
   153  					if uplo == blas.Lower {
   154  						v = aCopy[j*lda+i]
   155  					}
   156  					aMat.Data[i*aMat.Stride+j] = v
   157  					aMat.Data[j*aMat.Stride+i] = v
   158  				}
   159  			}
   160  
   161  			tmp := blas64.General{
   162  				Rows:   n,
   163  				Cols:   n,
   164  				Stride: n,
   165  				Data:   make([]float64, n*n),
   166  			}
   167  
   168  			ans := blas64.General{
   169  				Rows:   n,
   170  				Cols:   n,
   171  				Stride: n,
   172  				Data:   make([]float64, n*n),
   173  			}
   174  
   175  			blas64.Gemm(blas.Trans, blas.NoTrans, 1, qMat, aMat, 0, tmp)
   176  			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, qMat, 0, ans)
   177  
   178  			// Compare with T.
   179  			tMat := blas64.General{
   180  				Rows:   n,
   181  				Cols:   n,
   182  				Stride: n,
   183  				Data:   make([]float64, n*n),
   184  			}
   185  			for i := 0; i < n-1; i++ {
   186  				tMat.Data[i*tMat.Stride+i] = d[i]
   187  				tMat.Data[i*tMat.Stride+i+1] = e[i]
   188  				tMat.Data[(i+1)*tMat.Stride+i] = e[i]
   189  			}
   190  			tMat.Data[(n-1)*tMat.Stride+n-1] = d[n-1]
   191  
   192  			same := true
   193  			for i := 0; i < n; i++ {
   194  				for j := 0; j < n; j++ {
   195  					if math.Abs(ans.Data[i*ans.Stride+j]-tMat.Data[i*tMat.Stride+j]) > 1e-10 {
   196  						same = false
   197  					}
   198  				}
   199  			}
   200  			if !same {
   201  				t.Errorf("Matrix answer mismatch")
   202  			}
   203  		}
   204  	}
   205  }