github.com/gopherd/gonum@v0.0.4/lapack/testlapack/dpstrf.go (about)

     1  // Copyright ©2021 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"testing"
    11  
    12  	"math/rand"
    13  
    14  	"github.com/gopherd/gonum/blas"
    15  	"github.com/gopherd/gonum/blas/blas64"
    16  	"github.com/gopherd/gonum/lapack"
    17  )
    18  
    19  type Dpstrfer interface {
    20  	Dpstrf(uplo blas.Uplo, n int, a []float64, lda int, piv []int, tol float64, work []float64) (rank int, ok bool)
    21  }
    22  
    23  func DpstrfTest(t *testing.T, impl Dpstrfer) {
    24  	rnd := rand.New(rand.NewSource(1))
    25  	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
    26  		t.Run(uploToString(uplo), func(t *testing.T) {
    27  			for _, n := range []int{0, 1, 2, 3, 4, 5, 31, 32, 33, 63, 64, 65, 127, 128, 129} {
    28  				for _, lda := range []int{max(1, n), n + 5} {
    29  					for _, rank := range []int{int(0.7 * float64(n)), n} {
    30  						dpstrfTest(t, impl, rnd, uplo, n, lda, rank)
    31  					}
    32  				}
    33  			}
    34  		})
    35  	}
    36  }
    37  
    38  func dpstrfTest(t *testing.T, impl Dpstrfer, rnd *rand.Rand, uplo blas.Uplo, n, lda, rankWant int) {
    39  	const tol = 1e-13
    40  
    41  	name := fmt.Sprintf("n=%v,lda=%v", n, lda)
    42  	bi := blas64.Implementation()
    43  
    44  	// Generate a random, symmetric A with the given rank by applying rankWant
    45  	// rank-1 updates to the zero matrix.
    46  	a := make([]float64, n*lda)
    47  	for i := 0; i < rankWant; i++ {
    48  		x := randomSlice(n, rnd)
    49  		bi.Dsyr(uplo, n, 1, x, 1, a, lda)
    50  	}
    51  
    52  	// Make a copy of A for storing the factorization.
    53  	aFac := make([]float64, len(a))
    54  	copy(aFac, a)
    55  
    56  	// Allocate a slice for pivots and fill it with invalid index values.
    57  	piv := make([]int, n)
    58  	for i := range piv {
    59  		piv[i] = -1
    60  	}
    61  
    62  	// Allocate the work slice.
    63  	work := make([]float64, 2*n)
    64  
    65  	// Call Dpstrf to Compute the Cholesky factorization with complete pivoting.
    66  	rank, ok := impl.Dpstrf(uplo, n, aFac, lda, piv, -1, work)
    67  
    68  	if ok != (rank == n) {
    69  		t.Errorf("%v: unexpected ok; got %v, want %v", name, ok, rank == n)
    70  	}
    71  	if rank != rankWant {
    72  		t.Errorf("%v: unexpected rank; got %v, want %v", name, rank, rankWant)
    73  	}
    74  
    75  	if n == 0 {
    76  		return
    77  	}
    78  
    79  	// Check that the residual |P*Uᵀ*U*Pᵀ - A| / n or |P*L*Lᵀ*Pᵀ - A| / n is
    80  	// sufficiently small.
    81  	resid := residualDpstrf(uplo, n, a, aFac, lda, rank, piv)
    82  	if resid > tol || math.IsNaN(resid) {
    83  		t.Errorf("%v: residual too large; got %v, want<=%v", name, resid, tol)
    84  	}
    85  }
    86  
    87  func residualDpstrf(uplo blas.Uplo, n int, a, aFac []float64, lda int, rank int, piv []int) float64 {
    88  	bi := blas64.Implementation()
    89  	// Reconstruct the symmetric positive semi-definite matrix A from its L or U
    90  	// factors and the permutation matrix P.
    91  	perm := zeros(n, n, n)
    92  	if uplo == blas.Upper {
    93  		// Change notation.
    94  		u, ldu := aFac, lda
    95  		// Zero out last n-rank rows of the factor U.
    96  		for i := rank; i < n; i++ {
    97  			for j := i; j < n; j++ {
    98  				u[i*ldu+j] = 0
    99  			}
   100  		}
   101  		// Extract U to aRec.
   102  		aRec := zeros(n, n, n)
   103  		for i := 0; i < n; i++ {
   104  			for j := i; j < n; j++ {
   105  				aRec.Data[i*aRec.Stride+j] = u[i*ldu+j]
   106  			}
   107  		}
   108  		// Multiply U by Uᵀ from the left.
   109  		bi.Dtrmm(blas.Left, blas.Upper, blas.Trans, blas.NonUnit, n, n,
   110  			1, u, ldu, aRec.Data, aRec.Stride)
   111  		// Form P * Uᵀ * U * Pᵀ.
   112  		for i := 0; i < n; i++ {
   113  			for j := 0; j < n; j++ {
   114  				if piv[i] > piv[j] {
   115  					// Don't set the lower triangle.
   116  					continue
   117  				}
   118  				if i <= j {
   119  					perm.Data[piv[i]*perm.Stride+piv[j]] = aRec.Data[i*aRec.Stride+j]
   120  				} else {
   121  					perm.Data[piv[i]*perm.Stride+piv[j]] = aRec.Data[j*aRec.Stride+i]
   122  				}
   123  			}
   124  		}
   125  		// Compute the difference P*Uᵀ*U*Pᵀ - A.
   126  		for i := 0; i < n; i++ {
   127  			for j := i; j < n; j++ {
   128  				perm.Data[i*perm.Stride+j] -= a[i*lda+j]
   129  			}
   130  		}
   131  	} else {
   132  		// Change notation.
   133  		l, ldl := aFac, lda
   134  		// Zero out last n-rank columns of the factor L.
   135  		for i := rank; i < n; i++ {
   136  			for j := rank; j <= i; j++ {
   137  				l[i*ldl+j] = 0
   138  			}
   139  		}
   140  		// Extract L to aRec.
   141  		aRec := zeros(n, n, n)
   142  		for i := 0; i < n; i++ {
   143  			for j := 0; j <= i; j++ {
   144  				aRec.Data[i*aRec.Stride+j] = l[i*ldl+j]
   145  			}
   146  		}
   147  		// Multiply L by Lᵀ from the right.
   148  		bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.NonUnit, n, n,
   149  			1, l, ldl, aRec.Data, aRec.Stride)
   150  		// Form P * L * Lᵀ * Pᵀ.
   151  		for i := 0; i < n; i++ {
   152  			for j := 0; j < n; j++ {
   153  				if piv[i] < piv[j] {
   154  					// Don't set the upper triangle.
   155  					continue
   156  				}
   157  				if i >= j {
   158  					perm.Data[piv[i]*perm.Stride+piv[j]] = aRec.Data[i*aRec.Stride+j]
   159  				} else {
   160  					perm.Data[piv[i]*perm.Stride+piv[j]] = aRec.Data[j*aRec.Stride+i]
   161  				}
   162  			}
   163  		}
   164  		// Compute the difference P*L*Lᵀ*Pᵀ - A.
   165  		for i := 0; i < n; i++ {
   166  			for j := 0; j <= i; j++ {
   167  				perm.Data[i*perm.Stride+j] -= a[i*lda+j]
   168  			}
   169  		}
   170  	}
   171  	// Compute |P*Uᵀ*U*Pᵀ - A| / n or |P*L*Lᵀ*Pᵀ - A| / n.
   172  	return dlansy(lapack.MaxColumnSum, uplo, n, perm.Data, perm.Stride) / float64(n)
   173  }