github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/blas/blas32/blas32.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package blas32 6 7 import ( 8 "github.com/jingcheng-WU/gonum/blas" 9 "github.com/jingcheng-WU/gonum/blas/gonum" 10 ) 11 12 var blas32 blas.Float32 = gonum.Implementation{} 13 14 // Use sets the BLAS float32 implementation to be used by subsequent BLAS calls. 15 // The default implementation is 16 // github.com/jingcheng-WU/gonum/blas/gonum.Implementation. 17 func Use(b blas.Float32) { 18 blas32 = b 19 } 20 21 // Implementation returns the current BLAS float32 implementation. 22 // 23 // Implementation allows direct calls to the current the BLAS float32 implementation 24 // giving finer control of parameters. 25 func Implementation() blas.Float32 { 26 return blas32 27 } 28 29 // Vector represents a vector with an associated element increment. 30 type Vector struct { 31 N int 32 Inc int 33 Data []float32 34 } 35 36 // General represents a matrix using the conventional storage scheme. 37 type General struct { 38 Rows, Cols int 39 Stride int 40 Data []float32 41 } 42 43 // Band represents a band matrix using the band storage scheme. 44 type Band struct { 45 Rows, Cols int 46 KL, KU int 47 Stride int 48 Data []float32 49 } 50 51 // Triangular represents a triangular matrix using the conventional storage scheme. 52 type Triangular struct { 53 N int 54 Stride int 55 Data []float32 56 Uplo blas.Uplo 57 Diag blas.Diag 58 } 59 60 // TriangularBand represents a triangular matrix using the band storage scheme. 61 type TriangularBand struct { 62 N, K int 63 Stride int 64 Data []float32 65 Uplo blas.Uplo 66 Diag blas.Diag 67 } 68 69 // TriangularPacked represents a triangular matrix using the packed storage scheme. 70 type TriangularPacked struct { 71 N int 72 Data []float32 73 Uplo blas.Uplo 74 Diag blas.Diag 75 } 76 77 // Symmetric represents a symmetric matrix using the conventional storage scheme. 78 type Symmetric struct { 79 N int 80 Stride int 81 Data []float32 82 Uplo blas.Uplo 83 } 84 85 // SymmetricBand represents a symmetric matrix using the band storage scheme. 86 type SymmetricBand struct { 87 N, K int 88 Stride int 89 Data []float32 90 Uplo blas.Uplo 91 } 92 93 // SymmetricPacked represents a symmetric matrix using the packed storage scheme. 94 type SymmetricPacked struct { 95 N int 96 Data []float32 97 Uplo blas.Uplo 98 } 99 100 // Level 1 101 102 const ( 103 negInc = "blas32: negative vector increment" 104 badLength = "blas32: vector length mismatch" 105 ) 106 107 // Dot computes the dot product of the two vectors: 108 // \sum_i x[i]*y[i]. 109 // Dot will panic if the lengths of x and y do not match. 110 func Dot(x, y Vector) float32 { 111 if x.N != y.N { 112 panic(badLength) 113 } 114 return blas32.Sdot(x.N, x.Data, x.Inc, y.Data, y.Inc) 115 } 116 117 // DDot computes the dot product of the two vectors: 118 // \sum_i x[i]*y[i]. 119 // DDot will panic if the lengths of x and y do not match. 120 func DDot(x, y Vector) float64 { 121 if x.N != y.N { 122 panic(badLength) 123 } 124 return blas32.Dsdot(x.N, x.Data, x.Inc, y.Data, y.Inc) 125 } 126 127 // SDDot computes the dot product of the two vectors adding a constant: 128 // alpha + \sum_i x[i]*y[i]. 129 // SDDot will panic if the lengths of x and y do not match. 130 func SDDot(alpha float32, x, y Vector) float32 { 131 if x.N != y.N { 132 panic(badLength) 133 } 134 return blas32.Sdsdot(x.N, alpha, x.Data, x.Inc, y.Data, y.Inc) 135 } 136 137 // Nrm2 computes the Euclidean norm of the vector x: 138 // sqrt(\sum_i x[i]*x[i]). 139 // 140 // Nrm2 will panic if the vector increment is negative. 141 func Nrm2(x Vector) float32 { 142 if x.Inc < 0 { 143 panic(negInc) 144 } 145 return blas32.Snrm2(x.N, x.Data, x.Inc) 146 } 147 148 // Asum computes the sum of the absolute values of the elements of x: 149 // \sum_i |x[i]|. 150 // 151 // Asum will panic if the vector increment is negative. 152 func Asum(x Vector) float32 { 153 if x.Inc < 0 { 154 panic(negInc) 155 } 156 return blas32.Sasum(x.N, x.Data, x.Inc) 157 } 158 159 // Iamax returns the index of an element of x with the largest absolute value. 160 // If there are multiple such indices the earliest is returned. 161 // Iamax returns -1 if n == 0. 162 // 163 // Iamax will panic if the vector increment is negative. 164 func Iamax(x Vector) int { 165 if x.Inc < 0 { 166 panic(negInc) 167 } 168 return blas32.Isamax(x.N, x.Data, x.Inc) 169 } 170 171 // Swap exchanges the elements of the two vectors: 172 // x[i], y[i] = y[i], x[i] for all i. 173 // Swap will panic if the lengths of x and y do not match. 174 func Swap(x, y Vector) { 175 if x.N != y.N { 176 panic(badLength) 177 } 178 blas32.Sswap(x.N, x.Data, x.Inc, y.Data, y.Inc) 179 } 180 181 // Copy copies the elements of x into the elements of y: 182 // y[i] = x[i] for all i. 183 // Copy will panic if the lengths of x and y do not match. 184 func Copy(x, y Vector) { 185 if x.N != y.N { 186 panic(badLength) 187 } 188 blas32.Scopy(x.N, x.Data, x.Inc, y.Data, y.Inc) 189 } 190 191 // Axpy adds x scaled by alpha to y: 192 // y[i] += alpha*x[i] for all i. 193 // Axpy will panic if the lengths of x and y do not match. 194 func Axpy(alpha float32, x, y Vector) { 195 if x.N != y.N { 196 panic(badLength) 197 } 198 blas32.Saxpy(x.N, alpha, x.Data, x.Inc, y.Data, y.Inc) 199 } 200 201 // Rotg computes the parameters of a Givens plane rotation so that 202 // ⎡ c s⎤ ⎡a⎤ ⎡r⎤ 203 // ⎣-s c⎦ * ⎣b⎦ = ⎣0⎦ 204 // where a and b are the Cartesian coordinates of a given point. 205 // c, s, and r are defined as 206 // r = ±Sqrt(a^2 + b^2), 207 // c = a/r, the cosine of the rotation angle, 208 // s = a/r, the sine of the rotation angle, 209 // and z is defined such that 210 // if |a| > |b|, z = s, 211 // otherwise if c != 0, z = 1/c, 212 // otherwise z = 1. 213 func Rotg(a, b float32) (c, s, r, z float32) { 214 return blas32.Srotg(a, b) 215 } 216 217 // Rotmg computes the modified Givens rotation. See 218 // http://www.netlib.org/lapack/explore-html/df/deb/drotmg_8f.html 219 // for more details. 220 func Rotmg(d1, d2, b1, b2 float32) (p blas.SrotmParams, rd1, rd2, rb1 float32) { 221 return blas32.Srotmg(d1, d2, b1, b2) 222 } 223 224 // Rot applies a plane transformation to n points represented by the vectors x 225 // and y: 226 // x[i] = c*x[i] + s*y[i], 227 // y[i] = -s*x[i] + c*y[i], for all i. 228 func Rot(n int, x, y Vector, c, s float32) { 229 blas32.Srot(n, x.Data, x.Inc, y.Data, y.Inc, c, s) 230 } 231 232 // Rotm applies the modified Givens rotation to n points represented by the 233 // vectors x and y. 234 func Rotm(n int, x, y Vector, p blas.SrotmParams) { 235 blas32.Srotm(n, x.Data, x.Inc, y.Data, y.Inc, p) 236 } 237 238 // Scal scales the vector x by alpha: 239 // x[i] *= alpha for all i. 240 // 241 // Scal will panic if the vector increment is negative. 242 func Scal(alpha float32, x Vector) { 243 if x.Inc < 0 { 244 panic(negInc) 245 } 246 blas32.Sscal(x.N, alpha, x.Data, x.Inc) 247 } 248 249 // Level 2 250 251 // Gemv computes 252 // y = alpha * A * x + beta * y if t == blas.NoTrans, 253 // y = alpha * Aᵀ * x + beta * y if t == blas.Trans or blas.ConjTrans, 254 // where A is an m×n dense matrix, x and y are vectors, and alpha and beta are scalars. 255 func Gemv(t blas.Transpose, alpha float32, a General, x Vector, beta float32, y Vector) { 256 blas32.Sgemv(t, a.Rows, a.Cols, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 257 } 258 259 // Gbmv computes 260 // y = alpha * A * x + beta * y if t == blas.NoTrans, 261 // y = alpha * Aᵀ * x + beta * y if t == blas.Trans or blas.ConjTrans, 262 // where A is an m×n band matrix, x and y are vectors, and alpha and beta are scalars. 263 func Gbmv(t blas.Transpose, alpha float32, a Band, x Vector, beta float32, y Vector) { 264 blas32.Sgbmv(t, a.Rows, a.Cols, a.KL, a.KU, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 265 } 266 267 // Trmv computes 268 // x = A * x if t == blas.NoTrans, 269 // x = Aᵀ * x if t == blas.Trans or blas.ConjTrans, 270 // where A is an n×n triangular matrix, and x is a vector. 271 func Trmv(t blas.Transpose, a Triangular, x Vector) { 272 blas32.Strmv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc) 273 } 274 275 // Tbmv computes 276 // x = A * x if t == blas.NoTrans, 277 // x = Aᵀ * x if t == blas.Trans or blas.ConjTrans, 278 // where A is an n×n triangular band matrix, and x is a vector. 279 func Tbmv(t blas.Transpose, a TriangularBand, x Vector) { 280 blas32.Stbmv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc) 281 } 282 283 // Tpmv computes 284 // x = A * x if t == blas.NoTrans, 285 // x = Aᵀ * x if t == blas.Trans or blas.ConjTrans, 286 // where A is an n×n triangular matrix in packed format, and x is a vector. 287 func Tpmv(t blas.Transpose, a TriangularPacked, x Vector) { 288 blas32.Stpmv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc) 289 } 290 291 // Trsv solves 292 // A * x = b if t == blas.NoTrans, 293 // Aᵀ * x = b if t == blas.Trans or blas.ConjTrans, 294 // where A is an n×n triangular matrix, and x and b are vectors. 295 // 296 // At entry to the function, x contains the values of b, and the result is 297 // stored in-place into x. 298 // 299 // No test for singularity or near-singularity is included in this 300 // routine. Such tests must be performed before calling this routine. 301 func Trsv(t blas.Transpose, a Triangular, x Vector) { 302 blas32.Strsv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc) 303 } 304 305 // Tbsv solves 306 // A * x = b if t == blas.NoTrans, 307 // Aᵀ * x = b if t == blas.Trans or blas.ConjTrans, 308 // where A is an n×n triangular band matrix, and x and b are vectors. 309 // 310 // At entry to the function, x contains the values of b, and the result is 311 // stored in place into x. 312 // 313 // No test for singularity or near-singularity is included in this 314 // routine. Such tests must be performed before calling this routine. 315 func Tbsv(t blas.Transpose, a TriangularBand, x Vector) { 316 blas32.Stbsv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc) 317 } 318 319 // Tpsv solves 320 // A * x = b if t == blas.NoTrans, 321 // Aᵀ * x = b if t == blas.Trans or blas.ConjTrans, 322 // where A is an n×n triangular matrix in packed format, and x and b are 323 // vectors. 324 // 325 // At entry to the function, x contains the values of b, and the result is 326 // stored in place into x. 327 // 328 // No test for singularity or near-singularity is included in this 329 // routine. Such tests must be performed before calling this routine. 330 func Tpsv(t blas.Transpose, a TriangularPacked, x Vector) { 331 blas32.Stpsv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc) 332 } 333 334 // Symv computes 335 // y = alpha * A * x + beta * y, 336 // where A is an n×n symmetric matrix, x and y are vectors, and alpha and 337 // beta are scalars. 338 func Symv(alpha float32, a Symmetric, x Vector, beta float32, y Vector) { 339 blas32.Ssymv(a.Uplo, a.N, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 340 } 341 342 // Sbmv performs 343 // y = alpha * A * x + beta * y, 344 // where A is an n×n symmetric band matrix, x and y are vectors, and alpha 345 // and beta are scalars. 346 func Sbmv(alpha float32, a SymmetricBand, x Vector, beta float32, y Vector) { 347 blas32.Ssbmv(a.Uplo, a.N, a.K, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 348 } 349 350 // Spmv performs 351 // y = alpha * A * x + beta * y, 352 // where A is an n×n symmetric matrix in packed format, x and y are vectors, 353 // and alpha and beta are scalars. 354 func Spmv(alpha float32, a SymmetricPacked, x Vector, beta float32, y Vector) { 355 blas32.Sspmv(a.Uplo, a.N, alpha, a.Data, x.Data, x.Inc, beta, y.Data, y.Inc) 356 } 357 358 // Ger performs a rank-1 update 359 // A += alpha * x * yᵀ, 360 // where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar. 361 func Ger(alpha float32, x, y Vector, a General) { 362 blas32.Sger(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride) 363 } 364 365 // Syr performs a rank-1 update 366 // A += alpha * x * xᵀ, 367 // where A is an n×n symmetric matrix, x is a vector, and alpha is a scalar. 368 func Syr(alpha float32, x Vector, a Symmetric) { 369 blas32.Ssyr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data, a.Stride) 370 } 371 372 // Spr performs the rank-1 update 373 // A += alpha * x * xᵀ, 374 // where A is an n×n symmetric matrix in packed format, x is a vector, and 375 // alpha is a scalar. 376 func Spr(alpha float32, x Vector, a SymmetricPacked) { 377 blas32.Sspr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data) 378 } 379 380 // Syr2 performs a rank-2 update 381 // A += alpha * x * yᵀ + alpha * y * xᵀ, 382 // where A is a symmetric n×n matrix, x and y are vectors, and alpha is a scalar. 383 func Syr2(alpha float32, x, y Vector, a Symmetric) { 384 blas32.Ssyr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride) 385 } 386 387 // Spr2 performs a rank-2 update 388 // A += alpha * x * yᵀ + alpha * y * xᵀ, 389 // where A is an n×n symmetric matrix in packed format, x and y are vectors, 390 // and alpha is a scalar. 391 func Spr2(alpha float32, x, y Vector, a SymmetricPacked) { 392 blas32.Sspr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data) 393 } 394 395 // Level 3 396 397 // Gemm computes 398 // C = alpha * A * B + beta * C, 399 // where A, B, and C are dense matrices, and alpha and beta are scalars. 400 // tA and tB specify whether A or B are transposed. 401 func Gemm(tA, tB blas.Transpose, alpha float32, a, b General, beta float32, c General) { 402 var m, n, k int 403 if tA == blas.NoTrans { 404 m, k = a.Rows, a.Cols 405 } else { 406 m, k = a.Cols, a.Rows 407 } 408 if tB == blas.NoTrans { 409 n = b.Cols 410 } else { 411 n = b.Rows 412 } 413 blas32.Sgemm(tA, tB, m, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride) 414 } 415 416 // Symm performs 417 // C = alpha * A * B + beta * C if s == blas.Left, 418 // C = alpha * B * A + beta * C if s == blas.Right, 419 // where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and 420 // alpha is a scalar. 421 func Symm(s blas.Side, alpha float32, a Symmetric, b General, beta float32, c General) { 422 var m, n int 423 if s == blas.Left { 424 m, n = a.N, b.Cols 425 } else { 426 m, n = b.Rows, a.N 427 } 428 blas32.Ssymm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride) 429 } 430 431 // Syrk performs a symmetric rank-k update 432 // C = alpha * A * Aᵀ + beta * C if t == blas.NoTrans, 433 // C = alpha * Aᵀ * A + beta * C if t == blas.Trans or blas.ConjTrans, 434 // where C is an n×n symmetric matrix, A is an n×k matrix if t == blas.NoTrans and 435 // a k×n matrix otherwise, and alpha and beta are scalars. 436 func Syrk(t blas.Transpose, alpha float32, a General, beta float32, c Symmetric) { 437 var n, k int 438 if t == blas.NoTrans { 439 n, k = a.Rows, a.Cols 440 } else { 441 n, k = a.Cols, a.Rows 442 } 443 blas32.Ssyrk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride) 444 } 445 446 // Syr2k performs a symmetric rank-2k update 447 // C = alpha * A * Bᵀ + alpha * B * Aᵀ + beta * C if t == blas.NoTrans, 448 // C = alpha * Aᵀ * B + alpha * Bᵀ * A + beta * C if t == blas.Trans or blas.ConjTrans, 449 // where C is an n×n symmetric matrix, A and B are n×k matrices if t == NoTrans 450 // and k×n matrices otherwise, and alpha and beta are scalars. 451 func Syr2k(t blas.Transpose, alpha float32, a, b General, beta float32, c Symmetric) { 452 var n, k int 453 if t == blas.NoTrans { 454 n, k = a.Rows, a.Cols 455 } else { 456 n, k = a.Cols, a.Rows 457 } 458 blas32.Ssyr2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride) 459 } 460 461 // Trmm performs 462 // B = alpha * A * B if tA == blas.NoTrans and s == blas.Left, 463 // B = alpha * Aᵀ * B if tA == blas.Trans or blas.ConjTrans, and s == blas.Left, 464 // B = alpha * B * A if tA == blas.NoTrans and s == blas.Right, 465 // B = alpha * B * Aᵀ if tA == blas.Trans or blas.ConjTrans, and s == blas.Right, 466 // where A is an n×n or m×m triangular matrix, B is an m×n matrix, and alpha is 467 // a scalar. 468 func Trmm(s blas.Side, tA blas.Transpose, alpha float32, a Triangular, b General) { 469 blas32.Strmm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride) 470 } 471 472 // Trsm solves 473 // A * X = alpha * B if tA == blas.NoTrans and s == blas.Left, 474 // Aᵀ * X = alpha * B if tA == blas.Trans or blas.ConjTrans, and s == blas.Left, 475 // X * A = alpha * B if tA == blas.NoTrans and s == blas.Right, 476 // X * Aᵀ = alpha * B if tA == blas.Trans or blas.ConjTrans, and s == blas.Right, 477 // where A is an n×n or m×m triangular matrix, X and B are m×n matrices, and 478 // alpha is a scalar. 479 // 480 // At entry to the function, X contains the values of B, and the result is 481 // stored in-place into X. 482 // 483 // No check is made that A is invertible. 484 func Trsm(s blas.Side, tA blas.Transpose, alpha float32, a Triangular, b General) { 485 blas32.Strsm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride) 486 }