gonum.org/v1/gonum@v0.14.0/lapack/gonum/dlatrs.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gonum 6 7 import ( 8 "math" 9 10 "gonum.org/v1/gonum/blas" 11 "gonum.org/v1/gonum/blas/blas64" 12 "gonum.org/v1/gonum/lapack" 13 ) 14 15 // Dlatrs solves a triangular system of equations scaled to prevent overflow. It 16 // solves 17 // 18 // A * x = scale * b if trans == blas.NoTrans 19 // Aᵀ * x = scale * b if trans == blas.Trans 20 // 21 // where the scale s is set for numeric stability. 22 // 23 // A is an n×n triangular matrix. On entry, the slice x contains the values of 24 // b, and on exit it contains the solution vector x. 25 // 26 // If normin == true, cnorm is an input and cnorm[j] contains the norm of the off-diagonal 27 // part of the j^th column of A. If trans == blas.NoTrans, cnorm[j] must be greater 28 // than or equal to the infinity norm, and greater than or equal to the one-norm 29 // otherwise. If normin == false, then cnorm is treated as an output, and is set 30 // to contain the 1-norm of the off-diagonal part of the j^th column of A. 31 // 32 // Dlatrs is an internal routine. It is exported for testing purposes. 33 func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n int, a []float64, lda int, x []float64, cnorm []float64) (scale float64) { 34 switch { 35 case uplo != blas.Upper && uplo != blas.Lower: 36 panic(badUplo) 37 case trans != blas.NoTrans && trans != blas.Trans && trans != blas.ConjTrans: 38 panic(badTrans) 39 case diag != blas.Unit && diag != blas.NonUnit: 40 panic(badDiag) 41 case n < 0: 42 panic(nLT0) 43 case lda < max(1, n): 44 panic(badLdA) 45 } 46 47 // Quick return if possible. 48 if n == 0 { 49 return 1 50 } 51 52 switch { 53 case len(a) < (n-1)*lda+n: 54 panic(shortA) 55 case len(x) < n: 56 panic(shortX) 57 case len(cnorm) < n: 58 panic(shortCNorm) 59 } 60 61 upper := uplo == blas.Upper 62 nonUnit := diag == blas.NonUnit 63 64 smlnum := dlamchS / dlamchP 65 bignum := 1 / smlnum 66 scale = 1 67 68 bi := blas64.Implementation() 69 70 if !normin { 71 if upper { 72 cnorm[0] = 0 73 for j := 1; j < n; j++ { 74 cnorm[j] = bi.Dasum(j, a[j:], lda) 75 } 76 } else { 77 for j := 0; j < n-1; j++ { 78 cnorm[j] = bi.Dasum(n-j-1, a[(j+1)*lda+j:], lda) 79 } 80 cnorm[n-1] = 0 81 } 82 } 83 // Scale the column norms by tscal if the maximum element in cnorm is greater than bignum. 84 imax := bi.Idamax(n, cnorm, 1) 85 var tscal float64 86 if cnorm[imax] <= bignum { 87 tscal = 1 88 } else { 89 tmax := cnorm[imax] 90 // Avoid NaN generation if entries in cnorm exceed the overflow 91 // threshold. 92 if tmax <= math.MaxFloat64 { 93 // Case 1: All entries in cnorm are valid floating-point numbers. 94 tscal = 1 / (smlnum * tmax) 95 bi.Dscal(n, tscal, cnorm, 1) 96 } else { 97 // Case 2: At least one column norm of A cannot be represented as 98 // floating-point number. Find the offdiagonal entry A[i,j] with the 99 // largest absolute value. If this entry is not +/- Infinity, use 100 // this value as tscal. 101 tmax = 0 102 if upper { 103 // A is upper triangular. 104 for j := 1; j < n; j++ { 105 tmax = math.Max(impl.Dlange(lapack.MaxAbs, j, 1, a[j:], lda, nil), tmax) 106 } 107 } else { 108 // A is lower triangular. 109 for j := 0; j < n-1; j++ { 110 tmax = math.Max(impl.Dlange(lapack.MaxAbs, n-j-1, 1, a[(j+1)*lda+j:], lda, nil), tmax) 111 } 112 } 113 if tmax <= math.MaxFloat64 { 114 tscal = 1 / (smlnum * tmax) 115 for j := 0; j < n; j++ { 116 if cnorm[j] <= math.MaxFloat64 { 117 cnorm[j] *= tscal 118 } else { 119 // Recompute the 1-norm without introducing Infinity in 120 // the summation. 121 cnorm[j] = 0 122 if upper { 123 for i := 0; i < j; i++ { 124 cnorm[j] += tscal * math.Abs(a[i*lda+j]) 125 } 126 } else { 127 for i := j + 1; i < n; i++ { 128 cnorm[j] += tscal * math.Abs(a[i*lda+j]) 129 } 130 } 131 } 132 } 133 } else { 134 // At least one entry of A is not a valid floating-point entry. 135 // Rely on Dtrsv to propagate Inf and NaN. 136 bi.Dtrsv(uplo, trans, diag, n, a, lda, x, 1) 137 return 138 } 139 } 140 } 141 142 // Compute a bound on the computed solution vector to see if bi.Dtrsv can be used. 143 j := bi.Idamax(n, x, 1) 144 xmax := math.Abs(x[j]) 145 xbnd := xmax 146 var grow float64 147 var jfirst, jlast, jinc int 148 if trans == blas.NoTrans { 149 if upper { 150 jfirst = n - 1 151 jlast = -1 152 jinc = -1 153 } else { 154 jfirst = 0 155 jlast = n 156 jinc = 1 157 } 158 // Compute the growth in A * x = b. 159 if tscal != 1 { 160 grow = 0 161 goto Solve 162 } 163 if nonUnit { 164 grow = 1 / math.Max(xbnd, smlnum) 165 xbnd = grow 166 for j := jfirst; j != jlast; j += jinc { 167 if grow <= smlnum { 168 goto Solve 169 } 170 tjj := math.Abs(a[j*lda+j]) 171 xbnd = math.Min(xbnd, math.Min(1, tjj)*grow) 172 if tjj+cnorm[j] >= smlnum { 173 grow *= tjj / (tjj + cnorm[j]) 174 } else { 175 grow = 0 176 } 177 } 178 grow = xbnd 179 } else { 180 grow = math.Min(1, 1/math.Max(xbnd, smlnum)) 181 for j := jfirst; j != jlast; j += jinc { 182 if grow <= smlnum { 183 goto Solve 184 } 185 grow *= 1 / (1 + cnorm[j]) 186 } 187 } 188 } else { 189 if upper { 190 jfirst = 0 191 jlast = n 192 jinc = 1 193 } else { 194 jfirst = n - 1 195 jlast = -1 196 jinc = -1 197 } 198 if tscal != 1 { 199 grow = 0 200 goto Solve 201 } 202 if nonUnit { 203 grow = 1 / (math.Max(xbnd, smlnum)) 204 xbnd = grow 205 for j := jfirst; j != jlast; j += jinc { 206 if grow <= smlnum { 207 goto Solve 208 } 209 xj := 1 + cnorm[j] 210 grow = math.Min(grow, xbnd/xj) 211 tjj := math.Abs(a[j*lda+j]) 212 if xj > tjj { 213 xbnd *= tjj / xj 214 } 215 } 216 grow = math.Min(grow, xbnd) 217 } else { 218 grow = math.Min(1, 1/math.Max(xbnd, smlnum)) 219 for j := jfirst; j != jlast; j += jinc { 220 if grow <= smlnum { 221 goto Solve 222 } 223 xj := 1 + cnorm[j] 224 grow /= xj 225 } 226 } 227 } 228 229 Solve: 230 if grow*tscal > smlnum { 231 // Use the Level 2 BLAS solve if the reciprocal of the bound on 232 // elements of X is not too small. 233 bi.Dtrsv(uplo, trans, diag, n, a, lda, x, 1) 234 if tscal != 1 { 235 bi.Dscal(n, 1/tscal, cnorm, 1) 236 } 237 return scale 238 } 239 240 // Use a Level 1 BLAS solve, scaling intermediate results. 241 if xmax > bignum { 242 scale = bignum / xmax 243 bi.Dscal(n, scale, x, 1) 244 xmax = bignum 245 } 246 if trans == blas.NoTrans { 247 for j := jfirst; j != jlast; j += jinc { 248 xj := math.Abs(x[j]) 249 var tjj, tjjs float64 250 if nonUnit { 251 tjjs = a[j*lda+j] * tscal 252 } else { 253 tjjs = tscal 254 if tscal == 1 { 255 goto Skip1 256 } 257 } 258 tjj = math.Abs(tjjs) 259 if tjj > smlnum { 260 if tjj < 1 { 261 if xj > tjj*bignum { 262 rec := 1 / xj 263 bi.Dscal(n, rec, x, 1) 264 scale *= rec 265 xmax *= rec 266 } 267 } 268 x[j] /= tjjs 269 xj = math.Abs(x[j]) 270 } else if tjj > 0 { 271 if xj > tjj*bignum { 272 rec := (tjj * bignum) / xj 273 if cnorm[j] > 1 { 274 rec /= cnorm[j] 275 } 276 bi.Dscal(n, rec, x, 1) 277 scale *= rec 278 xmax *= rec 279 } 280 x[j] /= tjjs 281 xj = math.Abs(x[j]) 282 } else { 283 for i := 0; i < n; i++ { 284 x[i] = 0 285 } 286 x[j] = 1 287 xj = 1 288 scale = 0 289 xmax = 0 290 } 291 Skip1: 292 if xj > 1 { 293 rec := 1 / xj 294 if cnorm[j] > (bignum-xmax)*rec { 295 rec *= 0.5 296 bi.Dscal(n, rec, x, 1) 297 scale *= rec 298 } 299 } else if xj*cnorm[j] > bignum-xmax { 300 bi.Dscal(n, 0.5, x, 1) 301 scale *= 0.5 302 } 303 if upper { 304 if j > 0 { 305 bi.Daxpy(j, -x[j]*tscal, a[j:], lda, x, 1) 306 i := bi.Idamax(j, x, 1) 307 xmax = math.Abs(x[i]) 308 } 309 } else { 310 if j < n-1 { 311 bi.Daxpy(n-j-1, -x[j]*tscal, a[(j+1)*lda+j:], lda, x[j+1:], 1) 312 i := j + bi.Idamax(n-j-1, x[j+1:], 1) 313 xmax = math.Abs(x[i]) 314 } 315 } 316 } 317 } else { 318 for j := jfirst; j != jlast; j += jinc { 319 xj := math.Abs(x[j]) 320 uscal := tscal 321 rec := 1 / math.Max(xmax, 1) 322 var tjjs float64 323 if cnorm[j] > (bignum-xj)*rec { 324 rec *= 0.5 325 if nonUnit { 326 tjjs = a[j*lda+j] * tscal 327 } else { 328 tjjs = tscal 329 } 330 tjj := math.Abs(tjjs) 331 if tjj > 1 { 332 rec = math.Min(1, rec*tjj) 333 uscal /= tjjs 334 } 335 if rec < 1 { 336 bi.Dscal(n, rec, x, 1) 337 scale *= rec 338 xmax *= rec 339 } 340 } 341 var sumj float64 342 if uscal == 1 { 343 if upper { 344 sumj = bi.Ddot(j, a[j:], lda, x, 1) 345 } else if j < n-1 { 346 sumj = bi.Ddot(n-j-1, a[(j+1)*lda+j:], lda, x[j+1:], 1) 347 } 348 } else { 349 if upper { 350 for i := 0; i < j; i++ { 351 sumj += (a[i*lda+j] * uscal) * x[i] 352 } 353 } else if j < n { 354 for i := j + 1; i < n; i++ { 355 sumj += (a[i*lda+j] * uscal) * x[i] 356 } 357 } 358 } 359 if uscal == tscal { 360 x[j] -= sumj 361 xj := math.Abs(x[j]) 362 var tjjs float64 363 if nonUnit { 364 tjjs = a[j*lda+j] * tscal 365 } else { 366 tjjs = tscal 367 if tscal == 1 { 368 goto Skip2 369 } 370 } 371 tjj := math.Abs(tjjs) 372 if tjj > smlnum { 373 if tjj < 1 { 374 if xj > tjj*bignum { 375 rec = 1 / xj 376 bi.Dscal(n, rec, x, 1) 377 scale *= rec 378 xmax *= rec 379 } 380 } 381 x[j] /= tjjs 382 } else if tjj > 0 { 383 if xj > tjj*bignum { 384 rec = (tjj * bignum) / xj 385 bi.Dscal(n, rec, x, 1) 386 scale *= rec 387 xmax *= rec 388 } 389 x[j] /= tjjs 390 } else { 391 for i := 0; i < n; i++ { 392 x[i] = 0 393 } 394 x[j] = 1 395 scale = 0 396 xmax = 0 397 } 398 } else { 399 x[j] = x[j]/tjjs - sumj 400 } 401 Skip2: 402 xmax = math.Max(xmax, math.Abs(x[j])) 403 } 404 } 405 scale /= tscal 406 if tscal != 1 { 407 bi.Dscal(n, 1/tscal, cnorm, 1) 408 } 409 return scale 410 }