gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dlaexc.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "fmt" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/blas" 14 "gonum.org/v1/gonum/blas/blas64" 15 "gonum.org/v1/gonum/lapack" 16 ) 17 18 type Dlaexcer interface { 19 Dlaexc(wantq bool, n int, t []float64, ldt int, q []float64, ldq int, j1, n1, n2 int, work []float64) bool 20 } 21 22 func DlaexcTest(t *testing.T, impl Dlaexcer) { 23 rnd := rand.New(rand.NewSource(1)) 24 25 for _, n := range []int{1, 2, 3, 4, 5, 6, 10, 18, 31, 53} { 26 for _, extra := range []int{0, 3} { 27 for cas := 0; cas < 100; cas++ { 28 testDlaexc(t, impl, rnd, n, extra) 29 } 30 } 31 } 32 } 33 34 func testDlaexc(t *testing.T, impl Dlaexcer, rnd *rand.Rand, n, extra int) { 35 const tol = 1e-14 36 37 // Generate random T in Schur canonical form. 38 tmat, _, _ := randomSchurCanonical(n, n+extra, true, rnd) 39 tmatCopy := cloneGeneral(tmat) 40 41 // Randomly pick the index of the first block. 42 j1 := rnd.Intn(n) 43 if j1 > 0 && tmat.Data[j1*tmat.Stride+j1-1] != 0 { 44 // Adjust j1 if it points to the second row of a 2x2 block. 45 j1-- 46 } 47 // Read sizes of the two blocks based on properties of T. 48 var n1, n2 int 49 switch j1 { 50 case n - 1: 51 n1, n2 = 1, 0 52 case n - 2: 53 if tmat.Data[(j1+1)*tmat.Stride+j1] == 0 { 54 n1, n2 = 1, 1 55 } else { 56 n1, n2 = 2, 0 57 } 58 case n - 3: 59 if tmat.Data[(j1+1)*tmat.Stride+j1] == 0 { 60 n1, n2 = 1, 2 61 } else { 62 n1, n2 = 2, 1 63 } 64 default: 65 if tmat.Data[(j1+1)*tmat.Stride+j1] == 0 { 66 n1 = 1 67 if tmat.Data[(j1+2)*tmat.Stride+j1+1] == 0 { 68 n2 = 1 69 } else { 70 n2 = 2 71 } 72 } else { 73 n1 = 2 74 if tmat.Data[(j1+3)*tmat.Stride+j1+2] == 0 { 75 n2 = 1 76 } else { 77 n2 = 2 78 } 79 } 80 } 81 82 name := fmt.Sprintf("Case n=%v,j1=%v,n1=%v,n2=%v,extra=%v", n, j1, n1, n2, extra) 83 84 // 1. Test without accumulating Q. 85 86 wantq := false 87 88 work := nanSlice(n) 89 90 ok := impl.Dlaexc(wantq, n, tmat.Data, tmat.Stride, nil, 1, j1, n1, n2, work) 91 92 // 2. Test with accumulating Q. 93 94 wantq = true 95 96 tmat2 := cloneGeneral(tmatCopy) 97 q := eye(n, n+extra) 98 qCopy := cloneGeneral(q) 99 work = nanSlice(n) 100 101 ok2 := impl.Dlaexc(wantq, n, tmat2.Data, tmat2.Stride, q.Data, q.Stride, j1, n1, n2, work) 102 103 if !generalOutsideAllNaN(tmat) { 104 t.Errorf("%v: out-of-range write to T", name) 105 } 106 if !generalOutsideAllNaN(tmat2) { 107 t.Errorf("%v: out-of-range write to T2", name) 108 } 109 if !generalOutsideAllNaN(q) { 110 t.Errorf("%v: out-of-range write to Q", name) 111 } 112 113 // Check that outputs from cases 1. and 2. are exactly equal, then check one of them. 114 if ok != ok2 { 115 t.Errorf("%v: ok != ok2", name) 116 } 117 if !equalGeneral(tmat, tmat2) { 118 t.Errorf("%v: T != T2", name) 119 } 120 121 if !ok { 122 if n1 == 1 && n2 == 1 { 123 t.Errorf("%v: unexpected failure", name) 124 } else { 125 t.Logf("%v: Dlaexc returned false", name) 126 } 127 } 128 129 if !ok || n1 == 0 || n2 == 0 || j1+n1 >= n { 130 // Check that T is not modified. 131 if !equalGeneral(tmat, tmatCopy) { 132 t.Errorf("%v: unexpected modification of T", name) 133 } 134 // Check that Q is not modified. 135 if !equalGeneral(q, qCopy) { 136 t.Errorf("%v: unexpected modification of Q", name) 137 } 138 return 139 } 140 141 // Check that T is not modified outside of rows and columns [j1:j1+n1+n2]. 142 for i := 0; i < n; i++ { 143 if j1 <= i && i < j1+n1+n2 { 144 continue 145 } 146 for j := 0; j < n; j++ { 147 if j1 <= j && j < j1+n1+n2 { 148 continue 149 } 150 diff := tmat.Data[i*tmat.Stride+j] - tmatCopy.Data[i*tmatCopy.Stride+j] 151 if diff != 0 { 152 t.Errorf("%v: unexpected modification of T[%v,%v]", name, i, j) 153 } 154 } 155 } 156 157 if !isSchurCanonicalGeneral(tmat) { 158 t.Errorf("%v: T is not in Schur canonical form", name) 159 } 160 161 // Check that Q is orthogonal. 162 resid := residualOrthogonal(q, false) 163 if resid > tol { 164 t.Errorf("%v: Q is not orthogonal; resid=%v, want<=%v", name, resid, tol) 165 } 166 167 // Check that Q is unchanged outside of columns [j1:j1+n1+n2]. 168 for i := 0; i < n; i++ { 169 for j := 0; j < n; j++ { 170 if j1 <= j && j < j1+n1+n2 { 171 continue 172 } 173 diff := q.Data[i*q.Stride+j] - qCopy.Data[i*qCopy.Stride+j] 174 if diff != 0 { 175 t.Errorf("%v: unexpected modification of Q[%v,%v]", name, i, j) 176 } 177 } 178 } 179 180 // Check that Qᵀ * TOrig * Q == T 181 qt := zeros(n, n, n) 182 blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, tmatCopy, 0, qt) 183 qtq := cloneGeneral(tmat) 184 blas64.Gemm(blas.NoTrans, blas.NoTrans, -1, qt, q, 1, qtq) 185 resid = dlange(lapack.MaxColumnSum, n, n, qtq.Data, qtq.Stride) 186 if resid > float64(n)*tol { 187 t.Errorf("%v: mismatch between Qᵀ*(initial T)*Q and (final T); resid=%v, want<=%v", 188 name, resid, float64(n)*tol) 189 } 190 }