gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/blas/blas32/blas32.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package blas32 6 7 import ( 8 "gonum.org/v1/gonum/blas" 9 "gonum.org/v1/gonum/blas/gonum" 10 ) 11 12 var blas32 blas.Float32 = gonum.Implementation{} 13 14 // Use sets the BLAS float32 implementation to be used by subsequent BLAS calls. 15 // The default implementation is 16 // gonum.org/v1/gonum/blas/gonum.Implementation. 17 func Use(b blas.Float32) { 18 blas32 = b 19 } 20 21 // Implementation returns the current BLAS float32 implementation. 22 // 23 // Implementation allows direct calls to the current the BLAS float32 implementation 24 // giving finer control of parameters. 25 func Implementation() blas.Float32 { 26 return blas32 27 } 28 29 // Vector represents a vector with an associated element increment. 30 type Vector struct { 31 N int 32 Inc int 33 Data []float32 34 } 35 36 // General represents a matrix using the conventional storage scheme. 37 type General struct { 38 Rows, Cols int 39 Stride int 40 Data []float32 41 } 42 43 // Band represents a band matrix using the band storage scheme. 44 type Band struct { 45 Rows, Cols int 46 KL, KU int 47 Stride int 48 Data []float32 49 } 50 51 // Triangular represents a triangular matrix using the conventional storage scheme. 52 type Triangular struct { 53 N int 54 Stride int 55 Data []float32 56 Uplo blas.Uplo 57 Diag blas.Diag 58 } 59 60 // TriangularBand represents a triangular matrix using the band storage scheme. 61 type TriangularBand struct { 62 N, K int 63 Stride int 64 Data []float32 65 Uplo blas.Uplo 66 Diag blas.Diag 67 } 68 69 // TriangularPacked represents a triangular matrix using the packed storage scheme. 70 type TriangularPacked struct { 71 N int 72 Data []float32 73 Uplo blas.Uplo 74 Diag blas.Diag 75 } 76 77 // Symmetric represents a symmetric matrix using the conventional storage scheme. 78 type Symmetric struct { 79 N int 80 Stride int 81 Data []float32 82 Uplo blas.Uplo 83 } 84 85 // SymmetricBand represents a symmetric matrix using the band storage scheme. 86 type SymmetricBand struct { 87 N, K int 88 Stride int 89 Data []float32 90 Uplo blas.Uplo 91 } 92 93 // SymmetricPacked represents a symmetric matrix using the packed storage scheme. 94 type SymmetricPacked struct { 95 N int 96 Data []float32 97 Uplo blas.Uplo 98 } 99 100 // Level 1 101 102 const ( 103 negInc = "blas32: negative vector increment" 104 badLength = "blas32: vector length mismatch" 105 ) 106 107 // Dot computes the dot product of the two vectors: 108 // 109 // \sum_i x[i]*y[i]. 110 // 111 // Dot will panic if the lengths of x and y do not match. 112 func Dot(x, y Vector) float32 { 113 if x.N != y.N { 114 panic(badLength) 115 } 116 return blas32.Sdot(x.N, x.Data, x.Inc, y.Data, y.Inc) 117 } 118 119 // DDot computes the dot product of the two vectors: 120 // 121 // \sum_i x[i]*y[i]. 122 // 123 // DDot will panic if the lengths of x and y do not match. 124 func DDot(x, y Vector) float64 { 125 if x.N != y.N { 126 panic(badLength) 127 } 128 return blas32.Dsdot(x.N, x.Data, x.Inc, y.Data, y.Inc) 129 } 130 131 // SDDot computes the dot product of the two vectors adding a constant: 132 // 133 // alpha + \sum_i x[i]*y[i]. 134 // 135 // SDDot will panic if the lengths of x and y do not match. 136 func SDDot(alpha float32, x, y Vector) float32 { 137 if x.N != y.N { 138 panic(badLength) 139 } 140 return blas32.Sdsdot(x.N, alpha, x.Data, x.Inc, y.Data, y.Inc) 141 } 142 143 // Nrm2 computes the Euclidean norm of the vector x: 144 // 145 // sqrt(\sum_i x[i]*x[i]). 146 // 147 // Nrm2 will panic if the vector increment is negative. 148 func Nrm2(x Vector) float32 { 149 if x.Inc < 0 { 150 panic(negInc) 151 } 152 return blas32.Snrm2(x.N, x.Data, x.Inc) 153 } 154 155 // Asum computes the sum of the absolute values of the elements of x: 156 // 157 // \sum_i |x[i]|. 158 // 159 // Asum will panic if the vector increment is negative. 160 func Asum(x Vector) float32 { 161 if x.Inc < 0 { 162 panic(negInc) 163 } 164 return blas32.Sasum(x.N, x.Data, x.Inc) 165 } 166 167 // Iamax returns the index of an element of x with the largest absolute value. 168 // If there are multiple such indices the earliest is returned. 169 // Iamax returns -1 if n == 0. 170 // 171 // Iamax will panic if the vector increment is negative. 172 func Iamax(x Vector) int { 173 if x.Inc < 0 { 174 panic(negInc) 175 } 176 return blas32.Isamax(x.N, x.Data, x.Inc) 177 } 178 179 // Swap exchanges the elements of the two vectors: 180 // 181 // x[i], y[i] = y[i], x[i] for all i. 182 // 183 // Swap will panic if the lengths of x and y do not match. 184 func Swap(x, y Vector) { 185 if x.N != y.N { 186 panic(badLength) 187 } 188 blas32.Sswap(x.N, x.Data, x.Inc, y.Data, y.Inc) 189 } 190 191 // Copy copies the elements of x into the elements of y: 192 // 193 // y[i] = x[i] for all i. 194 // 195 // Copy will panic if the lengths of x and y do not match. 196 func Copy(x, y Vector) { 197 if x.N != y.N { 198 panic(badLength) 199 } 200 blas32.Scopy(x.N, x.Data, x.Inc, y.Data, y.Inc) 201 } 202 203 // Axpy adds x scaled by alpha to y: 204 // 205 // y[i] += alpha*x[i] for all i. 206 // 207 // Axpy will panic if the lengths of x and y do not match. 208 func Axpy(alpha float32, x, y Vector) { 209 if x.N != y.N { 210 panic(badLength) 211 } 212 blas32.Saxpy(x.N, alpha, x.Data, x.Inc, y.Data, y.Inc) 213 } 214 215 // Rotg computes the parameters of a Givens plane rotation so that 216 // 217 // ⎡ c s⎤ ⎡a⎤ ⎡r⎤ 218 // ⎣-s c⎦ * ⎣b⎦ = ⎣0⎦ 219 // 220 // where a and b are the Cartesian coordinates of a given point. 221 // c, s, and r are defined as 222 // 223 // r = ±Sqrt(a^2 + b^2), 224 // c = a/r, the cosine of the rotation angle, 225 // s = a/r, the sine of the rotation angle, 226 // 227 // and z is defined such that 228 // 229 // if |a| > |b|, z = s, 230 // otherwise if c != 0, z = 1/c, 231 // otherwise z = 1. 232 func Rotg(a, b float32) (c, s, r, z float32) { 233 return blas32.Srotg(a, b) 234 } 235 236 // Rotmg computes the modified Givens rotation. See 237 // http://www.netlib.org/lapack/explore-html/df/deb/drotmg_8f.html 238 // for more details. 239 func Rotmg(d1, d2, b1, b2 float32) (p blas.SrotmParams, rd1, rd2, rb1 float32) { 240 return blas32.Srotmg(d1, d2, b1, b2) 241 } 242 243 // Rot applies a plane transformation to n points represented by the vectors x 244 // and y: 245 // 246 // x[i] = c*x[i] + s*y[i], 247 // y[i] = -s*x[i] + c*y[i], for all i. 248 func Rot(n int, x, y Vector, c, s float32) { 249 blas32.Srot(n, x.Data, x.Inc, y.Data, y.Inc, c, s) 250 } 251 252 // Rotm applies the modified Givens rotation to n points represented by the 253 // vectors x and y. 254 func Rotm(n int, x, y Vector, p blas.SrotmParams) { 255 blas32.Srotm(n, x.Data, x.Inc, y.Data, y.Inc, p) 256 } 257 258 // Scal scales the vector x by alpha: 259 // 260 // x[i] *= alpha for all i. 261 // 262 // Scal will panic if the vector increment is negative. 263 func Scal(alpha float32, x Vector) { 264 if x.Inc < 0 { 265 panic(negInc) 266 } 267 blas32.Sscal(x.N, alpha, x.Data, x.Inc) 268 } 269 270 // Level 2 271 272 // Gemv computes 273 // 274 // y = alpha * A * x + beta * y if t == blas.NoTrans, 275 // y = alpha * Aᵀ * x + beta * y if t == blas.Trans or blas.ConjTrans, 276 // 277 // where A is an m×n dense matrix, x and y are vectors, and alpha and beta are scalars. 278 func Gemv(t blas.Transpose, alpha float32, a General, x Vector, beta float32, y Vector) { 279 blas32.Sgemv(t, a.Rows, a.Cols, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 280 } 281 282 // Gbmv computes 283 // 284 // y = alpha * A * x + beta * y if t == blas.NoTrans, 285 // y = alpha * Aᵀ * x + beta * y if t == blas.Trans or blas.ConjTrans, 286 // 287 // where A is an m×n band matrix, x and y are vectors, and alpha and beta are scalars. 288 func Gbmv(t blas.Transpose, alpha float32, a Band, x Vector, beta float32, y Vector) { 289 blas32.Sgbmv(t, a.Rows, a.Cols, a.KL, a.KU, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 290 } 291 292 // Trmv computes 293 // 294 // x = A * x if t == blas.NoTrans, 295 // x = Aᵀ * x if t == blas.Trans or blas.ConjTrans, 296 // 297 // where A is an n×n triangular matrix, and x is a vector. 298 func Trmv(t blas.Transpose, a Triangular, x Vector) { 299 blas32.Strmv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc) 300 } 301 302 // Tbmv computes 303 // 304 // x = A * x if t == blas.NoTrans, 305 // x = Aᵀ * x if t == blas.Trans or blas.ConjTrans, 306 // 307 // where A is an n×n triangular band matrix, and x is a vector. 308 func Tbmv(t blas.Transpose, a TriangularBand, x Vector) { 309 blas32.Stbmv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc) 310 } 311 312 // Tpmv computes 313 // 314 // x = A * x if t == blas.NoTrans, 315 // x = Aᵀ * x if t == blas.Trans or blas.ConjTrans, 316 // 317 // where A is an n×n triangular matrix in packed format, and x is a vector. 318 func Tpmv(t blas.Transpose, a TriangularPacked, x Vector) { 319 blas32.Stpmv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc) 320 } 321 322 // Trsv solves 323 // 324 // A * x = b if t == blas.NoTrans, 325 // Aᵀ * x = b if t == blas.Trans or blas.ConjTrans, 326 // 327 // where A is an n×n triangular matrix, and x and b are vectors. 328 // 329 // At entry to the function, x contains the values of b, and the result is 330 // stored in-place into x. 331 // 332 // No test for singularity or near-singularity is included in this 333 // routine. Such tests must be performed before calling this routine. 334 func Trsv(t blas.Transpose, a Triangular, x Vector) { 335 blas32.Strsv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc) 336 } 337 338 // Tbsv solves 339 // 340 // A * x = b if t == blas.NoTrans, 341 // Aᵀ * x = b if t == blas.Trans or blas.ConjTrans, 342 // 343 // where A is an n×n triangular band matrix, and x and b are vectors. 344 // 345 // At entry to the function, x contains the values of b, and the result is 346 // stored in place into x. 347 // 348 // No test for singularity or near-singularity is included in this 349 // routine. Such tests must be performed before calling this routine. 350 func Tbsv(t blas.Transpose, a TriangularBand, x Vector) { 351 blas32.Stbsv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc) 352 } 353 354 // Tpsv solves 355 // 356 // A * x = b if t == blas.NoTrans, 357 // Aᵀ * x = b if t == blas.Trans or blas.ConjTrans, 358 // 359 // where A is an n×n triangular matrix in packed format, and x and b are 360 // vectors. 361 // 362 // At entry to the function, x contains the values of b, and the result is 363 // stored in place into x. 364 // 365 // No test for singularity or near-singularity is included in this 366 // routine. Such tests must be performed before calling this routine. 367 func Tpsv(t blas.Transpose, a TriangularPacked, x Vector) { 368 blas32.Stpsv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc) 369 } 370 371 // Symv computes 372 // 373 // y = alpha * A * x + beta * y, 374 // 375 // where A is an n×n symmetric matrix, x and y are vectors, and alpha and 376 // beta are scalars. 377 func Symv(alpha float32, a Symmetric, x Vector, beta float32, y Vector) { 378 blas32.Ssymv(a.Uplo, a.N, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 379 } 380 381 // Sbmv performs 382 // 383 // y = alpha * A * x + beta * y, 384 // 385 // where A is an n×n symmetric band matrix, x and y are vectors, and alpha 386 // and beta are scalars. 387 func Sbmv(alpha float32, a SymmetricBand, x Vector, beta float32, y Vector) { 388 blas32.Ssbmv(a.Uplo, a.N, a.K, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 389 } 390 391 // Spmv performs 392 // 393 // y = alpha * A * x + beta * y, 394 // 395 // where A is an n×n symmetric matrix in packed format, x and y are vectors, 396 // and alpha and beta are scalars. 397 func Spmv(alpha float32, a SymmetricPacked, x Vector, beta float32, y Vector) { 398 blas32.Sspmv(a.Uplo, a.N, alpha, a.Data, x.Data, x.Inc, beta, y.Data, y.Inc) 399 } 400 401 // Ger performs a rank-1 update 402 // 403 // A += alpha * x * yᵀ, 404 // 405 // where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar. 406 func Ger(alpha float32, x, y Vector, a General) { 407 blas32.Sger(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride) 408 } 409 410 // Syr performs a rank-1 update 411 // 412 // A += alpha * x * xᵀ, 413 // 414 // where A is an n×n symmetric matrix, x is a vector, and alpha is a scalar. 415 func Syr(alpha float32, x Vector, a Symmetric) { 416 blas32.Ssyr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data, a.Stride) 417 } 418 419 // Spr performs the rank-1 update 420 // 421 // A += alpha * x * xᵀ, 422 // 423 // where A is an n×n symmetric matrix in packed format, x is a vector, and 424 // alpha is a scalar. 425 func Spr(alpha float32, x Vector, a SymmetricPacked) { 426 blas32.Sspr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data) 427 } 428 429 // Syr2 performs a rank-2 update 430 // 431 // A += alpha * x * yᵀ + alpha * y * xᵀ, 432 // 433 // where A is a symmetric n×n matrix, x and y are vectors, and alpha is a scalar. 434 func Syr2(alpha float32, x, y Vector, a Symmetric) { 435 blas32.Ssyr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride) 436 } 437 438 // Spr2 performs a rank-2 update 439 // 440 // A += alpha * x * yᵀ + alpha * y * xᵀ, 441 // 442 // where A is an n×n symmetric matrix in packed format, x and y are vectors, 443 // and alpha is a scalar. 444 func Spr2(alpha float32, x, y Vector, a SymmetricPacked) { 445 blas32.Sspr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data) 446 } 447 448 // Level 3 449 450 // Gemm computes 451 // 452 // C = alpha * A * B + beta * C, 453 // 454 // where A, B, and C are dense matrices, and alpha and beta are scalars. 455 // tA and tB specify whether A or B are transposed. 456 func Gemm(tA, tB blas.Transpose, alpha float32, a, b General, beta float32, c General) { 457 var m, n, k int 458 if tA == blas.NoTrans { 459 m, k = a.Rows, a.Cols 460 } else { 461 m, k = a.Cols, a.Rows 462 } 463 if tB == blas.NoTrans { 464 n = b.Cols 465 } else { 466 n = b.Rows 467 } 468 blas32.Sgemm(tA, tB, m, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride) 469 } 470 471 // Symm performs 472 // 473 // C = alpha * A * B + beta * C if s == blas.Left, 474 // C = alpha * B * A + beta * C if s == blas.Right, 475 // 476 // where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and 477 // alpha is a scalar. 478 func Symm(s blas.Side, alpha float32, a Symmetric, b General, beta float32, c General) { 479 var m, n int 480 if s == blas.Left { 481 m, n = a.N, b.Cols 482 } else { 483 m, n = b.Rows, a.N 484 } 485 blas32.Ssymm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride) 486 } 487 488 // Syrk performs a symmetric rank-k update 489 // 490 // C = alpha * A * Aᵀ + beta * C if t == blas.NoTrans, 491 // C = alpha * Aᵀ * A + beta * C if t == blas.Trans or blas.ConjTrans, 492 // 493 // where C is an n×n symmetric matrix, A is an n×k matrix if t == blas.NoTrans and 494 // a k×n matrix otherwise, and alpha and beta are scalars. 495 func Syrk(t blas.Transpose, alpha float32, a General, beta float32, c Symmetric) { 496 var n, k int 497 if t == blas.NoTrans { 498 n, k = a.Rows, a.Cols 499 } else { 500 n, k = a.Cols, a.Rows 501 } 502 blas32.Ssyrk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride) 503 } 504 505 // Syr2k performs a symmetric rank-2k update 506 // 507 // C = alpha * A * Bᵀ + alpha * B * Aᵀ + beta * C if t == blas.NoTrans, 508 // C = alpha * Aᵀ * B + alpha * Bᵀ * A + beta * C if t == blas.Trans or blas.ConjTrans, 509 // 510 // where C is an n×n symmetric matrix, A and B are n×k matrices if t == NoTrans 511 // and k×n matrices otherwise, and alpha and beta are scalars. 512 func Syr2k(t blas.Transpose, alpha float32, a, b General, beta float32, c Symmetric) { 513 var n, k int 514 if t == blas.NoTrans { 515 n, k = a.Rows, a.Cols 516 } else { 517 n, k = a.Cols, a.Rows 518 } 519 blas32.Ssyr2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride) 520 } 521 522 // Trmm performs 523 // 524 // B = alpha * A * B if tA == blas.NoTrans and s == blas.Left, 525 // B = alpha * Aᵀ * B if tA == blas.Trans or blas.ConjTrans, and s == blas.Left, 526 // B = alpha * B * A if tA == blas.NoTrans and s == blas.Right, 527 // B = alpha * B * Aᵀ if tA == blas.Trans or blas.ConjTrans, and s == blas.Right, 528 // 529 // where A is an n×n or m×m triangular matrix, B is an m×n matrix, and alpha is 530 // a scalar. 531 func Trmm(s blas.Side, tA blas.Transpose, alpha float32, a Triangular, b General) { 532 blas32.Strmm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride) 533 } 534 535 // Trsm solves 536 // 537 // A * X = alpha * B if tA == blas.NoTrans and s == blas.Left, 538 // Aᵀ * X = alpha * B if tA == blas.Trans or blas.ConjTrans, and s == blas.Left, 539 // X * A = alpha * B if tA == blas.NoTrans and s == blas.Right, 540 // X * Aᵀ = alpha * B if tA == blas.Trans or blas.ConjTrans, and s == blas.Right, 541 // 542 // where A is an n×n or m×m triangular matrix, X and B are m×n matrices, and 543 // alpha is a scalar. 544 // 545 // At entry to the function, X contains the values of B, and the result is 546 // stored in-place into X. 547 // 548 // No check is made that A is invertible. 549 func Trsm(s blas.Side, tA blas.Transpose, alpha float32, a Triangular, b General) { 550 blas32.Strsm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride) 551 }