github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/native/dlasy2.go (about) 1 // Copyright ©2016 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package native 6 7 import ( 8 "math" 9 10 "github.com/gonum/blas/blas64" 11 ) 12 13 // Dlasy2 solves the Sylvester matrix equation where the matrices are of order 1 14 // or 2. It computes the unknown n1×n2 matrix X so that 15 // TL*X + sgn*X*TR = scale*B, if tranl == false and tranr == false, 16 // TL^T*X + sgn*X*TR = scale*B, if tranl == true and tranr == false, 17 // TL*X + sgn*X*TR^T = scale*B, if tranl == false and tranr == true, 18 // TL^T*X + sgn*X*TR^T = scale*B, if tranl == true and tranr == true, 19 // where TL is n1×n1, TR is n2×n2, B is n1×n2, and 1 <= n1,n2 <= 2. 20 // 21 // isgn must be 1 or -1, and n1 and n2 must be 0, 1, or 2, but these conditions 22 // are not checked. 23 // 24 // Dlasy2 returns three values, a scale factor that is chosen less than or equal 25 // to 1 to prevent the solution overflowing, the infinity norm of the solution, 26 // and an indicator of success. If ok is false, TL and TR have eigenvalues that 27 // are too close, so TL or TR is perturbed to get a non-singular equation. 28 // 29 // Dlasy2 is an internal routine. It is exported for testing purposes. 30 func (impl Implementation) Dlasy2(tranl, tranr bool, isgn, n1, n2 int, tl []float64, ldtl int, tr []float64, ldtr int, b []float64, ldb int, x []float64, ldx int) (scale, xnorm float64, ok bool) { 31 // TODO(vladimir-ch): Add input validation checks conditionally skipped 32 // using the build tag mechanism. 33 34 ok = true 35 // Quick return if possible. 36 if n1 == 0 || n2 == 0 { 37 return scale, xnorm, ok 38 } 39 40 // Set constants to control overflow. 41 eps := dlamchP 42 smlnum := dlamchS / eps 43 sgn := float64(isgn) 44 45 if n1 == 1 && n2 == 1 { 46 // 1×1 case: TL11*X + sgn*X*TR11 = B11. 47 tau1 := tl[0] + sgn*tr[0] 48 bet := math.Abs(tau1) 49 if bet <= smlnum { 50 tau1 = smlnum 51 bet = smlnum 52 ok = false 53 } 54 scale = 1 55 gam := math.Abs(b[0]) 56 if smlnum*gam > bet { 57 scale = 1 / gam 58 } 59 x[0] = b[0] * scale / tau1 60 xnorm = math.Abs(x[0]) 61 return scale, xnorm, ok 62 } 63 64 if n1+n2 == 3 { 65 // 1×2 or 2×1 case. 66 var ( 67 smin float64 68 tmp [4]float64 // tmp is used as a 2×2 row-major matrix. 69 btmp [2]float64 70 ) 71 if n1 == 1 && n2 == 2 { 72 // 1×2 case: TL11*[X11 X12] + sgn*[X11 X12]*op[TR11 TR12] = [B11 B12]. 73 // [TR21 TR22] 74 smin = math.Abs(tl[0]) 75 smin = math.Max(smin, math.Max(math.Abs(tr[0]), math.Abs(tr[1]))) 76 smin = math.Max(smin, math.Max(math.Abs(tr[ldtr]), math.Abs(tr[ldtr+1]))) 77 smin = math.Max(eps*smin, smlnum) 78 tmp[0] = tl[0] + sgn*tr[0] 79 tmp[3] = tl[0] + sgn*tr[ldtr+1] 80 if tranr { 81 tmp[1] = sgn * tr[1] 82 tmp[2] = sgn * tr[ldtr] 83 } else { 84 tmp[1] = sgn * tr[ldtr] 85 tmp[2] = sgn * tr[1] 86 } 87 btmp[0] = b[0] 88 btmp[1] = b[1] 89 } else { 90 // 2×1 case: op[TL11 TL12]*[X11] + sgn*[X11]*TR11 = [B11]. 91 // [TL21 TL22]*[X21] [X21] [B21] 92 smin = math.Abs(tr[0]) 93 smin = math.Max(smin, math.Max(math.Abs(tl[0]), math.Abs(tl[1]))) 94 smin = math.Max(smin, math.Max(math.Abs(tl[ldtl]), math.Abs(tl[ldtl+1]))) 95 smin = math.Max(eps*smin, smlnum) 96 tmp[0] = tl[0] + sgn*tr[0] 97 tmp[3] = tl[ldtl+1] + sgn*tr[0] 98 if tranl { 99 tmp[1] = tl[ldtl] 100 tmp[2] = tl[1] 101 } else { 102 tmp[1] = tl[1] 103 tmp[2] = tl[ldtl] 104 } 105 btmp[0] = b[0] 106 btmp[1] = b[ldb] 107 } 108 109 // Solve 2×2 system using complete pivoting. 110 // Set pivots less than smin to smin. 111 112 bi := blas64.Implementation() 113 ipiv := bi.Idamax(len(tmp), tmp[:], 1) 114 // Compute the upper triangular matrix [u11 u12]. 115 // [ 0 u22] 116 u11 := tmp[ipiv] 117 if math.Abs(u11) <= smin { 118 ok = false 119 u11 = smin 120 } 121 locu12 := [4]int{1, 0, 3, 2} // Index in tmp of the element on the same row as the pivot. 122 u12 := tmp[locu12[ipiv]] 123 locl21 := [4]int{2, 3, 0, 1} // Index in tmp of the element on the same column as the pivot. 124 l21 := tmp[locl21[ipiv]] / u11 125 locu22 := [4]int{3, 2, 1, 0} // Index in tmp of the remaining element. 126 u22 := tmp[locu22[ipiv]] - l21*u12 127 if math.Abs(u22) <= smin { 128 ok = false 129 u22 = smin 130 } 131 if ipiv&0x2 != 0 { // true for ipiv equal to 2 and 3. 132 // The pivot was in the second row, swap the elements of 133 // the right-hand side. 134 btmp[0], btmp[1] = btmp[1], btmp[0]-l21*btmp[1] 135 } else { 136 btmp[1] -= l21 * btmp[0] 137 } 138 scale = 1 139 if 2*smlnum*math.Abs(btmp[1]) > math.Abs(u22) || 2*smlnum*math.Abs(btmp[0]) > math.Abs(u11) { 140 scale = 0.5 / math.Max(math.Abs(btmp[0]), math.Abs(btmp[1])) 141 btmp[0] *= scale 142 btmp[1] *= scale 143 } 144 // Solve the system [u11 u12] [x21] = [ btmp[0] ]. 145 // [ 0 u22] [x22] [ btmp[1] ] 146 x22 := btmp[1] / u22 147 x21 := btmp[0]/u11 - (u12/u11)*x22 148 if ipiv&0x1 != 0 { // true for ipiv equal to 1 and 3. 149 // The pivot was in the second column, swap the elements 150 // of the solution. 151 x21, x22 = x22, x21 152 } 153 x[0] = x21 154 if n1 == 1 { 155 x[1] = x22 156 xnorm = math.Abs(x[0]) + math.Abs(x[1]) 157 } else { 158 x[ldx] = x22 159 xnorm = math.Max(math.Abs(x[0]), math.Abs(x[ldx])) 160 } 161 return scale, xnorm, ok 162 } 163 164 // 2×2 case: op[TL11 TL12]*[X11 X12] + SGN*[X11 X12]*op[TR11 TR12] = [B11 B12]. 165 // [TL21 TL22] [X21 X22] [X21 X22] [TR21 TR22] [B21 B22] 166 // 167 // Solve equivalent 4×4 system using complete pivoting. 168 // Set pivots less than smin to smin. 169 170 smin := math.Max(math.Abs(tr[0]), math.Abs(tr[1])) 171 smin = math.Max(smin, math.Max(math.Abs(tr[ldtr]), math.Abs(tr[ldtr+1]))) 172 smin = math.Max(smin, math.Max(math.Abs(tl[0]), math.Abs(tl[1]))) 173 smin = math.Max(smin, math.Max(math.Abs(tl[ldtl]), math.Abs(tl[ldtl+1]))) 174 smin = math.Max(eps*smin, smlnum) 175 176 var t [4][4]float64 177 t[0][0] = tl[0] + sgn*tr[0] 178 t[1][1] = tl[0] + sgn*tr[ldtr+1] 179 t[2][2] = tl[ldtl+1] + sgn*tr[0] 180 t[3][3] = tl[ldtl+1] + sgn*tr[ldtr+1] 181 if tranl { 182 t[0][2] = tl[ldtl] 183 t[1][3] = tl[ldtl] 184 t[2][0] = tl[1] 185 t[3][1] = tl[1] 186 } else { 187 t[0][2] = tl[1] 188 t[1][3] = tl[1] 189 t[2][0] = tl[ldtl] 190 t[3][1] = tl[ldtl] 191 } 192 if tranr { 193 t[0][1] = sgn * tr[1] 194 t[1][0] = sgn * tr[ldtr] 195 t[2][3] = sgn * tr[1] 196 t[3][2] = sgn * tr[ldtr] 197 } else { 198 t[0][1] = sgn * tr[ldtr] 199 t[1][0] = sgn * tr[1] 200 t[2][3] = sgn * tr[ldtr] 201 t[3][2] = sgn * tr[1] 202 } 203 204 var btmp [4]float64 205 btmp[0] = b[0] 206 btmp[1] = b[1] 207 btmp[2] = b[ldb] 208 btmp[3] = b[ldb+1] 209 210 // Perform elimination. 211 var jpiv [4]int // jpiv records any column swaps for pivoting. 212 for i := 0; i < 3; i++ { 213 var ( 214 xmax float64 215 ipsv, jpsv int 216 ) 217 for ip := i; ip < 4; ip++ { 218 for jp := i; jp < 4; jp++ { 219 if math.Abs(t[ip][jp]) >= xmax { 220 xmax = math.Abs(t[ip][jp]) 221 ipsv = ip 222 jpsv = jp 223 } 224 } 225 } 226 if ipsv != i { 227 // The pivot is not in the top row of the unprocessed 228 // block, swap rows ipsv and i of t and btmp. 229 t[ipsv], t[i] = t[i], t[ipsv] 230 btmp[ipsv], btmp[i] = btmp[i], btmp[ipsv] 231 } 232 if jpsv != i { 233 // The pivot is not in the left column of the 234 // unprocessed block, swap columns jpsv and i of t. 235 for k := 0; k < 4; k++ { 236 t[k][jpsv], t[k][i] = t[k][i], t[k][jpsv] 237 } 238 } 239 jpiv[i] = jpsv 240 if math.Abs(t[i][i]) < smin { 241 ok = false 242 t[i][i] = smin 243 } 244 for k := i + 1; k < 4; k++ { 245 t[k][i] /= t[i][i] 246 btmp[k] -= t[k][i] * btmp[i] 247 for j := i + 1; j < 4; j++ { 248 t[k][j] -= t[k][i] * t[i][j] 249 } 250 } 251 } 252 if math.Abs(t[3][3]) < smin { 253 ok = false 254 t[3][3] = smin 255 } 256 scale = 1 257 if 8*smlnum*math.Abs(btmp[0]) > math.Abs(t[0][0]) || 258 8*smlnum*math.Abs(btmp[1]) > math.Abs(t[1][1]) || 259 8*smlnum*math.Abs(btmp[2]) > math.Abs(t[2][2]) || 260 8*smlnum*math.Abs(btmp[3]) > math.Abs(t[3][3]) { 261 262 maxbtmp := math.Max(math.Abs(btmp[0]), math.Abs(btmp[1])) 263 maxbtmp = math.Max(maxbtmp, math.Max(math.Abs(btmp[2]), math.Abs(btmp[3]))) 264 scale = 1 / 8 / maxbtmp 265 btmp[0] *= scale 266 btmp[1] *= scale 267 btmp[2] *= scale 268 btmp[3] *= scale 269 } 270 // Compute the solution of the upper triangular system t * tmp = btmp. 271 var tmp [4]float64 272 for i := 3; i >= 0; i-- { 273 temp := 1 / t[i][i] 274 tmp[i] = btmp[i] * temp 275 for j := i + 1; j < 4; j++ { 276 tmp[i] -= temp * t[i][j] * tmp[j] 277 } 278 } 279 for i := 2; i >= 0; i-- { 280 if jpiv[i] != i { 281 tmp[i], tmp[jpiv[i]] = tmp[jpiv[i]], tmp[i] 282 } 283 } 284 x[0] = tmp[0] 285 x[1] = tmp[1] 286 x[ldx] = tmp[2] 287 x[ldx+1] = tmp[3] 288 xnorm = math.Max(math.Abs(tmp[0])+math.Abs(tmp[1]), math.Abs(tmp[2])+math.Abs(tmp[3])) 289 return scale, xnorm, ok 290 }