gonum.org/v1/gonum@v0.14.0/lapack/gonum/dlatbs.go (about) 1 // Copyright ©2019 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gonum 6 7 import ( 8 "math" 9 10 "gonum.org/v1/gonum/blas" 11 "gonum.org/v1/gonum/blas/blas64" 12 ) 13 14 // Dlatbs solves a triangular banded system of equations 15 // 16 // A * x = s*b if trans == blas.NoTrans 17 // Aᵀ * x = s*b if trans == blas.Trans or blas.ConjTrans 18 // 19 // where A is an upper or lower triangular band matrix, x and b are n-element 20 // vectors, and s is a scaling factor chosen so that the components of x will be 21 // less than the overflow threshold. 22 // 23 // On entry, x contains the right-hand side b of the triangular system. 24 // On return, x is overwritten by the solution vector x. 25 // 26 // normin specifies whether the cnorm parameter contains the column norms of A on 27 // entry. If it is true, cnorm[j] contains the norm of the off-diagonal part of 28 // the j-th column of A. If it is false, the norms will be computed and stored 29 // in cnorm. 30 // 31 // Dlatbs returns the scaling factor s for the triangular system. If the matrix 32 // A is singular (A[j,j]==0 for some j), then scale is set to 0 and a 33 // non-trivial solution to A*x = 0 is returned. 34 // 35 // Dlatbs is an internal routine. It is exported for testing purposes. 36 func (Implementation) Dlatbs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n, kd int, ab []float64, ldab int, x, cnorm []float64) (scale float64) { 37 noTran := trans == blas.NoTrans 38 switch { 39 case uplo != blas.Upper && uplo != blas.Lower: 40 panic(badUplo) 41 case !noTran && trans != blas.Trans && trans != blas.ConjTrans: 42 panic(badTrans) 43 case diag != blas.NonUnit && diag != blas.Unit: 44 panic(badDiag) 45 case n < 0: 46 panic(nLT0) 47 case kd < 0: 48 panic(kdLT0) 49 case ldab < kd+1: 50 panic(badLdA) 51 } 52 53 // Quick return if possible. 54 if n == 0 { 55 return 1 56 } 57 58 switch { 59 case len(ab) < (n-1)*ldab+kd+1: 60 panic(shortAB) 61 case len(x) < n: 62 panic(shortX) 63 case len(cnorm) < n: 64 panic(shortCNorm) 65 } 66 67 // Parameters to control overflow. 68 smlnum := dlamchS / dlamchP 69 bignum := 1 / smlnum 70 71 bi := blas64.Implementation() 72 kld := max(1, ldab-1) 73 if !normin { 74 // Compute the 1-norm of each column, not including the diagonal. 75 if uplo == blas.Upper { 76 for j := 0; j < n; j++ { 77 jlen := min(j, kd) 78 if jlen > 0 { 79 cnorm[j] = bi.Dasum(jlen, ab[(j-jlen)*ldab+jlen:], kld) 80 } else { 81 cnorm[j] = 0 82 } 83 } 84 } else { 85 for j := 0; j < n; j++ { 86 jlen := min(n-j-1, kd) 87 if jlen > 0 { 88 cnorm[j] = bi.Dasum(jlen, ab[(j+1)*ldab+kd-1:], kld) 89 } else { 90 cnorm[j] = 0 91 } 92 } 93 } 94 } 95 96 // Set up indices and increments for loops below. 97 var ( 98 jFirst, jLast, jInc int 99 maind int 100 ) 101 if noTran { 102 if uplo == blas.Upper { 103 jFirst = n - 1 104 jLast = -1 105 jInc = -1 106 maind = 0 107 } else { 108 jFirst = 0 109 jLast = n 110 jInc = 1 111 maind = kd 112 } 113 } else { 114 if uplo == blas.Upper { 115 jFirst = 0 116 jLast = n 117 jInc = 1 118 maind = 0 119 } else { 120 jFirst = n - 1 121 jLast = -1 122 jInc = -1 123 maind = kd 124 } 125 } 126 127 // Scale the column norms by tscal if the maximum element in cnorm is 128 // greater than bignum. 129 tmax := cnorm[bi.Idamax(n, cnorm, 1)] 130 tscal := 1.0 131 if tmax > bignum { 132 tscal = 1 / (smlnum * tmax) 133 bi.Dscal(n, tscal, cnorm, 1) 134 } 135 136 // Compute a bound on the computed solution vector to see if the Level 2 137 // BLAS routine Dtbsv can be used. 138 139 xMax := math.Abs(x[bi.Idamax(n, x, 1)]) 140 xBnd := xMax 141 grow := 0.0 142 // Compute the growth only if the maximum element in cnorm is NOT greater 143 // than bignum. 144 if tscal != 1 { 145 goto skipComputeGrow 146 } 147 if noTran { 148 // Compute the growth in A * x = b. 149 if diag == blas.NonUnit { 150 // A is non-unit triangular. 151 // 152 // Compute grow = 1/G_j and xBnd = 1/M_j. 153 // Initially, G_0 = max{x(i), i=1,...,n}. 154 grow = 1 / math.Max(xBnd, smlnum) 155 xBnd = grow 156 for j := jFirst; j != jLast; j += jInc { 157 if grow <= smlnum { 158 // Exit the loop because the growth factor is too small. 159 goto skipComputeGrow 160 } 161 // M_j = G_{j-1} / abs(A[j,j]) 162 tjj := math.Abs(ab[j*ldab+maind]) 163 xBnd = math.Min(xBnd, math.Min(1, tjj)*grow) 164 if tjj+cnorm[j] >= smlnum { 165 // G_j = G_{j-1}*( 1 + cnorm[j] / abs(A[j,j]) ) 166 grow *= tjj / (tjj + cnorm[j]) 167 } else { 168 // G_j could overflow, set grow to 0. 169 grow = 0 170 } 171 } 172 grow = xBnd 173 } else { 174 // A is unit triangular. 175 // 176 // Compute grow = 1/G_j, where G_0 = max{x(i), i=1,...,n}. 177 grow = math.Min(1, 1/math.Max(xBnd, smlnum)) 178 for j := jFirst; j != jLast; j += jInc { 179 if grow <= smlnum { 180 // Exit the loop because the growth factor is too small. 181 goto skipComputeGrow 182 } 183 // G_j = G_{j-1}*( 1 + cnorm[j] ) 184 grow /= 1 + cnorm[j] 185 } 186 } 187 } else { 188 // Compute the growth in Aᵀ * x = b. 189 if diag == blas.NonUnit { 190 // A is non-unit triangular. 191 // 192 // Compute grow = 1/G_j and xBnd = 1/M_j. 193 // Initially, G_0 = max{x(i), i=1,...,n}. 194 grow = 1 / math.Max(xBnd, smlnum) 195 xBnd = grow 196 for j := jFirst; j != jLast; j += jInc { 197 if grow <= smlnum { 198 // Exit the loop because the growth factor is too small. 199 goto skipComputeGrow 200 } 201 // G_j = max( G_{j-1}, M_{j-1}*( 1 + cnorm[j] ) ) 202 xj := 1 + cnorm[j] 203 grow = math.Min(grow, xBnd/xj) 204 // M_j = M_{j-1}*( 1 + cnorm[j] ) / abs(A[j,j]) 205 tjj := math.Abs(ab[j*ldab+maind]) 206 if xj > tjj { 207 xBnd *= tjj / xj 208 } 209 } 210 grow = math.Min(grow, xBnd) 211 } else { 212 // A is unit triangular. 213 // 214 // Compute grow = 1/G_j, where G_0 = max{x(i), i=1,...,n}. 215 grow = math.Min(1, 1/math.Max(xBnd, smlnum)) 216 for j := jFirst; j != jLast; j += jInc { 217 if grow <= smlnum { 218 // Exit the loop because the growth factor is too small. 219 goto skipComputeGrow 220 } 221 // G_j = G_{j-1}*( 1 + cnorm[j] ) 222 grow /= 1 + cnorm[j] 223 } 224 } 225 } 226 skipComputeGrow: 227 228 if grow*tscal > smlnum { 229 // The reciprocal of the bound on elements of X is not too small, use 230 // the Level 2 BLAS solve. 231 bi.Dtbsv(uplo, trans, diag, n, kd, ab, ldab, x, 1) 232 // Scale the column norms by 1/tscal for return. 233 if tscal != 1 { 234 bi.Dscal(n, 1/tscal, cnorm, 1) 235 } 236 return 1 237 } 238 239 // Use a Level 1 BLAS solve, scaling intermediate results. 240 241 scale = 1 242 if xMax > bignum { 243 // Scale x so that its components are less than or equal to bignum in 244 // absolute value. 245 scale = bignum / xMax 246 bi.Dscal(n, scale, x, 1) 247 xMax = bignum 248 } 249 250 if noTran { 251 // Solve A * x = b. 252 for j := jFirst; j != jLast; j += jInc { 253 // Compute x[j] = b[j] / A[j,j], scaling x if necessary. 254 xj := math.Abs(x[j]) 255 tjjs := tscal 256 if diag == blas.NonUnit { 257 tjjs *= ab[j*ldab+maind] 258 } 259 tjj := math.Abs(tjjs) 260 switch { 261 case tjj > smlnum: 262 // smlnum < abs(A[j,j]) 263 if tjj < 1 && xj > tjj*bignum { 264 // Scale x by 1/b[j]. 265 rec := 1 / xj 266 bi.Dscal(n, rec, x, 1) 267 scale *= rec 268 xMax *= rec 269 } 270 x[j] /= tjjs 271 xj = math.Abs(x[j]) 272 case tjj > 0: 273 // 0 < abs(A[j,j]) <= smlnum 274 if xj > tjj*bignum { 275 // Scale x by (1/abs(x[j]))*abs(A[j,j])*bignum to avoid 276 // overflow when dividing by A[j,j]. 277 rec := tjj * bignum / xj 278 if cnorm[j] > 1 { 279 // Scale by 1/cnorm[j] to avoid overflow when 280 // multiplying x[j] times column j. 281 rec /= cnorm[j] 282 } 283 bi.Dscal(n, rec, x, 1) 284 scale *= rec 285 xMax *= rec 286 } 287 x[j] /= tjjs 288 xj = math.Abs(x[j]) 289 default: 290 // A[j,j] == 0: Set x[0:n] = 0, x[j] = 1, and scale = 0, and 291 // compute a solution to A*x = 0. 292 for i := range x[:n] { 293 x[i] = 0 294 } 295 x[j] = 1 296 xj = 1 297 scale = 0 298 xMax = 0 299 } 300 301 // Scale x if necessary to avoid overflow when adding a multiple of 302 // column j of A. 303 switch { 304 case xj > 1: 305 rec := 1 / xj 306 if cnorm[j] > (bignum-xMax)*rec { 307 // Scale x by 1/(2*abs(x[j])). 308 rec *= 0.5 309 bi.Dscal(n, rec, x, 1) 310 scale *= rec 311 } 312 case xj*cnorm[j] > bignum-xMax: 313 // Scale x by 1/2. 314 bi.Dscal(n, 0.5, x, 1) 315 scale *= 0.5 316 } 317 318 if uplo == blas.Upper { 319 if j > 0 { 320 // Compute the update 321 // x[max(0,j-kd):j] := x[max(0,j-kd):j] - x[j] * A[max(0,j-kd):j,j] 322 jlen := min(j, kd) 323 if jlen > 0 { 324 bi.Daxpy(jlen, -x[j]*tscal, ab[(j-jlen)*ldab+jlen:], kld, x[j-jlen:], 1) 325 } 326 i := bi.Idamax(j, x, 1) 327 xMax = math.Abs(x[i]) 328 } 329 } else if j < n-1 { 330 // Compute the update 331 // x[j+1:min(j+kd,n)] := x[j+1:min(j+kd,n)] - x[j] * A[j+1:min(j+kd,n),j] 332 jlen := min(kd, n-j-1) 333 if jlen > 0 { 334 bi.Daxpy(jlen, -x[j]*tscal, ab[(j+1)*ldab+kd-1:], kld, x[j+1:], 1) 335 } 336 i := j + 1 + bi.Idamax(n-j-1, x[j+1:], 1) 337 xMax = math.Abs(x[i]) 338 } 339 } 340 } else { 341 // Solve Aᵀ * x = b. 342 for j := jFirst; j != jLast; j += jInc { 343 // Compute x[j] = b[j] - sum A[k,j]*x[k]. 344 // k!=j 345 xj := math.Abs(x[j]) 346 tjjs := tscal 347 if diag == blas.NonUnit { 348 tjjs *= ab[j*ldab+maind] 349 } 350 tjj := math.Abs(tjjs) 351 rec := 1 / math.Max(1, xMax) 352 uscal := tscal 353 if cnorm[j] > (bignum-xj)*rec { 354 // If x[j] could overflow, scale x by 1/(2*xMax). 355 rec *= 0.5 356 if tjj > 1 { 357 // Divide by A[j,j] when scaling x if A[j,j] > 1. 358 rec = math.Min(1, rec*tjj) 359 uscal /= tjjs 360 } 361 if rec < 1 { 362 bi.Dscal(n, rec, x, 1) 363 scale *= rec 364 xMax *= rec 365 } 366 } 367 368 var sumj float64 369 if uscal == 1 { 370 // If the scaling needed for A in the dot product is 1, call 371 // Ddot to perform the dot product... 372 if uplo == blas.Upper { 373 jlen := min(j, kd) 374 if jlen > 0 { 375 sumj = bi.Ddot(jlen, ab[(j-jlen)*ldab+jlen:], kld, x[j-jlen:], 1) 376 } 377 } else { 378 jlen := min(n-j-1, kd) 379 if jlen > 0 { 380 sumj = bi.Ddot(jlen, ab[(j+1)*ldab+kd-1:], kld, x[j+1:], 1) 381 } 382 } 383 } else { 384 // ...otherwise, use in-line code for the dot product. 385 if uplo == blas.Upper { 386 jlen := min(j, kd) 387 for i := 0; i < jlen; i++ { 388 sumj += (ab[(j-jlen+i)*ldab+jlen-i] * uscal) * x[j-jlen+i] 389 } 390 } else { 391 jlen := min(n-j-1, kd) 392 for i := 0; i < jlen; i++ { 393 sumj += (ab[(j+1+i)*ldab+kd-1-i] * uscal) * x[j+i+1] 394 } 395 } 396 } 397 398 if uscal == tscal { 399 // Compute x[j] := ( x[j] - sumj ) / A[j,j] 400 // if 1/A[j,j] was not used to scale the dot product. 401 x[j] -= sumj 402 xj = math.Abs(x[j]) 403 // Compute x[j] = x[j] / A[j,j], scaling if necessary. 404 // Note: the reference implementation skips this step for blas.Unit matrices 405 // when tscal is equal to 1 but it complicates the logic and only saves 406 // the comparison and division in the first switch-case. Not skipping it 407 // is also consistent with the NoTrans case above. 408 switch { 409 case tjj > smlnum: 410 // smlnum < abs(A[j,j]): 411 if tjj < 1 && xj > tjj*bignum { 412 // Scale x by 1/abs(x[j]). 413 rec := 1 / xj 414 bi.Dscal(n, rec, x, 1) 415 scale *= rec 416 xMax *= rec 417 } 418 x[j] /= tjjs 419 case tjj > 0: 420 // 0 < abs(A[j,j]) <= smlnum: 421 if xj > tjj*bignum { 422 // Scale x by (1/abs(x[j]))*abs(A[j,j])*bignum. 423 rec := (tjj * bignum) / xj 424 bi.Dscal(n, rec, x, 1) 425 scale *= rec 426 xMax *= rec 427 } 428 x[j] /= tjjs 429 default: 430 // A[j,j] == 0: Set x[0:n] = 0, x[j] = 1, and scale = 0, and 431 // compute a solution Aᵀ * x = 0. 432 for i := range x[:n] { 433 x[i] = 0 434 } 435 x[j] = 1 436 scale = 0 437 xMax = 0 438 } 439 } else { 440 // Compute x[j] := x[j] / A[j,j] - sumj 441 // if the dot product has already been divided by 1/A[j,j]. 442 x[j] = x[j]/tjjs - sumj 443 } 444 xMax = math.Max(xMax, math.Abs(x[j])) 445 } 446 scale /= tscal 447 } 448 449 // Scale the column norms by 1/tscal for return. 450 if tscal != 1 { 451 bi.Dscal(n, 1/tscal, cnorm, 1) 452 } 453 return scale 454 }