github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dbdsqr.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math/rand"
    10  	"sort"
    11  	"testing"
    12  
    13  	"github.com/gonum/blas"
    14  	"github.com/gonum/blas/blas64"
    15  	"github.com/gonum/floats"
    16  )
    17  
    18  type Dbdsqrer interface {
    19  	Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, vt []float64, ldvt int, u []float64, ldu int, c []float64, ldc int, work []float64) (ok bool)
    20  }
    21  
    22  func DbdsqrTest(t *testing.T, impl Dbdsqrer) {
    23  	rnd := rand.New(rand.NewSource(1))
    24  	bi := blas64.Implementation()
    25  	_ = bi
    26  	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
    27  		for _, test := range []struct {
    28  			n, ncvt, nru, ncc, ldvt, ldu, ldc int
    29  		}{
    30  			{5, 5, 5, 5, 0, 0, 0},
    31  			{10, 10, 10, 10, 0, 0, 0},
    32  			{10, 11, 12, 13, 0, 0, 0},
    33  			{20, 13, 12, 11, 0, 0, 0},
    34  
    35  			{5, 5, 5, 5, 6, 7, 8},
    36  			{10, 10, 10, 10, 30, 40, 50},
    37  			{10, 12, 11, 13, 30, 40, 50},
    38  			{20, 12, 13, 11, 30, 40, 50},
    39  
    40  			{130, 130, 130, 500, 900, 900, 500},
    41  		} {
    42  			for cas := 0; cas < 100; cas++ {
    43  				n := test.n
    44  				ncvt := test.ncvt
    45  				nru := test.nru
    46  				ncc := test.ncc
    47  				ldvt := test.ldvt
    48  				ldu := test.ldu
    49  				ldc := test.ldc
    50  				if ldvt == 0 {
    51  					ldvt = ncvt
    52  				}
    53  				if ldu == 0 {
    54  					ldu = n
    55  				}
    56  				if ldc == 0 {
    57  					ldc = ncc
    58  				}
    59  
    60  				d := make([]float64, n)
    61  				for i := range d {
    62  					d[i] = rnd.NormFloat64()
    63  				}
    64  				e := make([]float64, n-1)
    65  				for i := range e {
    66  					e[i] = rnd.NormFloat64()
    67  				}
    68  				dCopy := make([]float64, len(d))
    69  				copy(dCopy, d)
    70  				eCopy := make([]float64, len(e))
    71  				copy(eCopy, e)
    72  				work := make([]float64, 4*n)
    73  				for i := range work {
    74  					work[i] = rnd.NormFloat64()
    75  				}
    76  
    77  				// First test the decomposition of the bidiagonal matrix. Set
    78  				// pt and u equal to I with the correct size. At the result
    79  				// of Dbdsqr, p and u  will contain the data of P^T and Q, which
    80  				// will be used in the next step to test the multiplication
    81  				// with Q and VT.
    82  
    83  				q := make([]float64, n*n)
    84  				ldq := n
    85  				pt := make([]float64, n*n)
    86  				ldpt := n
    87  				for i := 0; i < n; i++ {
    88  					q[i*ldq+i] = 1
    89  				}
    90  				for i := 0; i < n; i++ {
    91  					pt[i*ldpt+i] = 1
    92  				}
    93  
    94  				ok := impl.Dbdsqr(uplo, n, n, n, 0, d, e, pt, ldpt, q, ldq, nil, 0, work)
    95  
    96  				isUpper := uplo == blas.Upper
    97  				errStr := fmt.Sprintf("isUpper = %v, n = %v, ncvt = %v, nru = %v, ncc = %v", isUpper, n, ncvt, nru, ncc)
    98  				if !ok {
    99  					t.Errorf("Unexpected Dbdsqr failure: %s", errStr)
   100  				}
   101  
   102  				bMat := constructBidiagonal(uplo, n, dCopy, eCopy)
   103  				sMat := constructBidiagonal(uplo, n, d, e)
   104  
   105  				tmp := blas64.General{
   106  					Rows:   n,
   107  					Cols:   n,
   108  					Stride: n,
   109  					Data:   make([]float64, n*n),
   110  				}
   111  				ansMat := blas64.General{
   112  					Rows:   n,
   113  					Cols:   n,
   114  					Stride: n,
   115  					Data:   make([]float64, n*n),
   116  				}
   117  
   118  				bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, q, ldq, sMat.Data, sMat.Stride, 0, tmp.Data, tmp.Stride)
   119  				bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, tmp.Data, tmp.Stride, pt, ldpt, 0, ansMat.Data, ansMat.Stride)
   120  
   121  				same := true
   122  				for i := 0; i < n; i++ {
   123  					for j := 0; j < n; j++ {
   124  						if !floats.EqualWithinAbsOrRel(ansMat.Data[i*ansMat.Stride+j], bMat.Data[i*bMat.Stride+j], 1e-8, 1e-8) {
   125  							same = false
   126  						}
   127  					}
   128  				}
   129  				if !same {
   130  					t.Errorf("Bidiagonal mismatch. %s", errStr)
   131  				}
   132  				if !sort.IsSorted(sort.Reverse(sort.Float64Slice(d))) {
   133  					t.Errorf("D is not sorted. %s", errStr)
   134  				}
   135  
   136  				// The above computed the real P and Q. Now input data for V^T,
   137  				// U, and C to check that the multiplications happen properly.
   138  				dAns := make([]float64, len(d))
   139  				copy(dAns, d)
   140  				eAns := make([]float64, len(e))
   141  				copy(eAns, e)
   142  
   143  				u := make([]float64, nru*ldu)
   144  				for i := range u {
   145  					u[i] = rnd.NormFloat64()
   146  				}
   147  				uCopy := make([]float64, len(u))
   148  				copy(uCopy, u)
   149  				vt := make([]float64, n*ldvt)
   150  				for i := range vt {
   151  					vt[i] = rnd.NormFloat64()
   152  				}
   153  				vtCopy := make([]float64, len(vt))
   154  				copy(vtCopy, vt)
   155  				c := make([]float64, n*ldc)
   156  				for i := range c {
   157  					c[i] = rnd.NormFloat64()
   158  				}
   159  				cCopy := make([]float64, len(c))
   160  				copy(cCopy, c)
   161  
   162  				// Reset input data
   163  				copy(d, dCopy)
   164  				copy(e, eCopy)
   165  				impl.Dbdsqr(uplo, n, ncvt, nru, ncc, d, e, vt, ldvt, u, ldu, c, ldc, work)
   166  
   167  				// Check result.
   168  				if !floats.EqualApprox(d, dAns, 1e-14) {
   169  					t.Errorf("D mismatch second time. %s", errStr)
   170  				}
   171  				if !floats.EqualApprox(e, eAns, 1e-14) {
   172  					t.Errorf("E mismatch second time. %s", errStr)
   173  				}
   174  				ans := make([]float64, len(vtCopy))
   175  				copy(ans, vtCopy)
   176  				ldans := ldvt
   177  				bi.Dgemm(blas.NoTrans, blas.NoTrans, n, ncvt, n, 1, pt, ldpt, vtCopy, ldvt, 0, ans, ldans)
   178  				if !floats.EqualApprox(ans, vt, 1e-10) {
   179  					t.Errorf("Vt result mismatch. %s", errStr)
   180  				}
   181  				ans = make([]float64, len(uCopy))
   182  				copy(ans, uCopy)
   183  				ldans = ldu
   184  				bi.Dgemm(blas.NoTrans, blas.NoTrans, nru, n, n, 1, uCopy, ldu, q, ldq, 0, ans, ldans)
   185  				if !floats.EqualApprox(ans, u, 1e-10) {
   186  					t.Errorf("U result mismatch. %s", errStr)
   187  				}
   188  				ans = make([]float64, len(cCopy))
   189  				copy(ans, cCopy)
   190  				ldans = ldc
   191  				bi.Dgemm(blas.Trans, blas.NoTrans, n, ncc, n, 1, q, ldq, cCopy, ldc, 0, ans, ldans)
   192  				if !floats.EqualApprox(ans, c, 1e-10) {
   193  					t.Errorf("C result mismatch. %s", errStr)
   194  				}
   195  			}
   196  		}
   197  	}
   198  }