github.com/gopherd/gonum@v0.0.4/lapack/gonum/dlatbs.go (about) 1 // Copyright ©2019 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gonum 6 7 import ( 8 "math" 9 10 "github.com/gopherd/gonum/blas" 11 "github.com/gopherd/gonum/blas/blas64" 12 ) 13 14 // Dlatbs solves a triangular banded system of equations 15 // A * x = s*b if trans == blas.NoTrans 16 // Aᵀ * x = s*b if trans == blas.Trans or blas.ConjTrans 17 // where A is an upper or lower triangular band matrix, x and b are n-element 18 // vectors, and s is a scaling factor chosen so that the components of x will be 19 // less than the overflow threshold. 20 // 21 // On entry, x contains the right-hand side b of the triangular system. 22 // On return, x is overwritten by the solution vector x. 23 // 24 // normin specifies whether the cnorm parameter contains the column norms of A on 25 // entry. If it is true, cnorm[j] contains the norm of the off-diagonal part of 26 // the j-th column of A. If it is false, the norms will be computed and stored 27 // in cnorm. 28 // 29 // Dlatbs returns the scaling factor s for the triangular system. If the matrix 30 // A is singular (A[j,j]==0 for some j), then scale is set to 0 and a 31 // non-trivial solution to A*x = 0 is returned. 32 // 33 // Dlatbs is an internal routine. It is exported for testing purposes. 34 func (Implementation) Dlatbs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n, kd int, ab []float64, ldab int, x, cnorm []float64) (scale float64) { 35 noTran := trans == blas.NoTrans 36 switch { 37 case uplo != blas.Upper && uplo != blas.Lower: 38 panic(badUplo) 39 case !noTran && trans != blas.Trans && trans != blas.ConjTrans: 40 panic(badTrans) 41 case diag != blas.NonUnit && diag != blas.Unit: 42 panic(badDiag) 43 case n < 0: 44 panic(nLT0) 45 case kd < 0: 46 panic(kdLT0) 47 case ldab < kd+1: 48 panic(badLdA) 49 } 50 51 // Quick return if possible. 52 if n == 0 { 53 return 0 54 } 55 56 switch { 57 case len(ab) < (n-1)*ldab+kd+1: 58 panic(shortAB) 59 case len(x) < n: 60 panic(shortX) 61 case len(cnorm) < n: 62 panic(shortCNorm) 63 } 64 65 // Parameters to control overflow. 66 smlnum := dlamchS / dlamchP 67 bignum := 1 / smlnum 68 69 bi := blas64.Implementation() 70 kld := max(1, ldab-1) 71 if !normin { 72 // Compute the 1-norm of each column, not including the diagonal. 73 if uplo == blas.Upper { 74 for j := 0; j < n; j++ { 75 jlen := min(j, kd) 76 if jlen > 0 { 77 cnorm[j] = bi.Dasum(jlen, ab[(j-jlen)*ldab+jlen:], kld) 78 } else { 79 cnorm[j] = 0 80 } 81 } 82 } else { 83 for j := 0; j < n; j++ { 84 jlen := min(n-j-1, kd) 85 if jlen > 0 { 86 cnorm[j] = bi.Dasum(jlen, ab[(j+1)*ldab+kd-1:], kld) 87 } else { 88 cnorm[j] = 0 89 } 90 } 91 } 92 } 93 94 // Set up indices and increments for loops below. 95 var ( 96 jFirst, jLast, jInc int 97 maind int 98 ) 99 if noTran { 100 if uplo == blas.Upper { 101 jFirst = n - 1 102 jLast = -1 103 jInc = -1 104 maind = 0 105 } else { 106 jFirst = 0 107 jLast = n 108 jInc = 1 109 maind = kd 110 } 111 } else { 112 if uplo == blas.Upper { 113 jFirst = 0 114 jLast = n 115 jInc = 1 116 maind = 0 117 } else { 118 jFirst = n - 1 119 jLast = -1 120 jInc = -1 121 maind = kd 122 } 123 } 124 125 // Scale the column norms by tscal if the maximum element in cnorm is 126 // greater than bignum. 127 tmax := cnorm[bi.Idamax(n, cnorm, 1)] 128 tscal := 1.0 129 if tmax > bignum { 130 tscal = 1 / (smlnum * tmax) 131 bi.Dscal(n, tscal, cnorm, 1) 132 } 133 134 // Compute a bound on the computed solution vector to see if the Level 2 135 // BLAS routine Dtbsv can be used. 136 137 xMax := math.Abs(x[bi.Idamax(n, x, 1)]) 138 xBnd := xMax 139 grow := 0.0 140 // Compute the growth only if the maximum element in cnorm is NOT greater 141 // than bignum. 142 if tscal != 1 { 143 goto skipComputeGrow 144 } 145 if noTran { 146 // Compute the growth in A * x = b. 147 if diag == blas.NonUnit { 148 // A is non-unit triangular. 149 // 150 // Compute grow = 1/G_j and xBnd = 1/M_j. 151 // Initially, G_0 = max{x(i), i=1,...,n}. 152 grow = 1 / math.Max(xBnd, smlnum) 153 xBnd = grow 154 for j := jFirst; j != jLast; j += jInc { 155 if grow <= smlnum { 156 // Exit the loop because the growth factor is too small. 157 goto skipComputeGrow 158 } 159 // M_j = G_{j-1} / abs(A[j,j]) 160 tjj := math.Abs(ab[j*ldab+maind]) 161 xBnd = math.Min(xBnd, math.Min(1, tjj)*grow) 162 if tjj+cnorm[j] >= smlnum { 163 // G_j = G_{j-1}*( 1 + cnorm[j] / abs(A[j,j]) ) 164 grow *= tjj / (tjj + cnorm[j]) 165 } else { 166 // G_j could overflow, set grow to 0. 167 grow = 0 168 } 169 } 170 grow = xBnd 171 } else { 172 // A is unit triangular. 173 // 174 // Compute grow = 1/G_j, where G_0 = max{x(i), i=1,...,n}. 175 grow = math.Min(1, 1/math.Max(xBnd, smlnum)) 176 for j := jFirst; j != jLast; j += jInc { 177 if grow <= smlnum { 178 // Exit the loop because the growth factor is too small. 179 goto skipComputeGrow 180 } 181 // G_j = G_{j-1}*( 1 + cnorm[j] ) 182 grow /= 1 + cnorm[j] 183 } 184 } 185 } else { 186 // Compute the growth in Aᵀ * x = b. 187 if diag == blas.NonUnit { 188 // A is non-unit triangular. 189 // 190 // Compute grow = 1/G_j and xBnd = 1/M_j. 191 // Initially, G_0 = max{x(i), i=1,...,n}. 192 grow = 1 / math.Max(xBnd, smlnum) 193 xBnd = grow 194 for j := jFirst; j != jLast; j += jInc { 195 if grow <= smlnum { 196 // Exit the loop because the growth factor is too small. 197 goto skipComputeGrow 198 } 199 // G_j = max( G_{j-1}, M_{j-1}*( 1 + cnorm[j] ) ) 200 xj := 1 + cnorm[j] 201 grow = math.Min(grow, xBnd/xj) 202 // M_j = M_{j-1}*( 1 + cnorm[j] ) / abs(A[j,j]) 203 tjj := math.Abs(ab[j*ldab+maind]) 204 if xj > tjj { 205 xBnd *= tjj / xj 206 } 207 } 208 grow = math.Min(grow, xBnd) 209 } else { 210 // A is unit triangular. 211 // 212 // Compute grow = 1/G_j, where G_0 = max{x(i), i=1,...,n}. 213 grow = math.Min(1, 1/math.Max(xBnd, smlnum)) 214 for j := jFirst; j != jLast; j += jInc { 215 if grow <= smlnum { 216 // Exit the loop because the growth factor is too small. 217 goto skipComputeGrow 218 } 219 // G_j = G_{j-1}*( 1 + cnorm[j] ) 220 grow /= 1 + cnorm[j] 221 } 222 } 223 } 224 skipComputeGrow: 225 226 if grow*tscal > smlnum { 227 // The reciprocal of the bound on elements of X is not too small, use 228 // the Level 2 BLAS solve. 229 bi.Dtbsv(uplo, trans, diag, n, kd, ab, ldab, x, 1) 230 // Scale the column norms by 1/tscal for return. 231 if tscal != 1 { 232 bi.Dscal(n, 1/tscal, cnorm, 1) 233 } 234 return 1 235 } 236 237 // Use a Level 1 BLAS solve, scaling intermediate results. 238 239 scale = 1 240 if xMax > bignum { 241 // Scale x so that its components are less than or equal to bignum in 242 // absolute value. 243 scale = bignum / xMax 244 bi.Dscal(n, scale, x, 1) 245 xMax = bignum 246 } 247 248 if noTran { 249 // Solve A * x = b. 250 for j := jFirst; j != jLast; j += jInc { 251 // Compute x[j] = b[j] / A[j,j], scaling x if necessary. 252 xj := math.Abs(x[j]) 253 tjjs := tscal 254 if diag == blas.NonUnit { 255 tjjs *= ab[j*ldab+maind] 256 } 257 tjj := math.Abs(tjjs) 258 switch { 259 case tjj > smlnum: 260 // smlnum < abs(A[j,j]) 261 if tjj < 1 && xj > tjj*bignum { 262 // Scale x by 1/b[j]. 263 rec := 1 / xj 264 bi.Dscal(n, rec, x, 1) 265 scale *= rec 266 xMax *= rec 267 } 268 x[j] /= tjjs 269 xj = math.Abs(x[j]) 270 case tjj > 0: 271 // 0 < abs(A[j,j]) <= smlnum 272 if xj > tjj*bignum { 273 // Scale x by (1/abs(x[j]))*abs(A[j,j])*bignum to avoid 274 // overflow when dividing by A[j,j]. 275 rec := tjj * bignum / xj 276 if cnorm[j] > 1 { 277 // Scale by 1/cnorm[j] to avoid overflow when 278 // multiplying x[j] times column j. 279 rec /= cnorm[j] 280 } 281 bi.Dscal(n, rec, x, 1) 282 scale *= rec 283 xMax *= rec 284 } 285 x[j] /= tjjs 286 xj = math.Abs(x[j]) 287 default: 288 // A[j,j] == 0: Set x[0:n] = 0, x[j] = 1, and scale = 0, and 289 // compute a solution to A*x = 0. 290 for i := range x[:n] { 291 x[i] = 0 292 } 293 x[j] = 1 294 xj = 1 295 scale = 0 296 xMax = 0 297 } 298 299 // Scale x if necessary to avoid overflow when adding a multiple of 300 // column j of A. 301 switch { 302 case xj > 1: 303 rec := 1 / xj 304 if cnorm[j] > (bignum-xMax)*rec { 305 // Scale x by 1/(2*abs(x[j])). 306 rec *= 0.5 307 bi.Dscal(n, rec, x, 1) 308 scale *= rec 309 } 310 case xj*cnorm[j] > bignum-xMax: 311 // Scale x by 1/2. 312 bi.Dscal(n, 0.5, x, 1) 313 scale *= 0.5 314 } 315 316 if uplo == blas.Upper { 317 if j > 0 { 318 // Compute the update 319 // x[max(0,j-kd):j] := x[max(0,j-kd):j] - x[j] * A[max(0,j-kd):j,j] 320 jlen := min(j, kd) 321 if jlen > 0 { 322 bi.Daxpy(jlen, -x[j]*tscal, ab[(j-jlen)*ldab+jlen:], kld, x[j-jlen:], 1) 323 } 324 i := bi.Idamax(j, x, 1) 325 xMax = math.Abs(x[i]) 326 } 327 } else if j < n-1 { 328 // Compute the update 329 // x[j+1:min(j+kd,n)] := x[j+1:min(j+kd,n)] - x[j] * A[j+1:min(j+kd,n),j] 330 jlen := min(kd, n-j-1) 331 if jlen > 0 { 332 bi.Daxpy(jlen, -x[j]*tscal, ab[(j+1)*ldab+kd-1:], kld, x[j+1:], 1) 333 } 334 i := j + 1 + bi.Idamax(n-j-1, x[j+1:], 1) 335 xMax = math.Abs(x[i]) 336 } 337 } 338 } else { 339 // Solve Aᵀ * x = b. 340 for j := jFirst; j != jLast; j += jInc { 341 // Compute x[j] = b[j] - sum A[k,j]*x[k]. 342 // k!=j 343 xj := math.Abs(x[j]) 344 tjjs := tscal 345 if diag == blas.NonUnit { 346 tjjs *= ab[j*ldab+maind] 347 } 348 tjj := math.Abs(tjjs) 349 rec := 1 / math.Max(1, xMax) 350 uscal := tscal 351 if cnorm[j] > (bignum-xj)*rec { 352 // If x[j] could overflow, scale x by 1/(2*xMax). 353 rec *= 0.5 354 if tjj > 1 { 355 // Divide by A[j,j] when scaling x if A[j,j] > 1. 356 rec = math.Min(1, rec*tjj) 357 uscal /= tjjs 358 } 359 if rec < 1 { 360 bi.Dscal(n, rec, x, 1) 361 scale *= rec 362 xMax *= rec 363 } 364 } 365 366 var sumj float64 367 if uscal == 1 { 368 // If the scaling needed for A in the dot product is 1, call 369 // Ddot to perform the dot product... 370 if uplo == blas.Upper { 371 jlen := min(j, kd) 372 if jlen > 0 { 373 sumj = bi.Ddot(jlen, ab[(j-jlen)*ldab+jlen:], kld, x[j-jlen:], 1) 374 } 375 } else { 376 jlen := min(n-j-1, kd) 377 if jlen > 0 { 378 sumj = bi.Ddot(jlen, ab[(j+1)*ldab+kd-1:], kld, x[j+1:], 1) 379 } 380 } 381 } else { 382 // ...otherwise, use in-line code for the dot product. 383 if uplo == blas.Upper { 384 jlen := min(j, kd) 385 for i := 0; i < jlen; i++ { 386 sumj += (ab[(j-jlen+i)*ldab+jlen-i] * uscal) * x[j-jlen+i] 387 } 388 } else { 389 jlen := min(n-j-1, kd) 390 for i := 0; i < jlen; i++ { 391 sumj += (ab[(j+1+i)*ldab+kd-1-i] * uscal) * x[j+i+1] 392 } 393 } 394 } 395 396 if uscal == tscal { 397 // Compute x[j] := ( x[j] - sumj ) / A[j,j] 398 // if 1/A[j,j] was not used to scale the dot product. 399 x[j] -= sumj 400 xj = math.Abs(x[j]) 401 // Compute x[j] = x[j] / A[j,j], scaling if necessary. 402 // Note: the reference implementation skips this step for blas.Unit matrices 403 // when tscal is equal to 1 but it complicates the logic and only saves 404 // the comparison and division in the first switch-case. Not skipping it 405 // is also consistent with the NoTrans case above. 406 switch { 407 case tjj > smlnum: 408 // smlnum < abs(A[j,j]): 409 if tjj < 1 && xj > tjj*bignum { 410 // Scale x by 1/abs(x[j]). 411 rec := 1 / xj 412 bi.Dscal(n, rec, x, 1) 413 scale *= rec 414 xMax *= rec 415 } 416 x[j] /= tjjs 417 case tjj > 0: 418 // 0 < abs(A[j,j]) <= smlnum: 419 if xj > tjj*bignum { 420 // Scale x by (1/abs(x[j]))*abs(A[j,j])*bignum. 421 rec := (tjj * bignum) / xj 422 bi.Dscal(n, rec, x, 1) 423 scale *= rec 424 xMax *= rec 425 } 426 x[j] /= tjjs 427 default: 428 // A[j,j] == 0: Set x[0:n] = 0, x[j] = 1, and scale = 0, and 429 // compute a solution Aᵀ * x = 0. 430 for i := range x[:n] { 431 x[i] = 0 432 } 433 x[j] = 1 434 scale = 0 435 xMax = 0 436 } 437 } else { 438 // Compute x[j] := x[j] / A[j,j] - sumj 439 // if the dot product has already been divided by 1/A[j,j]. 440 x[j] = x[j]/tjjs - sumj 441 } 442 xMax = math.Max(xMax, math.Abs(x[j])) 443 } 444 scale /= tscal 445 } 446 447 // Scale the column norms by 1/tscal for return. 448 if tscal != 1 { 449 bi.Dscal(n, 1/tscal, cnorm, 1) 450 } 451 return scale 452 }