github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/blas/blas32/blas32.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package blas32
     6  
     7  import (
     8  	"github.com/jingcheng-WU/gonum/blas"
     9  	"github.com/jingcheng-WU/gonum/blas/gonum"
    10  )
    11  
    12  var blas32 blas.Float32 = gonum.Implementation{}
    13  
    14  // Use sets the BLAS float32 implementation to be used by subsequent BLAS calls.
    15  // The default implementation is
    16  // github.com/jingcheng-WU/gonum/blas/gonum.Implementation.
    17  func Use(b blas.Float32) {
    18  	blas32 = b
    19  }
    20  
    21  // Implementation returns the current BLAS float32 implementation.
    22  //
    23  // Implementation allows direct calls to the current the BLAS float32 implementation
    24  // giving finer control of parameters.
    25  func Implementation() blas.Float32 {
    26  	return blas32
    27  }
    28  
    29  // Vector represents a vector with an associated element increment.
    30  type Vector struct {
    31  	N    int
    32  	Inc  int
    33  	Data []float32
    34  }
    35  
    36  // General represents a matrix using the conventional storage scheme.
    37  type General struct {
    38  	Rows, Cols int
    39  	Stride     int
    40  	Data       []float32
    41  }
    42  
    43  // Band represents a band matrix using the band storage scheme.
    44  type Band struct {
    45  	Rows, Cols int
    46  	KL, KU     int
    47  	Stride     int
    48  	Data       []float32
    49  }
    50  
    51  // Triangular represents a triangular matrix using the conventional storage scheme.
    52  type Triangular struct {
    53  	N      int
    54  	Stride int
    55  	Data   []float32
    56  	Uplo   blas.Uplo
    57  	Diag   blas.Diag
    58  }
    59  
    60  // TriangularBand represents a triangular matrix using the band storage scheme.
    61  type TriangularBand struct {
    62  	N, K   int
    63  	Stride int
    64  	Data   []float32
    65  	Uplo   blas.Uplo
    66  	Diag   blas.Diag
    67  }
    68  
    69  // TriangularPacked represents a triangular matrix using the packed storage scheme.
    70  type TriangularPacked struct {
    71  	N    int
    72  	Data []float32
    73  	Uplo blas.Uplo
    74  	Diag blas.Diag
    75  }
    76  
    77  // Symmetric represents a symmetric matrix using the conventional storage scheme.
    78  type Symmetric struct {
    79  	N      int
    80  	Stride int
    81  	Data   []float32
    82  	Uplo   blas.Uplo
    83  }
    84  
    85  // SymmetricBand represents a symmetric matrix using the band storage scheme.
    86  type SymmetricBand struct {
    87  	N, K   int
    88  	Stride int
    89  	Data   []float32
    90  	Uplo   blas.Uplo
    91  }
    92  
    93  // SymmetricPacked represents a symmetric matrix using the packed storage scheme.
    94  type SymmetricPacked struct {
    95  	N    int
    96  	Data []float32
    97  	Uplo blas.Uplo
    98  }
    99  
   100  // Level 1
   101  
   102  const (
   103  	negInc    = "blas32: negative vector increment"
   104  	badLength = "blas32: vector length mismatch"
   105  )
   106  
   107  // Dot computes the dot product of the two vectors:
   108  //  \sum_i x[i]*y[i].
   109  // Dot will panic if the lengths of x and y do not match.
   110  func Dot(x, y Vector) float32 {
   111  	if x.N != y.N {
   112  		panic(badLength)
   113  	}
   114  	return blas32.Sdot(x.N, x.Data, x.Inc, y.Data, y.Inc)
   115  }
   116  
   117  // DDot computes the dot product of the two vectors:
   118  //  \sum_i x[i]*y[i].
   119  // DDot will panic if the lengths of x and y do not match.
   120  func DDot(x, y Vector) float64 {
   121  	if x.N != y.N {
   122  		panic(badLength)
   123  	}
   124  	return blas32.Dsdot(x.N, x.Data, x.Inc, y.Data, y.Inc)
   125  }
   126  
   127  // SDDot computes the dot product of the two vectors adding a constant:
   128  //  alpha + \sum_i x[i]*y[i].
   129  // SDDot will panic if the lengths of x and y do not match.
   130  func SDDot(alpha float32, x, y Vector) float32 {
   131  	if x.N != y.N {
   132  		panic(badLength)
   133  	}
   134  	return blas32.Sdsdot(x.N, alpha, x.Data, x.Inc, y.Data, y.Inc)
   135  }
   136  
   137  // Nrm2 computes the Euclidean norm of the vector x:
   138  //  sqrt(\sum_i x[i]*x[i]).
   139  //
   140  // Nrm2 will panic if the vector increment is negative.
   141  func Nrm2(x Vector) float32 {
   142  	if x.Inc < 0 {
   143  		panic(negInc)
   144  	}
   145  	return blas32.Snrm2(x.N, x.Data, x.Inc)
   146  }
   147  
   148  // Asum computes the sum of the absolute values of the elements of x:
   149  //  \sum_i |x[i]|.
   150  //
   151  // Asum will panic if the vector increment is negative.
   152  func Asum(x Vector) float32 {
   153  	if x.Inc < 0 {
   154  		panic(negInc)
   155  	}
   156  	return blas32.Sasum(x.N, x.Data, x.Inc)
   157  }
   158  
   159  // Iamax returns the index of an element of x with the largest absolute value.
   160  // If there are multiple such indices the earliest is returned.
   161  // Iamax returns -1 if n == 0.
   162  //
   163  // Iamax will panic if the vector increment is negative.
   164  func Iamax(x Vector) int {
   165  	if x.Inc < 0 {
   166  		panic(negInc)
   167  	}
   168  	return blas32.Isamax(x.N, x.Data, x.Inc)
   169  }
   170  
   171  // Swap exchanges the elements of the two vectors:
   172  //  x[i], y[i] = y[i], x[i] for all i.
   173  // Swap will panic if the lengths of x and y do not match.
   174  func Swap(x, y Vector) {
   175  	if x.N != y.N {
   176  		panic(badLength)
   177  	}
   178  	blas32.Sswap(x.N, x.Data, x.Inc, y.Data, y.Inc)
   179  }
   180  
   181  // Copy copies the elements of x into the elements of y:
   182  //  y[i] = x[i] for all i.
   183  // Copy will panic if the lengths of x and y do not match.
   184  func Copy(x, y Vector) {
   185  	if x.N != y.N {
   186  		panic(badLength)
   187  	}
   188  	blas32.Scopy(x.N, x.Data, x.Inc, y.Data, y.Inc)
   189  }
   190  
   191  // Axpy adds x scaled by alpha to y:
   192  //  y[i] += alpha*x[i] for all i.
   193  // Axpy will panic if the lengths of x and y do not match.
   194  func Axpy(alpha float32, x, y Vector) {
   195  	if x.N != y.N {
   196  		panic(badLength)
   197  	}
   198  	blas32.Saxpy(x.N, alpha, x.Data, x.Inc, y.Data, y.Inc)
   199  }
   200  
   201  // Rotg computes the parameters of a Givens plane rotation so that
   202  //  ⎡ c s⎤   ⎡a⎤   ⎡r⎤
   203  //  ⎣-s c⎦ * ⎣b⎦ = ⎣0⎦
   204  // where a and b are the Cartesian coordinates of a given point.
   205  // c, s, and r are defined as
   206  //  r = ±Sqrt(a^2 + b^2),
   207  //  c = a/r, the cosine of the rotation angle,
   208  //  s = a/r, the sine of the rotation angle,
   209  // and z is defined such that
   210  //  if |a| > |b|,        z = s,
   211  //  otherwise if c != 0, z = 1/c,
   212  //  otherwise            z = 1.
   213  func Rotg(a, b float32) (c, s, r, z float32) {
   214  	return blas32.Srotg(a, b)
   215  }
   216  
   217  // Rotmg computes the modified Givens rotation. See
   218  // http://www.netlib.org/lapack/explore-html/df/deb/drotmg_8f.html
   219  // for more details.
   220  func Rotmg(d1, d2, b1, b2 float32) (p blas.SrotmParams, rd1, rd2, rb1 float32) {
   221  	return blas32.Srotmg(d1, d2, b1, b2)
   222  }
   223  
   224  // Rot applies a plane transformation to n points represented by the vectors x
   225  // and y:
   226  //  x[i] =  c*x[i] + s*y[i],
   227  //  y[i] = -s*x[i] + c*y[i], for all i.
   228  func Rot(n int, x, y Vector, c, s float32) {
   229  	blas32.Srot(n, x.Data, x.Inc, y.Data, y.Inc, c, s)
   230  }
   231  
   232  // Rotm applies the modified Givens rotation to n points represented by the
   233  // vectors x and y.
   234  func Rotm(n int, x, y Vector, p blas.SrotmParams) {
   235  	blas32.Srotm(n, x.Data, x.Inc, y.Data, y.Inc, p)
   236  }
   237  
   238  // Scal scales the vector x by alpha:
   239  //  x[i] *= alpha for all i.
   240  //
   241  // Scal will panic if the vector increment is negative.
   242  func Scal(alpha float32, x Vector) {
   243  	if x.Inc < 0 {
   244  		panic(negInc)
   245  	}
   246  	blas32.Sscal(x.N, alpha, x.Data, x.Inc)
   247  }
   248  
   249  // Level 2
   250  
   251  // Gemv computes
   252  //  y = alpha * A * x + beta * y   if t == blas.NoTrans,
   253  //  y = alpha * Aᵀ * x + beta * y  if t == blas.Trans or blas.ConjTrans,
   254  // where A is an m×n dense matrix, x and y are vectors, and alpha and beta are scalars.
   255  func Gemv(t blas.Transpose, alpha float32, a General, x Vector, beta float32, y Vector) {
   256  	blas32.Sgemv(t, a.Rows, a.Cols, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
   257  }
   258  
   259  // Gbmv computes
   260  //  y = alpha * A * x + beta * y   if t == blas.NoTrans,
   261  //  y = alpha * Aᵀ * x + beta * y  if t == blas.Trans or blas.ConjTrans,
   262  // where A is an m×n band matrix, x and y are vectors, and alpha and beta are scalars.
   263  func Gbmv(t blas.Transpose, alpha float32, a Band, x Vector, beta float32, y Vector) {
   264  	blas32.Sgbmv(t, a.Rows, a.Cols, a.KL, a.KU, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
   265  }
   266  
   267  // Trmv computes
   268  //  x = A * x   if t == blas.NoTrans,
   269  //  x = Aᵀ * x  if t == blas.Trans or blas.ConjTrans,
   270  // where A is an n×n triangular matrix, and x is a vector.
   271  func Trmv(t blas.Transpose, a Triangular, x Vector) {
   272  	blas32.Strmv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
   273  }
   274  
   275  // Tbmv computes
   276  //  x = A * x   if t == blas.NoTrans,
   277  //  x = Aᵀ * x  if t == blas.Trans or blas.ConjTrans,
   278  // where A is an n×n triangular band matrix, and x is a vector.
   279  func Tbmv(t blas.Transpose, a TriangularBand, x Vector) {
   280  	blas32.Stbmv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
   281  }
   282  
   283  // Tpmv computes
   284  //  x = A * x   if t == blas.NoTrans,
   285  //  x = Aᵀ * x  if t == blas.Trans or blas.ConjTrans,
   286  // where A is an n×n triangular matrix in packed format, and x is a vector.
   287  func Tpmv(t blas.Transpose, a TriangularPacked, x Vector) {
   288  	blas32.Stpmv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
   289  }
   290  
   291  // Trsv solves
   292  //  A * x = b   if t == blas.NoTrans,
   293  //  Aᵀ * x = b  if t == blas.Trans or blas.ConjTrans,
   294  // where A is an n×n triangular matrix, and x and b are vectors.
   295  //
   296  // At entry to the function, x contains the values of b, and the result is
   297  // stored in-place into x.
   298  //
   299  // No test for singularity or near-singularity is included in this
   300  // routine. Such tests must be performed before calling this routine.
   301  func Trsv(t blas.Transpose, a Triangular, x Vector) {
   302  	blas32.Strsv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
   303  }
   304  
   305  // Tbsv solves
   306  //  A * x = b   if t == blas.NoTrans,
   307  //  Aᵀ * x = b  if t == blas.Trans or blas.ConjTrans,
   308  // where A is an n×n triangular band matrix, and x and b are vectors.
   309  //
   310  // At entry to the function, x contains the values of b, and the result is
   311  // stored in place into x.
   312  //
   313  // No test for singularity or near-singularity is included in this
   314  // routine. Such tests must be performed before calling this routine.
   315  func Tbsv(t blas.Transpose, a TriangularBand, x Vector) {
   316  	blas32.Stbsv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
   317  }
   318  
   319  // Tpsv solves
   320  //  A * x = b   if t == blas.NoTrans,
   321  //  Aᵀ * x = b  if t == blas.Trans or blas.ConjTrans,
   322  // where A is an n×n triangular matrix in packed format, and x and b are
   323  // vectors.
   324  //
   325  // At entry to the function, x contains the values of b, and the result is
   326  // stored in place into x.
   327  //
   328  // No test for singularity or near-singularity is included in this
   329  // routine. Such tests must be performed before calling this routine.
   330  func Tpsv(t blas.Transpose, a TriangularPacked, x Vector) {
   331  	blas32.Stpsv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
   332  }
   333  
   334  // Symv computes
   335  //  y = alpha * A * x + beta * y,
   336  // where A is an n×n symmetric matrix, x and y are vectors, and alpha and
   337  // beta are scalars.
   338  func Symv(alpha float32, a Symmetric, x Vector, beta float32, y Vector) {
   339  	blas32.Ssymv(a.Uplo, a.N, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
   340  }
   341  
   342  // Sbmv performs
   343  //  y = alpha * A * x + beta * y,
   344  // where A is an n×n symmetric band matrix, x and y are vectors, and alpha
   345  // and beta are scalars.
   346  func Sbmv(alpha float32, a SymmetricBand, x Vector, beta float32, y Vector) {
   347  	blas32.Ssbmv(a.Uplo, a.N, a.K, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
   348  }
   349  
   350  // Spmv performs
   351  //  y = alpha * A * x + beta * y,
   352  // where A is an n×n symmetric matrix in packed format, x and y are vectors,
   353  // and alpha and beta are scalars.
   354  func Spmv(alpha float32, a SymmetricPacked, x Vector, beta float32, y Vector) {
   355  	blas32.Sspmv(a.Uplo, a.N, alpha, a.Data, x.Data, x.Inc, beta, y.Data, y.Inc)
   356  }
   357  
   358  // Ger performs a rank-1 update
   359  //  A += alpha * x * yᵀ,
   360  // where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar.
   361  func Ger(alpha float32, x, y Vector, a General) {
   362  	blas32.Sger(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
   363  }
   364  
   365  // Syr performs a rank-1 update
   366  //  A += alpha * x * xᵀ,
   367  // where A is an n×n symmetric matrix, x is a vector, and alpha is a scalar.
   368  func Syr(alpha float32, x Vector, a Symmetric) {
   369  	blas32.Ssyr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data, a.Stride)
   370  }
   371  
   372  // Spr performs the rank-1 update
   373  //  A += alpha * x * xᵀ,
   374  // where A is an n×n symmetric matrix in packed format, x is a vector, and
   375  // alpha is a scalar.
   376  func Spr(alpha float32, x Vector, a SymmetricPacked) {
   377  	blas32.Sspr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data)
   378  }
   379  
   380  // Syr2 performs a rank-2 update
   381  //  A += alpha * x * yᵀ + alpha * y * xᵀ,
   382  // where A is a symmetric n×n matrix, x and y are vectors, and alpha is a scalar.
   383  func Syr2(alpha float32, x, y Vector, a Symmetric) {
   384  	blas32.Ssyr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
   385  }
   386  
   387  // Spr2 performs a rank-2 update
   388  //  A += alpha * x * yᵀ + alpha * y * xᵀ,
   389  // where A is an n×n symmetric matrix in packed format, x and y are vectors,
   390  // and alpha is a scalar.
   391  func Spr2(alpha float32, x, y Vector, a SymmetricPacked) {
   392  	blas32.Sspr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data)
   393  }
   394  
   395  // Level 3
   396  
   397  // Gemm computes
   398  //  C = alpha * A * B + beta * C,
   399  // where A, B, and C are dense matrices, and alpha and beta are scalars.
   400  // tA and tB specify whether A or B are transposed.
   401  func Gemm(tA, tB blas.Transpose, alpha float32, a, b General, beta float32, c General) {
   402  	var m, n, k int
   403  	if tA == blas.NoTrans {
   404  		m, k = a.Rows, a.Cols
   405  	} else {
   406  		m, k = a.Cols, a.Rows
   407  	}
   408  	if tB == blas.NoTrans {
   409  		n = b.Cols
   410  	} else {
   411  		n = b.Rows
   412  	}
   413  	blas32.Sgemm(tA, tB, m, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
   414  }
   415  
   416  // Symm performs
   417  //  C = alpha * A * B + beta * C  if s == blas.Left,
   418  //  C = alpha * B * A + beta * C  if s == blas.Right,
   419  // where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and
   420  // alpha is a scalar.
   421  func Symm(s blas.Side, alpha float32, a Symmetric, b General, beta float32, c General) {
   422  	var m, n int
   423  	if s == blas.Left {
   424  		m, n = a.N, b.Cols
   425  	} else {
   426  		m, n = b.Rows, a.N
   427  	}
   428  	blas32.Ssymm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
   429  }
   430  
   431  // Syrk performs a symmetric rank-k update
   432  //  C = alpha * A * Aᵀ + beta * C  if t == blas.NoTrans,
   433  //  C = alpha * Aᵀ * A + beta * C  if t == blas.Trans or blas.ConjTrans,
   434  // where C is an n×n symmetric matrix, A is an n×k matrix if t == blas.NoTrans and
   435  // a k×n matrix otherwise, and alpha and beta are scalars.
   436  func Syrk(t blas.Transpose, alpha float32, a General, beta float32, c Symmetric) {
   437  	var n, k int
   438  	if t == blas.NoTrans {
   439  		n, k = a.Rows, a.Cols
   440  	} else {
   441  		n, k = a.Cols, a.Rows
   442  	}
   443  	blas32.Ssyrk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride)
   444  }
   445  
   446  // Syr2k performs a symmetric rank-2k update
   447  //  C = alpha * A * Bᵀ + alpha * B * Aᵀ + beta * C  if t == blas.NoTrans,
   448  //  C = alpha * Aᵀ * B + alpha * Bᵀ * A + beta * C  if t == blas.Trans or blas.ConjTrans,
   449  // where C is an n×n symmetric matrix, A and B are n×k matrices if t == NoTrans
   450  // and k×n matrices otherwise, and alpha and beta are scalars.
   451  func Syr2k(t blas.Transpose, alpha float32, a, b General, beta float32, c Symmetric) {
   452  	var n, k int
   453  	if t == blas.NoTrans {
   454  		n, k = a.Rows, a.Cols
   455  	} else {
   456  		n, k = a.Cols, a.Rows
   457  	}
   458  	blas32.Ssyr2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
   459  }
   460  
   461  // Trmm performs
   462  //  B = alpha * A * B   if tA == blas.NoTrans and s == blas.Left,
   463  //  B = alpha * Aᵀ * B  if tA == blas.Trans or blas.ConjTrans, and s == blas.Left,
   464  //  B = alpha * B * A   if tA == blas.NoTrans and s == blas.Right,
   465  //  B = alpha * B * Aᵀ  if tA == blas.Trans or blas.ConjTrans, and s == blas.Right,
   466  // where A is an n×n or m×m triangular matrix, B is an m×n matrix, and alpha is
   467  // a scalar.
   468  func Trmm(s blas.Side, tA blas.Transpose, alpha float32, a Triangular, b General) {
   469  	blas32.Strmm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
   470  }
   471  
   472  // Trsm solves
   473  //  A * X = alpha * B   if tA == blas.NoTrans and s == blas.Left,
   474  //  Aᵀ * X = alpha * B  if tA == blas.Trans or blas.ConjTrans, and s == blas.Left,
   475  //  X * A = alpha * B   if tA == blas.NoTrans and s == blas.Right,
   476  //  X * Aᵀ = alpha * B  if tA == blas.Trans or blas.ConjTrans, and s == blas.Right,
   477  // where A is an n×n or m×m triangular matrix, X and B are m×n matrices, and
   478  // alpha is a scalar.
   479  //
   480  // At entry to the function, X contains the values of B, and the result is
   481  // stored in-place into X.
   482  //
   483  // No check is made that A is invertible.
   484  func Trsm(s blas.Side, tA blas.Transpose, alpha float32, a Triangular, b General) {
   485  	blas32.Strsm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
   486  }