github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/native/dlatrs.go (about) 1 // Copyright ©2015 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package native 6 7 import ( 8 "math" 9 10 "github.com/gonum/blas" 11 "github.com/gonum/blas/blas64" 12 ) 13 14 // Dlatrs solves a triangular system of equations scaled to prevent overflow. It 15 // solves 16 // A * x = scale * b if trans == blas.NoTrans 17 // A^T * x = scale * b if trans == blas.Trans 18 // where the scale s is set for numeric stability. 19 // 20 // A is an n×n triangular matrix. On entry, the slice x contains the values of 21 // of b, and on exit it contains the solution vector x. 22 // 23 // If normin == true, cnorm is an input and cnorm[j] contains the norm of the off-diagonal 24 // part of the j^th column of A. If trans == blas.NoTrans, cnorm[j] must be greater 25 // than or equal to the infinity norm, and greater than or equal to the one-norm 26 // otherwise. If normin == false, then cnorm is treated as an output, and is set 27 // to contain the 1-norm of the off-diagonal part of the j^th column of A. 28 // 29 // Dlatrs is an internal routine. It is exported for testing purposes. 30 func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n int, a []float64, lda int, x []float64, cnorm []float64) (scale float64) { 31 if uplo != blas.Upper && uplo != blas.Lower { 32 panic(badUplo) 33 } 34 if trans != blas.Trans && trans != blas.NoTrans { 35 panic(badTrans) 36 } 37 if diag != blas.Unit && diag != blas.NonUnit { 38 panic(badDiag) 39 } 40 upper := uplo == blas.Upper 41 noTrans := trans == blas.NoTrans 42 nonUnit := diag == blas.NonUnit 43 44 if n < 0 { 45 panic(nLT0) 46 } 47 checkMatrix(n, n, a, lda) 48 checkVector(n, x, 1) 49 checkVector(n, cnorm, 1) 50 51 if n == 0 { 52 return 0 53 } 54 smlnum := dlamchS / dlamchP 55 bignum := 1 / smlnum 56 scale = 1 57 bi := blas64.Implementation() 58 if !normin { 59 if upper { 60 cnorm[0] = 0 61 for j := 1; j < n; j++ { 62 cnorm[j] = bi.Dasum(j, a[j:], lda) 63 } 64 } else { 65 for j := 0; j < n-1; j++ { 66 cnorm[j] = bi.Dasum(n-j-1, a[(j+1)*lda+j:], lda) 67 } 68 cnorm[n-1] = 0 69 } 70 } 71 // Scale the column norms by tscal if the maximum element in cnorm is greater than bignum. 72 imax := bi.Idamax(n, cnorm, 1) 73 tmax := cnorm[imax] 74 var tscal float64 75 if tmax <= bignum { 76 tscal = 1 77 } else { 78 tscal = 1 / (smlnum * tmax) 79 bi.Dscal(n, tscal, cnorm, 1) 80 } 81 82 // Compute a bound on the computed solution vector to see if bi.Dtrsv can be used. 83 j := bi.Idamax(n, x, 1) 84 xmax := math.Abs(x[j]) 85 xbnd := xmax 86 var grow float64 87 var jfirst, jlast, jinc int 88 if noTrans { 89 if upper { 90 jfirst = n - 1 91 jlast = -1 92 jinc = -1 93 } else { 94 jfirst = 0 95 jlast = n 96 jinc = 1 97 } 98 // Compute the growth in A * x = b. 99 if tscal != 1 { 100 grow = 0 101 goto Solve 102 } 103 if nonUnit { 104 grow = 1 / math.Max(xbnd, smlnum) 105 xbnd = grow 106 for j := jfirst; j != jlast; j += jinc { 107 if grow <= smlnum { 108 goto Solve 109 } 110 tjj := math.Abs(a[j*lda+j]) 111 xbnd = math.Min(xbnd, math.Min(1, tjj)*grow) 112 if tjj+cnorm[j] >= smlnum { 113 grow *= tjj / (tjj + cnorm[j]) 114 } else { 115 grow = 0 116 } 117 } 118 grow = xbnd 119 } else { 120 grow = math.Min(1, 1/math.Max(xbnd, smlnum)) 121 for j := jfirst; j != jlast; j += jinc { 122 if grow <= smlnum { 123 goto Solve 124 } 125 grow *= 1 / (1 + cnorm[j]) 126 } 127 } 128 } else { 129 if upper { 130 jfirst = 0 131 jlast = n 132 jinc = 1 133 } else { 134 jfirst = n - 1 135 jlast = -1 136 jinc = -1 137 } 138 if tscal != 1 { 139 grow = 0 140 goto Solve 141 } 142 if nonUnit { 143 grow = 1 / (math.Max(xbnd, smlnum)) 144 xbnd = grow 145 for j := jfirst; j != jlast; j += jinc { 146 if grow <= smlnum { 147 goto Solve 148 } 149 xj := 1 + cnorm[j] 150 grow = math.Min(grow, xbnd/xj) 151 tjj := math.Abs(a[j*lda+j]) 152 if xj > tjj { 153 xbnd *= tjj / xj 154 } 155 } 156 grow = math.Min(grow, xbnd) 157 } else { 158 grow = math.Min(1, 1/math.Max(xbnd, smlnum)) 159 for j := jfirst; j != jlast; j += jinc { 160 if grow <= smlnum { 161 goto Solve 162 } 163 xj := 1 + cnorm[j] 164 grow /= xj 165 } 166 } 167 } 168 169 Solve: 170 if grow*tscal > smlnum { 171 // Use the Level 2 BLAS solve if the reciprocal of the bound on 172 // elements of X is not too small. 173 bi.Dtrsv(uplo, trans, diag, n, a, lda, x, 1) 174 if tscal != 1 { 175 bi.Dscal(n, 1/tscal, cnorm, 1) 176 } 177 return scale 178 } 179 180 // Use a Level 1 BLAS solve, scaling intermediate results. 181 if xmax > bignum { 182 scale = bignum / xmax 183 bi.Dscal(n, scale, x, 1) 184 xmax = bignum 185 } 186 if noTrans { 187 for j := jfirst; j != jlast; j += jinc { 188 xj := math.Abs(x[j]) 189 var tjj, tjjs float64 190 if nonUnit { 191 tjjs = a[j*lda+j] * tscal 192 } else { 193 tjjs = tscal 194 if tscal == 1 { 195 goto Skip1 196 } 197 } 198 tjj = math.Abs(tjjs) 199 if tjj > smlnum { 200 if tjj < 1 { 201 if xj > tjj*bignum { 202 rec := 1 / xj 203 bi.Dscal(n, rec, x, 1) 204 scale *= rec 205 xmax *= rec 206 } 207 } 208 x[j] /= tjjs 209 xj = math.Abs(x[j]) 210 } else if tjj > 0 { 211 if xj > tjj*bignum { 212 rec := (tjj * bignum) / xj 213 if cnorm[j] > 1 { 214 rec /= cnorm[j] 215 } 216 bi.Dscal(n, rec, x, 1) 217 scale *= rec 218 xmax *= rec 219 } 220 x[j] /= tjjs 221 xj = math.Abs(x[j]) 222 } else { 223 for i := 0; i < n; i++ { 224 x[i] = 0 225 } 226 x[j] = 1 227 xj = 1 228 scale = 0 229 xmax = 0 230 } 231 Skip1: 232 if xj > 1 { 233 rec := 1 / xj 234 if cnorm[j] > (bignum-xmax)*rec { 235 rec *= 0.5 236 bi.Dscal(n, rec, x, 1) 237 scale *= rec 238 } 239 } else if xj*cnorm[j] > bignum-xmax { 240 bi.Dscal(n, 0.5, x, 1) 241 scale *= 0.5 242 } 243 if upper { 244 if j > 0 { 245 bi.Daxpy(j, -x[j]*tscal, a[j:], lda, x, 1) 246 i := bi.Idamax(j, x, 1) 247 xmax = math.Abs(x[i]) 248 } 249 } else { 250 if j < n-1 { 251 bi.Daxpy(n-j-1, -x[j]*tscal, a[(j+1)*lda+j:], lda, x[j+1:], 1) 252 i := j + bi.Idamax(n-j-1, x[j+1:], 1) 253 xmax = math.Abs(x[i]) 254 } 255 } 256 } 257 } else { 258 for j := jfirst; j != jlast; j += jinc { 259 xj := math.Abs(x[j]) 260 uscal := tscal 261 rec := 1 / math.Max(xmax, 1) 262 var tjjs float64 263 if cnorm[j] > (bignum-xj)*rec { 264 rec *= 0.5 265 if nonUnit { 266 tjjs = a[j*lda+j] * tscal 267 } else { 268 tjjs = tscal 269 } 270 tjj := math.Abs(tjjs) 271 if tjj > 1 { 272 rec = math.Min(1, rec*tjj) 273 uscal /= tjjs 274 } 275 if rec < 1 { 276 bi.Dscal(n, rec, x, 1) 277 scale *= rec 278 xmax *= rec 279 } 280 } 281 var sumj float64 282 if uscal == 1 { 283 if upper { 284 sumj = bi.Ddot(j, a[j:], lda, x, 1) 285 } else if j < n-1 { 286 sumj = bi.Ddot(n-j-1, a[(j+1)*lda+j:], lda, x[j+1:], 1) 287 } 288 } else { 289 if upper { 290 for i := 0; i < j; i++ { 291 sumj += (a[i*lda+j] * uscal) * x[i] 292 } 293 } else if j < n { 294 for i := j + 1; i < n; i++ { 295 sumj += (a[i*lda+j] * uscal) * x[i] 296 } 297 } 298 } 299 if uscal == tscal { 300 x[j] -= sumj 301 xj := math.Abs(x[j]) 302 var tjjs float64 303 if nonUnit { 304 tjjs = a[j*lda+j] * tscal 305 } else { 306 tjjs = tscal 307 if tscal == 1 { 308 goto Skip2 309 } 310 } 311 tjj := math.Abs(tjjs) 312 if tjj > smlnum { 313 if tjj < 1 { 314 if xj > tjj*bignum { 315 rec = 1 / xj 316 bi.Dscal(n, rec, x, 1) 317 scale *= rec 318 xmax *= rec 319 } 320 } 321 x[j] /= tjjs 322 } else if tjj > 0 { 323 if xj > tjj*bignum { 324 rec = (tjj * bignum) / xj 325 bi.Dscal(n, rec, x, 1) 326 scale *= rec 327 xmax *= rec 328 } 329 x[j] /= tjjs 330 } else { 331 for i := 0; i < n; i++ { 332 x[i] = 0 333 } 334 x[j] = 1 335 scale = 0 336 xmax = 0 337 } 338 } else { 339 x[j] = x[j]/tjjs - sumj 340 } 341 Skip2: 342 xmax = math.Max(xmax, math.Abs(x[j])) 343 } 344 } 345 scale /= tscal 346 if tscal != 1 { 347 bi.Dscal(n, 1/tscal, cnorm, 1) 348 } 349 return scale 350 }