gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/blas/blas32/blas32.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package blas32
     6  
     7  import (
     8  	"gonum.org/v1/gonum/blas"
     9  	"gonum.org/v1/gonum/blas/gonum"
    10  )
    11  
    12  var blas32 blas.Float32 = gonum.Implementation{}
    13  
    14  // Use sets the BLAS float32 implementation to be used by subsequent BLAS calls.
    15  // The default implementation is
    16  // gonum.org/v1/gonum/blas/gonum.Implementation.
    17  func Use(b blas.Float32) {
    18  	blas32 = b
    19  }
    20  
    21  // Implementation returns the current BLAS float32 implementation.
    22  //
    23  // Implementation allows direct calls to the current the BLAS float32 implementation
    24  // giving finer control of parameters.
    25  func Implementation() blas.Float32 {
    26  	return blas32
    27  }
    28  
    29  // Vector represents a vector with an associated element increment.
    30  type Vector struct {
    31  	N    int
    32  	Inc  int
    33  	Data []float32
    34  }
    35  
    36  // General represents a matrix using the conventional storage scheme.
    37  type General struct {
    38  	Rows, Cols int
    39  	Stride     int
    40  	Data       []float32
    41  }
    42  
    43  // Band represents a band matrix using the band storage scheme.
    44  type Band struct {
    45  	Rows, Cols int
    46  	KL, KU     int
    47  	Stride     int
    48  	Data       []float32
    49  }
    50  
    51  // Triangular represents a triangular matrix using the conventional storage scheme.
    52  type Triangular struct {
    53  	N      int
    54  	Stride int
    55  	Data   []float32
    56  	Uplo   blas.Uplo
    57  	Diag   blas.Diag
    58  }
    59  
    60  // TriangularBand represents a triangular matrix using the band storage scheme.
    61  type TriangularBand struct {
    62  	N, K   int
    63  	Stride int
    64  	Data   []float32
    65  	Uplo   blas.Uplo
    66  	Diag   blas.Diag
    67  }
    68  
    69  // TriangularPacked represents a triangular matrix using the packed storage scheme.
    70  type TriangularPacked struct {
    71  	N    int
    72  	Data []float32
    73  	Uplo blas.Uplo
    74  	Diag blas.Diag
    75  }
    76  
    77  // Symmetric represents a symmetric matrix using the conventional storage scheme.
    78  type Symmetric struct {
    79  	N      int
    80  	Stride int
    81  	Data   []float32
    82  	Uplo   blas.Uplo
    83  }
    84  
    85  // SymmetricBand represents a symmetric matrix using the band storage scheme.
    86  type SymmetricBand struct {
    87  	N, K   int
    88  	Stride int
    89  	Data   []float32
    90  	Uplo   blas.Uplo
    91  }
    92  
    93  // SymmetricPacked represents a symmetric matrix using the packed storage scheme.
    94  type SymmetricPacked struct {
    95  	N    int
    96  	Data []float32
    97  	Uplo blas.Uplo
    98  }
    99  
   100  // Level 1
   101  
   102  const (
   103  	negInc    = "blas32: negative vector increment"
   104  	badLength = "blas32: vector length mismatch"
   105  )
   106  
   107  // Dot computes the dot product of the two vectors:
   108  //
   109  //	\sum_i x[i]*y[i].
   110  //
   111  // Dot will panic if the lengths of x and y do not match.
   112  func Dot(x, y Vector) float32 {
   113  	if x.N != y.N {
   114  		panic(badLength)
   115  	}
   116  	return blas32.Sdot(x.N, x.Data, x.Inc, y.Data, y.Inc)
   117  }
   118  
   119  // DDot computes the dot product of the two vectors:
   120  //
   121  //	\sum_i x[i]*y[i].
   122  //
   123  // DDot will panic if the lengths of x and y do not match.
   124  func DDot(x, y Vector) float64 {
   125  	if x.N != y.N {
   126  		panic(badLength)
   127  	}
   128  	return blas32.Dsdot(x.N, x.Data, x.Inc, y.Data, y.Inc)
   129  }
   130  
   131  // SDDot computes the dot product of the two vectors adding a constant:
   132  //
   133  //	alpha + \sum_i x[i]*y[i].
   134  //
   135  // SDDot will panic if the lengths of x and y do not match.
   136  func SDDot(alpha float32, x, y Vector) float32 {
   137  	if x.N != y.N {
   138  		panic(badLength)
   139  	}
   140  	return blas32.Sdsdot(x.N, alpha, x.Data, x.Inc, y.Data, y.Inc)
   141  }
   142  
   143  // Nrm2 computes the Euclidean norm of the vector x:
   144  //
   145  //	sqrt(\sum_i x[i]*x[i]).
   146  //
   147  // Nrm2 will panic if the vector increment is negative.
   148  func Nrm2(x Vector) float32 {
   149  	if x.Inc < 0 {
   150  		panic(negInc)
   151  	}
   152  	return blas32.Snrm2(x.N, x.Data, x.Inc)
   153  }
   154  
   155  // Asum computes the sum of the absolute values of the elements of x:
   156  //
   157  //	\sum_i |x[i]|.
   158  //
   159  // Asum will panic if the vector increment is negative.
   160  func Asum(x Vector) float32 {
   161  	if x.Inc < 0 {
   162  		panic(negInc)
   163  	}
   164  	return blas32.Sasum(x.N, x.Data, x.Inc)
   165  }
   166  
   167  // Iamax returns the index of an element of x with the largest absolute value.
   168  // If there are multiple such indices the earliest is returned.
   169  // Iamax returns -1 if n == 0.
   170  //
   171  // Iamax will panic if the vector increment is negative.
   172  func Iamax(x Vector) int {
   173  	if x.Inc < 0 {
   174  		panic(negInc)
   175  	}
   176  	return blas32.Isamax(x.N, x.Data, x.Inc)
   177  }
   178  
   179  // Swap exchanges the elements of the two vectors:
   180  //
   181  //	x[i], y[i] = y[i], x[i] for all i.
   182  //
   183  // Swap will panic if the lengths of x and y do not match.
   184  func Swap(x, y Vector) {
   185  	if x.N != y.N {
   186  		panic(badLength)
   187  	}
   188  	blas32.Sswap(x.N, x.Data, x.Inc, y.Data, y.Inc)
   189  }
   190  
   191  // Copy copies the elements of x into the elements of y:
   192  //
   193  //	y[i] = x[i] for all i.
   194  //
   195  // Copy will panic if the lengths of x and y do not match.
   196  func Copy(x, y Vector) {
   197  	if x.N != y.N {
   198  		panic(badLength)
   199  	}
   200  	blas32.Scopy(x.N, x.Data, x.Inc, y.Data, y.Inc)
   201  }
   202  
   203  // Axpy adds x scaled by alpha to y:
   204  //
   205  //	y[i] += alpha*x[i] for all i.
   206  //
   207  // Axpy will panic if the lengths of x and y do not match.
   208  func Axpy(alpha float32, x, y Vector) {
   209  	if x.N != y.N {
   210  		panic(badLength)
   211  	}
   212  	blas32.Saxpy(x.N, alpha, x.Data, x.Inc, y.Data, y.Inc)
   213  }
   214  
   215  // Rotg computes the parameters of a Givens plane rotation so that
   216  //
   217  //	⎡ c s⎤   ⎡a⎤   ⎡r⎤
   218  //	⎣-s c⎦ * ⎣b⎦ = ⎣0⎦
   219  //
   220  // where a and b are the Cartesian coordinates of a given point.
   221  // c, s, and r are defined as
   222  //
   223  //	r = ±Sqrt(a^2 + b^2),
   224  //	c = a/r, the cosine of the rotation angle,
   225  //	s = a/r, the sine of the rotation angle,
   226  //
   227  // and z is defined such that
   228  //
   229  //	if |a| > |b|,        z = s,
   230  //	otherwise if c != 0, z = 1/c,
   231  //	otherwise            z = 1.
   232  func Rotg(a, b float32) (c, s, r, z float32) {
   233  	return blas32.Srotg(a, b)
   234  }
   235  
   236  // Rotmg computes the modified Givens rotation. See
   237  // http://www.netlib.org/lapack/explore-html/df/deb/drotmg_8f.html
   238  // for more details.
   239  func Rotmg(d1, d2, b1, b2 float32) (p blas.SrotmParams, rd1, rd2, rb1 float32) {
   240  	return blas32.Srotmg(d1, d2, b1, b2)
   241  }
   242  
   243  // Rot applies a plane transformation to n points represented by the vectors x
   244  // and y:
   245  //
   246  //	x[i] =  c*x[i] + s*y[i],
   247  //	y[i] = -s*x[i] + c*y[i], for all i.
   248  func Rot(n int, x, y Vector, c, s float32) {
   249  	blas32.Srot(n, x.Data, x.Inc, y.Data, y.Inc, c, s)
   250  }
   251  
   252  // Rotm applies the modified Givens rotation to n points represented by the
   253  // vectors x and y.
   254  func Rotm(n int, x, y Vector, p blas.SrotmParams) {
   255  	blas32.Srotm(n, x.Data, x.Inc, y.Data, y.Inc, p)
   256  }
   257  
   258  // Scal scales the vector x by alpha:
   259  //
   260  //	x[i] *= alpha for all i.
   261  //
   262  // Scal will panic if the vector increment is negative.
   263  func Scal(alpha float32, x Vector) {
   264  	if x.Inc < 0 {
   265  		panic(negInc)
   266  	}
   267  	blas32.Sscal(x.N, alpha, x.Data, x.Inc)
   268  }
   269  
   270  // Level 2
   271  
   272  // Gemv computes
   273  //
   274  //	y = alpha * A * x + beta * y   if t == blas.NoTrans,
   275  //	y = alpha * Aᵀ * x + beta * y  if t == blas.Trans or blas.ConjTrans,
   276  //
   277  // where A is an m×n dense matrix, x and y are vectors, and alpha and beta are scalars.
   278  func Gemv(t blas.Transpose, alpha float32, a General, x Vector, beta float32, y Vector) {
   279  	blas32.Sgemv(t, a.Rows, a.Cols, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
   280  }
   281  
   282  // Gbmv computes
   283  //
   284  //	y = alpha * A * x + beta * y   if t == blas.NoTrans,
   285  //	y = alpha * Aᵀ * x + beta * y  if t == blas.Trans or blas.ConjTrans,
   286  //
   287  // where A is an m×n band matrix, x and y are vectors, and alpha and beta are scalars.
   288  func Gbmv(t blas.Transpose, alpha float32, a Band, x Vector, beta float32, y Vector) {
   289  	blas32.Sgbmv(t, a.Rows, a.Cols, a.KL, a.KU, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
   290  }
   291  
   292  // Trmv computes
   293  //
   294  //	x = A * x   if t == blas.NoTrans,
   295  //	x = Aᵀ * x  if t == blas.Trans or blas.ConjTrans,
   296  //
   297  // where A is an n×n triangular matrix, and x is a vector.
   298  func Trmv(t blas.Transpose, a Triangular, x Vector) {
   299  	blas32.Strmv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
   300  }
   301  
   302  // Tbmv computes
   303  //
   304  //	x = A * x   if t == blas.NoTrans,
   305  //	x = Aᵀ * x  if t == blas.Trans or blas.ConjTrans,
   306  //
   307  // where A is an n×n triangular band matrix, and x is a vector.
   308  func Tbmv(t blas.Transpose, a TriangularBand, x Vector) {
   309  	blas32.Stbmv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
   310  }
   311  
   312  // Tpmv computes
   313  //
   314  //	x = A * x   if t == blas.NoTrans,
   315  //	x = Aᵀ * x  if t == blas.Trans or blas.ConjTrans,
   316  //
   317  // where A is an n×n triangular matrix in packed format, and x is a vector.
   318  func Tpmv(t blas.Transpose, a TriangularPacked, x Vector) {
   319  	blas32.Stpmv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
   320  }
   321  
   322  // Trsv solves
   323  //
   324  //	A * x = b   if t == blas.NoTrans,
   325  //	Aᵀ * x = b  if t == blas.Trans or blas.ConjTrans,
   326  //
   327  // where A is an n×n triangular matrix, and x and b are vectors.
   328  //
   329  // At entry to the function, x contains the values of b, and the result is
   330  // stored in-place into x.
   331  //
   332  // No test for singularity or near-singularity is included in this
   333  // routine. Such tests must be performed before calling this routine.
   334  func Trsv(t blas.Transpose, a Triangular, x Vector) {
   335  	blas32.Strsv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
   336  }
   337  
   338  // Tbsv solves
   339  //
   340  //	A * x = b   if t == blas.NoTrans,
   341  //	Aᵀ * x = b  if t == blas.Trans or blas.ConjTrans,
   342  //
   343  // where A is an n×n triangular band matrix, and x and b are vectors.
   344  //
   345  // At entry to the function, x contains the values of b, and the result is
   346  // stored in place into x.
   347  //
   348  // No test for singularity or near-singularity is included in this
   349  // routine. Such tests must be performed before calling this routine.
   350  func Tbsv(t blas.Transpose, a TriangularBand, x Vector) {
   351  	blas32.Stbsv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
   352  }
   353  
   354  // Tpsv solves
   355  //
   356  //	A * x = b   if t == blas.NoTrans,
   357  //	Aᵀ * x = b  if t == blas.Trans or blas.ConjTrans,
   358  //
   359  // where A is an n×n triangular matrix in packed format, and x and b are
   360  // vectors.
   361  //
   362  // At entry to the function, x contains the values of b, and the result is
   363  // stored in place into x.
   364  //
   365  // No test for singularity or near-singularity is included in this
   366  // routine. Such tests must be performed before calling this routine.
   367  func Tpsv(t blas.Transpose, a TriangularPacked, x Vector) {
   368  	blas32.Stpsv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
   369  }
   370  
   371  // Symv computes
   372  //
   373  //	y = alpha * A * x + beta * y,
   374  //
   375  // where A is an n×n symmetric matrix, x and y are vectors, and alpha and
   376  // beta are scalars.
   377  func Symv(alpha float32, a Symmetric, x Vector, beta float32, y Vector) {
   378  	blas32.Ssymv(a.Uplo, a.N, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
   379  }
   380  
   381  // Sbmv performs
   382  //
   383  //	y = alpha * A * x + beta * y,
   384  //
   385  // where A is an n×n symmetric band matrix, x and y are vectors, and alpha
   386  // and beta are scalars.
   387  func Sbmv(alpha float32, a SymmetricBand, x Vector, beta float32, y Vector) {
   388  	blas32.Ssbmv(a.Uplo, a.N, a.K, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
   389  }
   390  
   391  // Spmv performs
   392  //
   393  //	y = alpha * A * x + beta * y,
   394  //
   395  // where A is an n×n symmetric matrix in packed format, x and y are vectors,
   396  // and alpha and beta are scalars.
   397  func Spmv(alpha float32, a SymmetricPacked, x Vector, beta float32, y Vector) {
   398  	blas32.Sspmv(a.Uplo, a.N, alpha, a.Data, x.Data, x.Inc, beta, y.Data, y.Inc)
   399  }
   400  
   401  // Ger performs a rank-1 update
   402  //
   403  //	A += alpha * x * yᵀ,
   404  //
   405  // where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar.
   406  func Ger(alpha float32, x, y Vector, a General) {
   407  	blas32.Sger(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
   408  }
   409  
   410  // Syr performs a rank-1 update
   411  //
   412  //	A += alpha * x * xᵀ,
   413  //
   414  // where A is an n×n symmetric matrix, x is a vector, and alpha is a scalar.
   415  func Syr(alpha float32, x Vector, a Symmetric) {
   416  	blas32.Ssyr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data, a.Stride)
   417  }
   418  
   419  // Spr performs the rank-1 update
   420  //
   421  //	A += alpha * x * xᵀ,
   422  //
   423  // where A is an n×n symmetric matrix in packed format, x is a vector, and
   424  // alpha is a scalar.
   425  func Spr(alpha float32, x Vector, a SymmetricPacked) {
   426  	blas32.Sspr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data)
   427  }
   428  
   429  // Syr2 performs a rank-2 update
   430  //
   431  //	A += alpha * x * yᵀ + alpha * y * xᵀ,
   432  //
   433  // where A is a symmetric n×n matrix, x and y are vectors, and alpha is a scalar.
   434  func Syr2(alpha float32, x, y Vector, a Symmetric) {
   435  	blas32.Ssyr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
   436  }
   437  
   438  // Spr2 performs a rank-2 update
   439  //
   440  //	A += alpha * x * yᵀ + alpha * y * xᵀ,
   441  //
   442  // where A is an n×n symmetric matrix in packed format, x and y are vectors,
   443  // and alpha is a scalar.
   444  func Spr2(alpha float32, x, y Vector, a SymmetricPacked) {
   445  	blas32.Sspr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data)
   446  }
   447  
   448  // Level 3
   449  
   450  // Gemm computes
   451  //
   452  //	C = alpha * A * B + beta * C,
   453  //
   454  // where A, B, and C are dense matrices, and alpha and beta are scalars.
   455  // tA and tB specify whether A or B are transposed.
   456  func Gemm(tA, tB blas.Transpose, alpha float32, a, b General, beta float32, c General) {
   457  	var m, n, k int
   458  	if tA == blas.NoTrans {
   459  		m, k = a.Rows, a.Cols
   460  	} else {
   461  		m, k = a.Cols, a.Rows
   462  	}
   463  	if tB == blas.NoTrans {
   464  		n = b.Cols
   465  	} else {
   466  		n = b.Rows
   467  	}
   468  	blas32.Sgemm(tA, tB, m, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
   469  }
   470  
   471  // Symm performs
   472  //
   473  //	C = alpha * A * B + beta * C  if s == blas.Left,
   474  //	C = alpha * B * A + beta * C  if s == blas.Right,
   475  //
   476  // where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and
   477  // alpha is a scalar.
   478  func Symm(s blas.Side, alpha float32, a Symmetric, b General, beta float32, c General) {
   479  	var m, n int
   480  	if s == blas.Left {
   481  		m, n = a.N, b.Cols
   482  	} else {
   483  		m, n = b.Rows, a.N
   484  	}
   485  	blas32.Ssymm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
   486  }
   487  
   488  // Syrk performs a symmetric rank-k update
   489  //
   490  //	C = alpha * A * Aᵀ + beta * C  if t == blas.NoTrans,
   491  //	C = alpha * Aᵀ * A + beta * C  if t == blas.Trans or blas.ConjTrans,
   492  //
   493  // where C is an n×n symmetric matrix, A is an n×k matrix if t == blas.NoTrans and
   494  // a k×n matrix otherwise, and alpha and beta are scalars.
   495  func Syrk(t blas.Transpose, alpha float32, a General, beta float32, c Symmetric) {
   496  	var n, k int
   497  	if t == blas.NoTrans {
   498  		n, k = a.Rows, a.Cols
   499  	} else {
   500  		n, k = a.Cols, a.Rows
   501  	}
   502  	blas32.Ssyrk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride)
   503  }
   504  
   505  // Syr2k performs a symmetric rank-2k update
   506  //
   507  //	C = alpha * A * Bᵀ + alpha * B * Aᵀ + beta * C  if t == blas.NoTrans,
   508  //	C = alpha * Aᵀ * B + alpha * Bᵀ * A + beta * C  if t == blas.Trans or blas.ConjTrans,
   509  //
   510  // where C is an n×n symmetric matrix, A and B are n×k matrices if t == NoTrans
   511  // and k×n matrices otherwise, and alpha and beta are scalars.
   512  func Syr2k(t blas.Transpose, alpha float32, a, b General, beta float32, c Symmetric) {
   513  	var n, k int
   514  	if t == blas.NoTrans {
   515  		n, k = a.Rows, a.Cols
   516  	} else {
   517  		n, k = a.Cols, a.Rows
   518  	}
   519  	blas32.Ssyr2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
   520  }
   521  
   522  // Trmm performs
   523  //
   524  //	B = alpha * A * B   if tA == blas.NoTrans and s == blas.Left,
   525  //	B = alpha * Aᵀ * B  if tA == blas.Trans or blas.ConjTrans, and s == blas.Left,
   526  //	B = alpha * B * A   if tA == blas.NoTrans and s == blas.Right,
   527  //	B = alpha * B * Aᵀ  if tA == blas.Trans or blas.ConjTrans, and s == blas.Right,
   528  //
   529  // where A is an n×n or m×m triangular matrix, B is an m×n matrix, and alpha is
   530  // a scalar.
   531  func Trmm(s blas.Side, tA blas.Transpose, alpha float32, a Triangular, b General) {
   532  	blas32.Strmm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
   533  }
   534  
   535  // Trsm solves
   536  //
   537  //	A * X = alpha * B   if tA == blas.NoTrans and s == blas.Left,
   538  //	Aᵀ * X = alpha * B  if tA == blas.Trans or blas.ConjTrans, and s == blas.Left,
   539  //	X * A = alpha * B   if tA == blas.NoTrans and s == blas.Right,
   540  //	X * Aᵀ = alpha * B  if tA == blas.Trans or blas.ConjTrans, and s == blas.Right,
   541  //
   542  // where A is an n×n or m×m triangular matrix, X and B are m×n matrices, and
   543  // alpha is a scalar.
   544  //
   545  // At entry to the function, X contains the values of B, and the result is
   546  // stored in-place into X.
   547  //
   548  // No check is made that A is invertible.
   549  func Trsm(s blas.Side, tA blas.Transpose, alpha float32, a Triangular, b General) {
   550  	blas32.Strsm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
   551  }