github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dlatrd.go (about)

     1  // Copyright ©2016 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"math/rand"
    11  	"testing"
    12  
    13  	"github.com/gonum/blas"
    14  	"github.com/gonum/blas/blas64"
    15  )
    16  
    17  type Dlatrder interface {
    18  	Dlatrd(uplo blas.Uplo, n, nb int, a []float64, lda int, e, tau, w []float64, ldw int)
    19  }
    20  
    21  func DlatrdTest(t *testing.T, impl Dlatrder) {
    22  	rnd := rand.New(rand.NewSource(1))
    23  	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
    24  		for _, test := range []struct {
    25  			n, nb, lda, ldw int
    26  		}{
    27  			{5, 2, 0, 0},
    28  			{5, 5, 0, 0},
    29  
    30  			{5, 3, 10, 11},
    31  			{5, 5, 10, 11},
    32  		} {
    33  			n := test.n
    34  			nb := test.nb
    35  			lda := test.lda
    36  			if lda == 0 {
    37  				lda = n
    38  			}
    39  			ldw := test.ldw
    40  			if ldw == 0 {
    41  				ldw = nb
    42  			}
    43  
    44  			a := make([]float64, n*lda)
    45  			for i := range a {
    46  				a[i] = rnd.NormFloat64()
    47  			}
    48  
    49  			e := make([]float64, n-1)
    50  			for i := range e {
    51  				e[i] = math.NaN()
    52  			}
    53  			tau := make([]float64, n-1)
    54  			for i := range tau {
    55  				tau[i] = math.NaN()
    56  			}
    57  			w := make([]float64, n*ldw)
    58  			for i := range w {
    59  				w[i] = math.NaN()
    60  			}
    61  
    62  			aCopy := make([]float64, len(a))
    63  			copy(aCopy, a)
    64  
    65  			impl.Dlatrd(uplo, n, nb, a, lda, e, tau, w, ldw)
    66  
    67  			// Construct Q.
    68  			ldq := n
    69  			q := blas64.General{
    70  				Rows:   n,
    71  				Cols:   n,
    72  				Stride: ldq,
    73  				Data:   make([]float64, n*ldq),
    74  			}
    75  			for i := 0; i < n; i++ {
    76  				q.Data[i*ldq+i] = 1
    77  			}
    78  			if uplo == blas.Upper {
    79  				for i := n - 1; i >= n-nb; i-- {
    80  					if i == 0 {
    81  						continue
    82  					}
    83  					h := blas64.General{
    84  						Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n),
    85  					}
    86  					for j := 0; j < n; j++ {
    87  						h.Data[j*n+j] = 1
    88  					}
    89  					v := blas64.Vector{
    90  						Inc:  1,
    91  						Data: make([]float64, n),
    92  					}
    93  					for j := 0; j < i-1; j++ {
    94  						v.Data[j] = a[j*lda+i]
    95  					}
    96  					v.Data[i-1] = 1
    97  
    98  					blas64.Ger(-tau[i-1], v, v, h)
    99  
   100  					qTmp := blas64.General{
   101  						Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n),
   102  					}
   103  					copy(qTmp.Data, q.Data)
   104  					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qTmp, h, 0, q)
   105  				}
   106  			} else {
   107  				for i := 0; i < nb; i++ {
   108  					if i == n-1 {
   109  						continue
   110  					}
   111  					h := blas64.General{
   112  						Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n),
   113  					}
   114  					for j := 0; j < n; j++ {
   115  						h.Data[j*n+j] = 1
   116  					}
   117  					v := blas64.Vector{
   118  						Inc:  1,
   119  						Data: make([]float64, n),
   120  					}
   121  					v.Data[i+1] = 1
   122  					for j := i + 2; j < n; j++ {
   123  						v.Data[j] = a[j*lda+i]
   124  					}
   125  					blas64.Ger(-tau[i], v, v, h)
   126  
   127  					qTmp := blas64.General{
   128  						Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n),
   129  					}
   130  					copy(qTmp.Data, q.Data)
   131  					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qTmp, h, 0, q)
   132  				}
   133  			}
   134  			errStr := fmt.Sprintf("isUpper = %v, n = %v, nb = %v", uplo == blas.Upper, n, nb)
   135  			if !isOrthonormal(q) {
   136  				t.Errorf("Q not orthonormal. %s", errStr)
   137  			}
   138  			aGen := genFromSym(blas64.Symmetric{N: n, Stride: lda, Uplo: uplo, Data: aCopy})
   139  			if !dlatrdCheckDecomposition(t, uplo, n, nb, e, tau, a, lda, aGen, q) {
   140  				t.Errorf("Decomposition mismatch. %s", errStr)
   141  			}
   142  		}
   143  	}
   144  }
   145  
   146  // dlatrdCheckDecomposition checks that the first nb rows have been successfully
   147  // reduced.
   148  func dlatrdCheckDecomposition(t *testing.T, uplo blas.Uplo, n, nb int, e, tau, a []float64, lda int, aGen, q blas64.General) bool {
   149  	// Compute Q^T * A * Q.
   150  	tmp := blas64.General{
   151  		Rows:   n,
   152  		Cols:   n,
   153  		Stride: n,
   154  		Data:   make([]float64, n*n),
   155  	}
   156  
   157  	ans := blas64.General{
   158  		Rows:   n,
   159  		Cols:   n,
   160  		Stride: n,
   161  		Data:   make([]float64, n*n),
   162  	}
   163  
   164  	blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aGen, 0, tmp)
   165  	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, ans)
   166  
   167  	// Compare with T.
   168  	if uplo == blas.Upper {
   169  		for i := n - 1; i >= n-nb; i-- {
   170  			for j := 0; j < n; j++ {
   171  				v := ans.Data[i*ans.Stride+j]
   172  				switch {
   173  				case i == j:
   174  					if math.Abs(v-a[i*lda+j]) > 1e-10 {
   175  						return false
   176  					}
   177  				case i == j-1:
   178  					if math.Abs(a[i*lda+j]-1) > 1e-10 {
   179  						return false
   180  					}
   181  					if math.Abs(v-e[i]) > 1e-10 {
   182  						return false
   183  					}
   184  				case i == j+1:
   185  				default:
   186  					if math.Abs(v) > 1e-10 {
   187  						return false
   188  					}
   189  				}
   190  			}
   191  		}
   192  	} else {
   193  		for i := 0; i < nb; i++ {
   194  			for j := 0; j < n; j++ {
   195  				v := ans.Data[i*ans.Stride+j]
   196  				switch {
   197  				case i == j:
   198  					if math.Abs(v-a[i*lda+j]) > 1e-10 {
   199  						return false
   200  					}
   201  				case i == j-1:
   202  				case i == j+1:
   203  					if math.Abs(a[i*lda+j]-1) > 1e-10 {
   204  						return false
   205  					}
   206  					if math.Abs(v-e[i-1]) > 1e-10 {
   207  						return false
   208  					}
   209  				default:
   210  					if math.Abs(v) > 1e-10 {
   211  						return false
   212  					}
   213  				}
   214  			}
   215  		}
   216  	}
   217  	return true
   218  }
   219  
   220  // genFromSym constructs a (symmetric) general matrix from the data in the
   221  // symmetric.
   222  // TODO(btracey): Replace other constructions of this with a call to this function.
   223  func genFromSym(a blas64.Symmetric) blas64.General {
   224  	n := a.N
   225  	lda := a.Stride
   226  	uplo := a.Uplo
   227  	b := blas64.General{
   228  		Rows:   n,
   229  		Cols:   n,
   230  		Stride: n,
   231  		Data:   make([]float64, n*n),
   232  	}
   233  
   234  	for i := 0; i < n; i++ {
   235  		for j := i; j < n; j++ {
   236  			v := a.Data[i*lda+j]
   237  			if uplo == blas.Lower {
   238  				v = a.Data[j*lda+i]
   239  			}
   240  			b.Data[i*n+j] = v
   241  			b.Data[j*n+i] = v
   242  		}
   243  	}
   244  	return b
   245  }