gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dlatrd.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"testing"
    11  
    12  	"golang.org/x/exp/rand"
    13  
    14  	"gonum.org/v1/gonum/blas"
    15  	"gonum.org/v1/gonum/blas/blas64"
    16  )
    17  
    18  type Dlatrder interface {
    19  	Dlatrd(uplo blas.Uplo, n, nb int, a []float64, lda int, e, tau, w []float64, ldw int)
    20  }
    21  
    22  func DlatrdTest(t *testing.T, impl Dlatrder) {
    23  	const tol = 1e-14
    24  
    25  	rnd := rand.New(rand.NewSource(1))
    26  	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
    27  		for _, test := range []struct {
    28  			n, nb, lda, ldw int
    29  		}{
    30  			{5, 2, 0, 0},
    31  			{5, 5, 0, 0},
    32  
    33  			{5, 3, 10, 11},
    34  			{5, 5, 10, 11},
    35  		} {
    36  			n := test.n
    37  			nb := test.nb
    38  			lda := test.lda
    39  			if lda == 0 {
    40  				lda = n
    41  			}
    42  			ldw := test.ldw
    43  			if ldw == 0 {
    44  				ldw = nb
    45  			}
    46  
    47  			// Allocate n×n matrix A and fill it with random numbers.
    48  			a := make([]float64, n*lda)
    49  			for i := range a {
    50  				a[i] = rnd.NormFloat64()
    51  			}
    52  
    53  			// Allocate output slices and matrix W and fill them
    54  			// with NaN. All their elements should be overwritten by
    55  			// Dlatrd.
    56  			e := make([]float64, n-1)
    57  			for i := range e {
    58  				e[i] = math.NaN()
    59  			}
    60  			tau := make([]float64, n-1)
    61  			for i := range tau {
    62  				tau[i] = math.NaN()
    63  			}
    64  			w := make([]float64, n*ldw)
    65  			for i := range w {
    66  				w[i] = math.NaN()
    67  			}
    68  
    69  			aCopy := make([]float64, len(a))
    70  			copy(aCopy, a)
    71  
    72  			// Reduce nb rows and columns of the symmetric matrix A
    73  			// defined by uplo triangle to symmetric tridiagonal
    74  			// form.
    75  			impl.Dlatrd(uplo, n, nb, a, lda, e, tau, w, ldw)
    76  
    77  			// Construct Q from elementary reflectors stored in
    78  			// columns of A.
    79  			q := blas64.General{
    80  				Rows:   n,
    81  				Cols:   n,
    82  				Stride: n,
    83  				Data:   make([]float64, n*n),
    84  			}
    85  			// Initialize Q to the identity matrix.
    86  			for i := 0; i < n; i++ {
    87  				q.Data[i*q.Stride+i] = 1
    88  			}
    89  			if uplo == blas.Upper {
    90  				for i := n - 1; i >= n-nb; i-- {
    91  					if i == 0 {
    92  						continue
    93  					}
    94  
    95  					// Extract the elementary reflector v from A.
    96  					v := blas64.Vector{
    97  						Inc:  1,
    98  						Data: make([]float64, n),
    99  					}
   100  					for j := 0; j < i-1; j++ {
   101  						v.Data[j] = a[j*lda+i]
   102  					}
   103  					v.Data[i-1] = 1
   104  
   105  					// Compute H = I - tau[i-1] * v * vᵀ.
   106  					h := blas64.General{
   107  						Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n),
   108  					}
   109  					for j := 0; j < n; j++ {
   110  						h.Data[j*n+j] = 1
   111  					}
   112  					blas64.Ger(-tau[i-1], v, v, h)
   113  
   114  					// Update Q <- Q * H.
   115  					qTmp := blas64.General{
   116  						Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n),
   117  					}
   118  					copy(qTmp.Data, q.Data)
   119  					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qTmp, h, 0, q)
   120  				}
   121  			} else {
   122  				for i := 0; i < nb; i++ {
   123  					if i == n-1 {
   124  						continue
   125  					}
   126  
   127  					// Extract the elementary reflector v from A.
   128  					v := blas64.Vector{
   129  						Inc:  1,
   130  						Data: make([]float64, n),
   131  					}
   132  					v.Data[i+1] = 1
   133  					for j := i + 2; j < n; j++ {
   134  						v.Data[j] = a[j*lda+i]
   135  					}
   136  
   137  					// Compute H = I - tau[i] * v * vᵀ.
   138  					h := blas64.General{
   139  						Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n),
   140  					}
   141  					for j := 0; j < n; j++ {
   142  						h.Data[j*n+j] = 1
   143  					}
   144  					blas64.Ger(-tau[i], v, v, h)
   145  
   146  					// Update Q <- Q * H.
   147  					qTmp := blas64.General{
   148  						Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n),
   149  					}
   150  					copy(qTmp.Data, q.Data)
   151  					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qTmp, h, 0, q)
   152  				}
   153  			}
   154  			name := fmt.Sprintf("uplo=%c,n=%v,nb=%v", uplo, n, nb)
   155  			if resid := residualOrthogonal(q, false); resid > tol {
   156  				t.Errorf("Case %v: Q is not orthogonal; resid=%v, want<=%v", name, resid, tol)
   157  			}
   158  			aGen := genFromSym(blas64.Symmetric{N: n, Stride: lda, Uplo: uplo, Data: aCopy})
   159  			if !dlatrdCheckDecomposition(t, uplo, n, nb, e, a, lda, aGen, q, tol) {
   160  				t.Errorf("Case %v: Decomposition mismatch", name)
   161  			}
   162  		}
   163  	}
   164  }
   165  
   166  // dlatrdCheckDecomposition checks that the first nb rows have been successfully
   167  // reduced.
   168  func dlatrdCheckDecomposition(t *testing.T, uplo blas.Uplo, n, nb int, e, a []float64, lda int, aGen, q blas64.General, tol float64) bool {
   169  	// Compute ans = Qᵀ * A * Q.
   170  	// ans should be a tridiagonal matrix in the first or last nb rows and
   171  	// columns, depending on uplo.
   172  	tmp := blas64.General{
   173  		Rows:   n,
   174  		Cols:   n,
   175  		Stride: n,
   176  		Data:   make([]float64, n*n),
   177  	}
   178  	ans := blas64.General{
   179  		Rows:   n,
   180  		Cols:   n,
   181  		Stride: n,
   182  		Data:   make([]float64, n*n),
   183  	}
   184  	blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aGen, 0, tmp)
   185  	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, ans)
   186  
   187  	// Compare the output of Dlatrd (stored in a and e) with the explicit
   188  	// reduction to tridiagonal matrix Qᵀ * A * Q (stored in ans).
   189  	if uplo == blas.Upper {
   190  		for i := n - nb; i < n; i++ {
   191  			for j := 0; j < n; j++ {
   192  				v := ans.Data[i*ans.Stride+j]
   193  				switch {
   194  				case i == j:
   195  					// Diagonal elements of a and ans should match.
   196  					if math.Abs(v-a[i*lda+j]) > tol {
   197  						return false
   198  					}
   199  				case i == j-1:
   200  					// Superdiagonal elements in a should be 1.
   201  					if math.Abs(a[i*lda+j]-1) > tol {
   202  						return false
   203  					}
   204  					// Superdiagonal elements of ans should match e.
   205  					if math.Abs(v-e[i]) > tol {
   206  						return false
   207  					}
   208  				case i == j+1:
   209  				default:
   210  					// All other elements should be 0.
   211  					if math.Abs(v) > tol {
   212  						return false
   213  					}
   214  				}
   215  			}
   216  		}
   217  	} else {
   218  		for i := 0; i < nb; i++ {
   219  			for j := 0; j < n; j++ {
   220  				v := ans.Data[i*ans.Stride+j]
   221  				switch {
   222  				case i == j:
   223  					// Diagonal elements of a and ans should match.
   224  					if math.Abs(v-a[i*lda+j]) > tol {
   225  						return false
   226  					}
   227  				case i == j-1:
   228  				case i == j+1:
   229  					// Subdiagonal elements in a should be 1.
   230  					if math.Abs(a[i*lda+j]-1) > tol {
   231  						return false
   232  					}
   233  					// Subdiagonal elements of ans should match e.
   234  					if math.Abs(v-e[i-1]) > tol {
   235  						return false
   236  					}
   237  				default:
   238  					// All other elements should be 0.
   239  					if math.Abs(v) > tol {
   240  						return false
   241  					}
   242  				}
   243  			}
   244  		}
   245  	}
   246  	return true
   247  }
   248  
   249  // genFromSym constructs a (symmetric) general matrix from the data in the
   250  // symmetric.
   251  // TODO(btracey): Replace other constructions of this with a call to this function.
   252  func genFromSym(a blas64.Symmetric) blas64.General {
   253  	n := a.N
   254  	lda := a.Stride
   255  	uplo := a.Uplo
   256  	b := blas64.General{
   257  		Rows:   n,
   258  		Cols:   n,
   259  		Stride: n,
   260  		Data:   make([]float64, n*n),
   261  	}
   262  
   263  	for i := 0; i < n; i++ {
   264  		for j := i; j < n; j++ {
   265  			v := a.Data[i*lda+j]
   266  			if uplo == blas.Lower {
   267  				v = a.Data[j*lda+i]
   268  			}
   269  			b.Data[i*n+j] = v
   270  			b.Data[j*n+i] = v
   271  		}
   272  	}
   273  	return b
   274  }