gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dtrexc.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "fmt" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/blas" 14 "gonum.org/v1/gonum/blas/blas64" 15 "gonum.org/v1/gonum/lapack" 16 ) 17 18 type Dtrexcer interface { 19 Dtrexc(compq lapack.UpdateSchurComp, n int, t []float64, ldt int, q []float64, ldq int, ifst, ilst int, work []float64) (ifstOut, ilstOut int, ok bool) 20 } 21 22 func DtrexcTest(t *testing.T, impl Dtrexcer) { 23 rnd := rand.New(rand.NewSource(1)) 24 25 for _, n := range []int{0, 1, 2, 3, 4, 5, 6, 10, 18, 31, 53} { 26 for _, extra := range []int{0, 3} { 27 for cas := 0; cas < 100; cas++ { 28 var ifst, ilst int 29 if n > 0 { 30 ifst = rnd.Intn(n) 31 ilst = rnd.Intn(n) 32 } 33 dtrexcTest(t, impl, rnd, n, ifst, ilst, extra) 34 } 35 } 36 } 37 } 38 39 func dtrexcTest(t *testing.T, impl Dtrexcer, rnd *rand.Rand, n, ifst, ilst, extra int) { 40 const tol = 1e-13 41 42 tmat, _, _ := randomSchurCanonical(n, n+extra, true, rnd) 43 tmatCopy := cloneGeneral(tmat) 44 45 fstSize, fstFirst := schurBlockSize(tmat, ifst) 46 lstSize, lstFirst := schurBlockSize(tmat, ilst) 47 48 name := fmt.Sprintf("Case n=%v,ifst=%v,nbfst=%v,ilst=%v,nblst=%v,extra=%v", 49 n, ifst, fstSize, ilst, lstSize, extra) 50 51 // 1. Test without accumulating Q. 52 53 compq := lapack.UpdateSchurNone 54 55 work := nanSlice(n) 56 57 ifstGot, ilstGot, ok := impl.Dtrexc(compq, n, tmat.Data, tmat.Stride, nil, 1, ifst, ilst, work) 58 59 if !generalOutsideAllNaN(tmat) { 60 t.Errorf("%v: out-of-range write to T", name) 61 } 62 63 // 2. Test with accumulating Q. 64 65 compq = lapack.UpdateSchur 66 67 tmat2 := cloneGeneral(tmatCopy) 68 q := eye(n, n+extra) 69 qCopy := cloneGeneral(q) 70 work = nanSlice(n) 71 72 ifstGot2, ilstGot2, ok2 := impl.Dtrexc(compq, n, tmat2.Data, tmat2.Stride, q.Data, q.Stride, ifst, ilst, work) 73 74 if !generalOutsideAllNaN(tmat2) { 75 t.Errorf("%v: out-of-range write to T2", name) 76 } 77 if !generalOutsideAllNaN(q) { 78 t.Errorf("%v: out-of-range write to Q", name) 79 } 80 81 // Check that outputs from cases 1. and 2. are exactly equal, then check one of them. 82 if ifstGot != ifstGot2 { 83 t.Errorf("%v: ifstGot != ifstGot2", name) 84 } 85 if ilstGot != ilstGot2 { 86 t.Errorf("%v: ilstGot != ilstGot2", name) 87 } 88 if ok != ok2 { 89 t.Errorf("%v: ok != ok2", name) 90 } 91 if !equalGeneral(tmat, tmat2) { 92 t.Errorf("%v: T != T2", name) 93 } 94 95 // Check that the index of the first block was correctly updated (if 96 // necessary). 97 ifstWant := ifst 98 if !fstFirst { 99 ifstWant = ifst - 1 100 } 101 if ifstWant != ifstGot { 102 t.Errorf("%v: unexpected ifst=%v, want %v", name, ifstGot, ifstWant) 103 } 104 105 // Check that the index of the last block is as expected when ok=true. 106 // When ok=false, we don't know at which block the algorithm failed, so 107 // we don't check. 108 ilstWant := ilst 109 if !lstFirst { 110 ilstWant-- 111 } 112 if ok { 113 if ifstWant < ilstWant { 114 // If the blocks are swapped backwards, these 115 // adjustments are not necessary, the first row of the 116 // last block will end up at ifst. 117 switch { 118 case fstSize == 2 && lstSize == 1: 119 ilstWant-- 120 case fstSize == 1 && lstSize == 2: 121 ilstWant++ 122 } 123 } 124 if ilstWant != ilstGot { 125 t.Errorf("%v: unexpected ilst=%v, want %v", name, ilstGot, ilstWant) 126 } 127 } 128 129 if n <= 1 || ifstGot == ilstGot { 130 // Too small matrix or no swapping. 131 // Check that T was not modified. 132 if !equalGeneral(tmat, tmatCopy) { 133 t.Errorf("%v: unexpected modification of T when no swapping", name) 134 } 135 // Check that Q was not modified. 136 if !equalGeneral(q, qCopy) { 137 t.Errorf("%v: unexpected modification of Q when no swapping", name) 138 } 139 // Nothing more to check 140 return 141 } 142 143 if !isSchurCanonicalGeneral(tmat) { 144 t.Errorf("%v: T is not in Schur canonical form", name) 145 } 146 147 // Check that T was not modified except above the second subdiagonal in 148 // rows and columns [modMin,modMax]. 149 modMin := min(ifstGot, ilstGot) 150 modMax := max(ifstGot, ilstGot) + fstSize 151 for i := 0; i < n; i++ { 152 for j := 0; j < n; j++ { 153 if modMin <= i && i < modMax && j+1 >= i { 154 continue 155 } 156 if modMin <= j && j < modMax && j+1 >= i { 157 continue 158 } 159 diff := tmat.Data[i*tmat.Stride+j] - tmatCopy.Data[i*tmatCopy.Stride+j] 160 if diff != 0 { 161 t.Errorf("%v: unexpected modification at T[%v,%v]", name, i, j) 162 } 163 } 164 } 165 166 // Check that Q is orthogonal. 167 resid := residualOrthogonal(q, false) 168 if resid > tol { 169 t.Errorf("%v: Q is not orthogonal; resid=%v, want<=%v", name, resid, tol) 170 } 171 172 // Check that Q is unchanged outside of columns [modMin,modMax]. 173 for i := 0; i < n; i++ { 174 for j := 0; j < n; j++ { 175 if modMin <= j && j < modMax { 176 continue 177 } 178 if q.Data[i*q.Stride+j] != qCopy.Data[i*qCopy.Stride+j] { 179 t.Errorf("%v: unexpected modification of Q[%v,%v]", name, i, j) 180 } 181 } 182 } 183 184 // Check that Qᵀ * TOrig * Q == T 185 qt := zeros(n, n, n) 186 blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, tmatCopy, 0, qt) 187 qtq := cloneGeneral(tmat) 188 blas64.Gemm(blas.NoTrans, blas.NoTrans, -1, qt, q, 1, qtq) 189 resid = dlange(lapack.MaxColumnSum, n, n, qtq.Data, qtq.Stride) 190 if resid > tol { 191 t.Errorf("%v: mismatch between Qᵀ*(initial T)*Q and (final T); resid=%v, want<=%v", 192 name, resid, tol) 193 } 194 }