github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/gonum/dlatrs.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gonum 6 7 import ( 8 "math" 9 10 "github.com/jingcheng-WU/gonum/blas" 11 "github.com/jingcheng-WU/gonum/blas/blas64" 12 ) 13 14 // Dlatrs solves a triangular system of equations scaled to prevent overflow. It 15 // solves 16 // A * x = scale * b if trans == blas.NoTrans 17 // Aᵀ * x = scale * b if trans == blas.Trans 18 // where the scale s is set for numeric stability. 19 // 20 // A is an n×n triangular matrix. On entry, the slice x contains the values of 21 // b, and on exit it contains the solution vector x. 22 // 23 // If normin == true, cnorm is an input and cnorm[j] contains the norm of the off-diagonal 24 // part of the j^th column of A. If trans == blas.NoTrans, cnorm[j] must be greater 25 // than or equal to the infinity norm, and greater than or equal to the one-norm 26 // otherwise. If normin == false, then cnorm is treated as an output, and is set 27 // to contain the 1-norm of the off-diagonal part of the j^th column of A. 28 // 29 // Dlatrs is an internal routine. It is exported for testing purposes. 30 func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n int, a []float64, lda int, x []float64, cnorm []float64) (scale float64) { 31 switch { 32 case uplo != blas.Upper && uplo != blas.Lower: 33 panic(badUplo) 34 case trans != blas.NoTrans && trans != blas.Trans && trans != blas.ConjTrans: 35 panic(badTrans) 36 case diag != blas.Unit && diag != blas.NonUnit: 37 panic(badDiag) 38 case n < 0: 39 panic(nLT0) 40 case lda < max(1, n): 41 panic(badLdA) 42 } 43 44 // Quick return if possible. 45 if n == 0 { 46 return 0 47 } 48 49 switch { 50 case len(a) < (n-1)*lda+n: 51 panic(shortA) 52 case len(x) < n: 53 panic(shortX) 54 case len(cnorm) < n: 55 panic(shortCNorm) 56 } 57 58 upper := uplo == blas.Upper 59 nonUnit := diag == blas.NonUnit 60 61 smlnum := dlamchS / dlamchP 62 bignum := 1 / smlnum 63 scale = 1 64 65 bi := blas64.Implementation() 66 67 if !normin { 68 if upper { 69 cnorm[0] = 0 70 for j := 1; j < n; j++ { 71 cnorm[j] = bi.Dasum(j, a[j:], lda) 72 } 73 } else { 74 for j := 0; j < n-1; j++ { 75 cnorm[j] = bi.Dasum(n-j-1, a[(j+1)*lda+j:], lda) 76 } 77 cnorm[n-1] = 0 78 } 79 } 80 // Scale the column norms by tscal if the maximum element in cnorm is greater than bignum. 81 imax := bi.Idamax(n, cnorm, 1) 82 tmax := cnorm[imax] 83 var tscal float64 84 if tmax <= bignum { 85 tscal = 1 86 } else { 87 tscal = 1 / (smlnum * tmax) 88 bi.Dscal(n, tscal, cnorm, 1) 89 } 90 91 // Compute a bound on the computed solution vector to see if bi.Dtrsv can be used. 92 j := bi.Idamax(n, x, 1) 93 xmax := math.Abs(x[j]) 94 xbnd := xmax 95 var grow float64 96 var jfirst, jlast, jinc int 97 if trans == blas.NoTrans { 98 if upper { 99 jfirst = n - 1 100 jlast = -1 101 jinc = -1 102 } else { 103 jfirst = 0 104 jlast = n 105 jinc = 1 106 } 107 // Compute the growth in A * x = b. 108 if tscal != 1 { 109 grow = 0 110 goto Solve 111 } 112 if nonUnit { 113 grow = 1 / math.Max(xbnd, smlnum) 114 xbnd = grow 115 for j := jfirst; j != jlast; j += jinc { 116 if grow <= smlnum { 117 goto Solve 118 } 119 tjj := math.Abs(a[j*lda+j]) 120 xbnd = math.Min(xbnd, math.Min(1, tjj)*grow) 121 if tjj+cnorm[j] >= smlnum { 122 grow *= tjj / (tjj + cnorm[j]) 123 } else { 124 grow = 0 125 } 126 } 127 grow = xbnd 128 } else { 129 grow = math.Min(1, 1/math.Max(xbnd, smlnum)) 130 for j := jfirst; j != jlast; j += jinc { 131 if grow <= smlnum { 132 goto Solve 133 } 134 grow *= 1 / (1 + cnorm[j]) 135 } 136 } 137 } else { 138 if upper { 139 jfirst = 0 140 jlast = n 141 jinc = 1 142 } else { 143 jfirst = n - 1 144 jlast = -1 145 jinc = -1 146 } 147 if tscal != 1 { 148 grow = 0 149 goto Solve 150 } 151 if nonUnit { 152 grow = 1 / (math.Max(xbnd, smlnum)) 153 xbnd = grow 154 for j := jfirst; j != jlast; j += jinc { 155 if grow <= smlnum { 156 goto Solve 157 } 158 xj := 1 + cnorm[j] 159 grow = math.Min(grow, xbnd/xj) 160 tjj := math.Abs(a[j*lda+j]) 161 if xj > tjj { 162 xbnd *= tjj / xj 163 } 164 } 165 grow = math.Min(grow, xbnd) 166 } else { 167 grow = math.Min(1, 1/math.Max(xbnd, smlnum)) 168 for j := jfirst; j != jlast; j += jinc { 169 if grow <= smlnum { 170 goto Solve 171 } 172 xj := 1 + cnorm[j] 173 grow /= xj 174 } 175 } 176 } 177 178 Solve: 179 if grow*tscal > smlnum { 180 // Use the Level 2 BLAS solve if the reciprocal of the bound on 181 // elements of X is not too small. 182 bi.Dtrsv(uplo, trans, diag, n, a, lda, x, 1) 183 if tscal != 1 { 184 bi.Dscal(n, 1/tscal, cnorm, 1) 185 } 186 return scale 187 } 188 189 // Use a Level 1 BLAS solve, scaling intermediate results. 190 if xmax > bignum { 191 scale = bignum / xmax 192 bi.Dscal(n, scale, x, 1) 193 xmax = bignum 194 } 195 if trans == blas.NoTrans { 196 for j := jfirst; j != jlast; j += jinc { 197 xj := math.Abs(x[j]) 198 var tjj, tjjs float64 199 if nonUnit { 200 tjjs = a[j*lda+j] * tscal 201 } else { 202 tjjs = tscal 203 if tscal == 1 { 204 goto Skip1 205 } 206 } 207 tjj = math.Abs(tjjs) 208 if tjj > smlnum { 209 if tjj < 1 { 210 if xj > tjj*bignum { 211 rec := 1 / xj 212 bi.Dscal(n, rec, x, 1) 213 scale *= rec 214 xmax *= rec 215 } 216 } 217 x[j] /= tjjs 218 xj = math.Abs(x[j]) 219 } else if tjj > 0 { 220 if xj > tjj*bignum { 221 rec := (tjj * bignum) / xj 222 if cnorm[j] > 1 { 223 rec /= cnorm[j] 224 } 225 bi.Dscal(n, rec, x, 1) 226 scale *= rec 227 xmax *= rec 228 } 229 x[j] /= tjjs 230 xj = math.Abs(x[j]) 231 } else { 232 for i := 0; i < n; i++ { 233 x[i] = 0 234 } 235 x[j] = 1 236 xj = 1 237 scale = 0 238 xmax = 0 239 } 240 Skip1: 241 if xj > 1 { 242 rec := 1 / xj 243 if cnorm[j] > (bignum-xmax)*rec { 244 rec *= 0.5 245 bi.Dscal(n, rec, x, 1) 246 scale *= rec 247 } 248 } else if xj*cnorm[j] > bignum-xmax { 249 bi.Dscal(n, 0.5, x, 1) 250 scale *= 0.5 251 } 252 if upper { 253 if j > 0 { 254 bi.Daxpy(j, -x[j]*tscal, a[j:], lda, x, 1) 255 i := bi.Idamax(j, x, 1) 256 xmax = math.Abs(x[i]) 257 } 258 } else { 259 if j < n-1 { 260 bi.Daxpy(n-j-1, -x[j]*tscal, a[(j+1)*lda+j:], lda, x[j+1:], 1) 261 i := j + bi.Idamax(n-j-1, x[j+1:], 1) 262 xmax = math.Abs(x[i]) 263 } 264 } 265 } 266 } else { 267 for j := jfirst; j != jlast; j += jinc { 268 xj := math.Abs(x[j]) 269 uscal := tscal 270 rec := 1 / math.Max(xmax, 1) 271 var tjjs float64 272 if cnorm[j] > (bignum-xj)*rec { 273 rec *= 0.5 274 if nonUnit { 275 tjjs = a[j*lda+j] * tscal 276 } else { 277 tjjs = tscal 278 } 279 tjj := math.Abs(tjjs) 280 if tjj > 1 { 281 rec = math.Min(1, rec*tjj) 282 uscal /= tjjs 283 } 284 if rec < 1 { 285 bi.Dscal(n, rec, x, 1) 286 scale *= rec 287 xmax *= rec 288 } 289 } 290 var sumj float64 291 if uscal == 1 { 292 if upper { 293 sumj = bi.Ddot(j, a[j:], lda, x, 1) 294 } else if j < n-1 { 295 sumj = bi.Ddot(n-j-1, a[(j+1)*lda+j:], lda, x[j+1:], 1) 296 } 297 } else { 298 if upper { 299 for i := 0; i < j; i++ { 300 sumj += (a[i*lda+j] * uscal) * x[i] 301 } 302 } else if j < n { 303 for i := j + 1; i < n; i++ { 304 sumj += (a[i*lda+j] * uscal) * x[i] 305 } 306 } 307 } 308 if uscal == tscal { 309 x[j] -= sumj 310 xj := math.Abs(x[j]) 311 var tjjs float64 312 if nonUnit { 313 tjjs = a[j*lda+j] * tscal 314 } else { 315 tjjs = tscal 316 if tscal == 1 { 317 goto Skip2 318 } 319 } 320 tjj := math.Abs(tjjs) 321 if tjj > smlnum { 322 if tjj < 1 { 323 if xj > tjj*bignum { 324 rec = 1 / xj 325 bi.Dscal(n, rec, x, 1) 326 scale *= rec 327 xmax *= rec 328 } 329 } 330 x[j] /= tjjs 331 } else if tjj > 0 { 332 if xj > tjj*bignum { 333 rec = (tjj * bignum) / xj 334 bi.Dscal(n, rec, x, 1) 335 scale *= rec 336 xmax *= rec 337 } 338 x[j] /= tjjs 339 } else { 340 for i := 0; i < n; i++ { 341 x[i] = 0 342 } 343 x[j] = 1 344 scale = 0 345 xmax = 0 346 } 347 } else { 348 x[j] = x[j]/tjjs - sumj 349 } 350 Skip2: 351 xmax = math.Max(xmax, math.Abs(x[j])) 352 } 353 } 354 scale /= tscal 355 if tscal != 1 { 356 bi.Dscal(n, 1/tscal, cnorm, 1) 357 } 358 return scale 359 }