gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dgetc2.go (about) 1 // Copyright ©2021 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "fmt" 9 "math" 10 "testing" 11 12 "golang.org/x/exp/rand" 13 14 "gonum.org/v1/gonum/blas" 15 "gonum.org/v1/gonum/blas/blas64" 16 "gonum.org/v1/gonum/lapack" 17 ) 18 19 type Dgetc2er interface { 20 Dgetc2(n int, a []float64, lda int, ipiv, jpiv []int) (k int) 21 } 22 23 func Dgetc2Test(t *testing.T, impl Dgetc2er) { 24 rnd := rand.New(rand.NewSource(1)) 25 for _, n := range []int{0, 1, 2, 3, 4, 5, 10, 20} { 26 for _, lda := range []int{n, n + 5} { 27 dgetc2Test(t, impl, rnd, n, lda, false) 28 dgetc2Test(t, impl, rnd, n, lda, true) 29 } 30 } 31 } 32 33 func dgetc2Test(t *testing.T, impl Dgetc2er, rnd *rand.Rand, n, lda int, perturb bool) { 34 const tol = 1e-14 35 36 name := fmt.Sprintf("n=%v,lda=%v,perturb=%v", n, lda, perturb) 37 38 // Generate a random lower-triangular matrix with unit diagonal. 39 l := randomGeneral(n, n, max(1, n), rnd) 40 for i := 0; i < n; i++ { 41 l.Data[i*l.Stride+i] = 1 42 for j := i + 1; j < n; j++ { 43 l.Data[i*l.Stride+j] = 0 44 } 45 } 46 47 // Generate a random upper-triangular matrix. 48 u := randomGeneral(n, n, max(1, n), rnd) 49 for i := 0; i < n; i++ { 50 for j := 0; j < i; j++ { 51 u.Data[i*u.Stride+j] = 0 52 } 53 } 54 if perturb && n > 0 { 55 // Make U singular by randomly placing a zero on the diagonal. 56 i := rnd.Intn(n) 57 u.Data[i*u.Stride+i] = 0 58 } 59 60 // Construct A = L*U. 61 a := zeros(n, n, max(1, lda)) 62 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, l, u, 0, a) 63 64 // Allocate slices for pivots and pre-fill them with invalid indices. 65 ipiv := make([]int, n) 66 jpiv := make([]int, n) 67 for i := 0; i < n; i++ { 68 ipiv[i] = -1 69 jpiv[i] = -1 70 } 71 72 // Call Dgetc2 to compute the LU decomposition. 73 lu := cloneGeneral(a) 74 k := impl.Dgetc2(n, lu.Data, lu.Stride, ipiv, jpiv) 75 76 if n == 0 { 77 return 78 } 79 80 if perturb && k < 0 { 81 t.Errorf("%v: expected matrix perturbation", name) 82 } 83 84 // Verify all indices have been set. 85 for i := 0; i < n; i++ { 86 if ipiv[i] < 0 { 87 t.Errorf("%v: ipiv[%d] is not set", name, i) 88 } 89 if jpiv[i] < 0 { 90 t.Errorf("%v: jpiv[%d] is not set", name, i) 91 } 92 } 93 94 // Construct L and U matrices from Dgetc2 output. 95 l = zeros(n, n, n) 96 u = zeros(n, n, n) 97 for i := 0; i < n; i++ { 98 for j := 0; j < i; j++ { 99 l.Data[i*l.Stride+j] = lu.Data[i*lu.Stride+j] 100 } 101 l.Data[i*l.Stride+i] = 1 102 for j := i; j < n; j++ { 103 u.Data[i*u.Stride+j] = lu.Data[i*lu.Stride+j] 104 } 105 } 106 diff := zeros(n, n, n) 107 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, l, u, 0, diff) 108 109 // Apply permutation matrices P and Q to L*U. 110 for i := n - 1; i >= 0; i-- { 111 ipv := ipiv[i] 112 if ipv != i { 113 row1 := blas64.Vector{N: n, Data: diff.Data[i*diff.Stride:], Inc: 1} 114 row2 := blas64.Vector{N: n, Data: diff.Data[ipv*diff.Stride:], Inc: 1} 115 blas64.Swap(row1, row2) 116 } 117 jpv := jpiv[i] 118 if jpv != i { 119 col1 := blas64.Vector{N: n, Data: diff.Data[i:], Inc: diff.Stride} 120 col2 := blas64.Vector{N: n, Data: diff.Data[jpv:], Inc: diff.Stride} 121 blas64.Swap(col1, col2) 122 } 123 } 124 125 // Compute the residual |P*L*U*Q - A| and check that it is small. 126 for i := 0; i < n; i++ { 127 for j := 0; j < n; j++ { 128 diff.Data[i*diff.Stride+j] -= a.Data[i*a.Stride+j] 129 } 130 } 131 resid := dlange(lapack.MaxColumnSum, n, n, diff.Data, diff.Stride) 132 if resid > tol || math.IsNaN(resid) { 133 t.Errorf("%v: unexpected result; resid=%v, want<=%v", name, resid, tol) 134 } 135 }