gorgonia.org/gorgonia@v0.9.17/blas_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"testing"
     5  
     6  	"gonum.org/v1/gonum/blas"
     7  	"gonum.org/v1/gonum/blas/gonum"
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  var gonumImpl = gonum.Implementation{}
    12  
    13  // testBLASImplementation of the interface
    14  type testBLASImplementation struct {
    15  	gonum.Implementation
    16  	used bool
    17  }
    18  
    19  // Sdsdot computes the dot product of the two vectors plus a constant
    20  //  alpha + ∑_i x[i]*y[i]
    21  //
    22  // Float32 implementations are autogenerated and not directly tested.
    23  // Sdsdot ...
    24  func (*testBLASImplementation) Sdsdot(n int, alpha float32, x []float32, incX int, y []float32, incY int) float32 {
    25  	return gonumImpl.Sdsdot(n, alpha, x, incX, y, incY)
    26  }
    27  
    28  // Dsdot computes the dot product of the two vectors
    29  //  ∑_i x[i]*y[i]
    30  //
    31  // Float32 implementations are autogenerated and not directly tested.
    32  // Dsdot ...
    33  func (*testBLASImplementation) Dsdot(n int, x []float32, incX int, y []float32, incY int) float64 {
    34  	return gonumImpl.Dsdot(n, x, incX, y, incY)
    35  }
    36  
    37  // Sdot ...
    38  func (*testBLASImplementation) Sdot(n int, x []float32, incX int, y []float32, incY int) float32 {
    39  	return gonumImpl.Sdot(n, x, incX, y, incY)
    40  }
    41  
    42  // Snrm2 ...
    43  func (*testBLASImplementation) Snrm2(n int, x []float32, incX int) float32 {
    44  	return gonumImpl.Snrm2(n, x, incX)
    45  }
    46  
    47  // Sasum ...
    48  func (*testBLASImplementation) Sasum(n int, x []float32, incX int) float32 {
    49  	return gonumImpl.Sasum(n, x, incX)
    50  }
    51  
    52  // Isamax ...
    53  func (*testBLASImplementation) Isamax(n int, x []float32, incX int) int {
    54  	return gonumImpl.Isamax(n, x, incX)
    55  }
    56  
    57  // Sswap ...
    58  func (*testBLASImplementation) Sswap(n int, x []float32, incX int, y []float32, incY int) {
    59  	gonumImpl.Sswap(n, x, incX, y, incY)
    60  }
    61  
    62  // Scopy ...
    63  func (*testBLASImplementation) Scopy(n int, x []float32, incX int, y []float32, incY int) {
    64  	gonumImpl.Scopy(n, x, incX, y, incY)
    65  }
    66  
    67  // Saxpy ...
    68  func (*testBLASImplementation) Saxpy(n int, alpha float32, x []float32, incX int, y []float32, incY int) {
    69  	gonumImpl.Saxpy(n, alpha, x, incX, y, incY)
    70  }
    71  
    72  // Srotg ...
    73  func (*testBLASImplementation) Srotg(a float32, b float32) (c float32, s float32, r float32, z float32) {
    74  	return gonumImpl.Srotg(a, b)
    75  }
    76  
    77  // Srotmg ...
    78  func (*testBLASImplementation) Srotmg(d1 float32, d2 float32, b1 float32, b2 float32) (p blas.SrotmParams, rd1 float32, rd2 float32, rb1 float32) {
    79  	return gonumImpl.Srotmg(d1, d2, b1, b2)
    80  }
    81  
    82  // Srot ...
    83  func (*testBLASImplementation) Srot(n int, x []float32, incX int, y []float32, incY int, c float32, s float32) {
    84  	gonumImpl.Srot(n, x, incX, y, incY, c, s)
    85  }
    86  
    87  // Srotm ...
    88  func (*testBLASImplementation) Srotm(n int, x []float32, incX int, y []float32, incY int, p blas.SrotmParams) {
    89  	gonumImpl.Srotm(n, x, incX, y, incY, p)
    90  }
    91  
    92  // Sscal ...
    93  func (*testBLASImplementation) Sscal(n int, alpha float32, x []float32, incX int) {
    94  	gonumImpl.Sscal(n, alpha, x, incX)
    95  }
    96  
    97  // Sgemv ...
    98  func (*testBLASImplementation) Sgemv(tA blas.Transpose, m int, n int, alpha float32, a []float32, lda int, x []float32, incX int, beta float32, y []float32, incY int) {
    99  	gonumImpl.Sgemv(tA, m, n, alpha, a, lda, x, incX, beta, y, incY)
   100  }
   101  
   102  // Sgbmv ...
   103  func (*testBLASImplementation) Sgbmv(tA blas.Transpose, m int, n int, kL int, kU int, alpha float32, a []float32, lda int, x []float32, incX int, beta float32, y []float32, incY int) {
   104  	gonumImpl.Sgbmv(tA, m, n, kL, kU, alpha, a, lda, x, incX, beta, y, incY)
   105  }
   106  
   107  // Strmv ...
   108  func (*testBLASImplementation) Strmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float32, lda int, x []float32, incX int) {
   109  	gonumImpl.Strmv(ul, tA, d, n, a, lda, x, incX)
   110  }
   111  
   112  // Stbmv ...
   113  func (*testBLASImplementation) Stbmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, k int, a []float32, lda int, x []float32, incX int) {
   114  	gonumImpl.Stbmv(ul, tA, d, n, k, a, lda, x, incX)
   115  }
   116  
   117  // Stpmv ...
   118  func (*testBLASImplementation) Stpmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, ap []float32, x []float32, incX int) {
   119  	gonumImpl.Stpmv(ul, tA, d, n, ap, x, incX)
   120  }
   121  
   122  // Strsv ...
   123  func (*testBLASImplementation) Strsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float32, lda int, x []float32, incX int) {
   124  	gonumImpl.Strsv(ul, tA, d, n, a, lda, x, incX)
   125  }
   126  
   127  // Stbsv ...
   128  func (*testBLASImplementation) Stbsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, k int, a []float32, lda int, x []float32, incX int) {
   129  	gonumImpl.Stbsv(ul, tA, d, n, k, a, lda, x, incX)
   130  }
   131  
   132  // Stpsv ...
   133  func (*testBLASImplementation) Stpsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, ap []float32, x []float32, incX int) {
   134  	gonumImpl.Stpsv(ul, tA, d, n, ap, x, incX)
   135  
   136  }
   137  
   138  // Ssymv ...
   139  func (*testBLASImplementation) Ssymv(ul blas.Uplo, n int, alpha float32, a []float32, lda int, x []float32, incX int, beta float32, y []float32, incY int) {
   140  	gonumImpl.Ssymv(ul, n, alpha, a, lda, x, incX, beta, y, incY)
   141  
   142  }
   143  
   144  // Ssbmv ...
   145  func (*testBLASImplementation) Ssbmv(ul blas.Uplo, n int, k int, alpha float32, a []float32, lda int, x []float32, incX int, beta float32, y []float32, incY int) {
   146  	gonumImpl.Ssbmv(ul, n, k, alpha, a, lda, x, incX, beta, y, incY)
   147  
   148  }
   149  
   150  // Sspmv ...
   151  func (*testBLASImplementation) Sspmv(ul blas.Uplo, n int, alpha float32, ap []float32, x []float32, incX int, beta float32, y []float32, incY int) {
   152  	gonumImpl.Sspmv(ul, n, alpha, ap, x, incX, beta, y, incY)
   153  
   154  }
   155  
   156  // Sger ...
   157  func (*testBLASImplementation) Sger(m int, n int, alpha float32, x []float32, incX int, y []float32, incY int, a []float32, lda int) {
   158  	gonumImpl.Sger(m, n, alpha, x, incX, y, incY, a, lda)
   159  
   160  }
   161  
   162  // Ssyr ...
   163  func (*testBLASImplementation) Ssyr(ul blas.Uplo, n int, alpha float32, x []float32, incX int, a []float32, lda int) {
   164  	gonumImpl.Ssyr(ul, n, alpha, x, incX, a, lda)
   165  
   166  }
   167  
   168  // Sspr ...
   169  func (*testBLASImplementation) Sspr(ul blas.Uplo, n int, alpha float32, x []float32, incX int, ap []float32) {
   170  	gonumImpl.Sspr(ul, n, alpha, x, incX, ap)
   171  
   172  }
   173  
   174  // Ssyr2 ...
   175  func (*testBLASImplementation) Ssyr2(ul blas.Uplo, n int, alpha float32, x []float32, incX int, y []float32, incY int, a []float32, lda int) {
   176  	gonumImpl.Ssyr2(ul, n, alpha, x, incX, y, incY, a, lda)
   177  
   178  }
   179  
   180  // Sspr2 ...
   181  func (*testBLASImplementation) Sspr2(ul blas.Uplo, n int, alpha float32, x []float32, incX int, y []float32, incY int, a []float32) {
   182  	gonumImpl.Sspr2(ul, n, alpha, x, incX, y, incY, a)
   183  
   184  }
   185  
   186  // Ssymm ...
   187  func (*testBLASImplementation) Ssymm(s blas.Side, ul blas.Uplo, m int, n int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) {
   188  	gonumImpl.Ssymm(s, ul, m, n, alpha, a, lda, b, ldb, beta, c, ldc)
   189  
   190  }
   191  
   192  // Ssyrk ...
   193  func (*testBLASImplementation) Ssyrk(ul blas.Uplo, t blas.Transpose, n int, k int, alpha float32, a []float32, lda int, beta float32, c []float32, ldc int) {
   194  	gonumImpl.Ssyrk(ul, t, n, k, alpha, a, lda, beta, c, ldc)
   195  
   196  }
   197  
   198  // Ssyr2k ...
   199  func (*testBLASImplementation) Ssyr2k(ul blas.Uplo, t blas.Transpose, n int, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) {
   200  	gonumImpl.Ssyr2k(ul, t, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
   201  
   202  }
   203  
   204  // Strmm ...
   205  func (*testBLASImplementation) Strmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m int, n int, alpha float32, a []float32, lda int, b []float32, ldb int) {
   206  	gonumImpl.Strmm(s, ul, tA, d, m, n, alpha, a, lda, b, ldb)
   207  
   208  }
   209  
   210  // Strsm ...
   211  func (*testBLASImplementation) Strsm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m int, n int, alpha float32, a []float32, lda int, b []float32, ldb int) {
   212  	gonumImpl.Strsm(s, ul, tA, d, m, n, alpha, a, lda, b, ldb)
   213  
   214  }
   215  
   216  // Ddot ...
   217  func (*testBLASImplementation) Ddot(n int, x []float64, incX int, y []float64, incY int) float64 {
   218  	return gonumImpl.Ddot(n, x, incX, y, incY)
   219  
   220  }
   221  
   222  // Dnrm2 ...
   223  func (*testBLASImplementation) Dnrm2(n int, x []float64, incX int) float64 {
   224  	return gonumImpl.Dnrm2(n, x, incX)
   225  
   226  }
   227  
   228  // Dasum ...
   229  func (*testBLASImplementation) Dasum(n int, x []float64, incX int) float64 {
   230  	return gonumImpl.Dasum(n, x, incX)
   231  
   232  }
   233  
   234  // Idamax ...
   235  func (*testBLASImplementation) Idamax(n int, x []float64, incX int) int {
   236  	return gonumImpl.Idamax(n, x, incX)
   237  
   238  }
   239  
   240  // Dswap ...
   241  func (*testBLASImplementation) Dswap(n int, x []float64, incX int, y []float64, incY int) {
   242  	gonumImpl.Dswap(n, x, incX, y, incY)
   243  
   244  }
   245  
   246  // Dcopy ...
   247  func (*testBLASImplementation) Dcopy(n int, x []float64, incX int, y []float64, incY int) {
   248  	gonumImpl.Dcopy(n, x, incX, y, incY)
   249  
   250  }
   251  
   252  // Daxpy ...
   253  func (*testBLASImplementation) Daxpy(n int, alpha float64, x []float64, incX int, y []float64, incY int) {
   254  	gonumImpl.Daxpy(n, alpha, x, incX, y, incY)
   255  
   256  }
   257  
   258  // Drotg ...
   259  func (*testBLASImplementation) Drotg(a float64, b float64) (c float64, s float64, r float64, z float64) {
   260  	return gonumImpl.Drotg(a, b)
   261  
   262  }
   263  
   264  // Drotmg ...
   265  func (*testBLASImplementation) Drotmg(d1 float64, d2 float64, b1 float64, b2 float64) (p blas.DrotmParams, rd1 float64, rd2 float64, rb1 float64) {
   266  	return gonumImpl.Drotmg(d1, d2, b1, b2)
   267  
   268  }
   269  
   270  // Drot ...
   271  func (*testBLASImplementation) Drot(n int, x []float64, incX int, y []float64, incY int, c float64, s float64) {
   272  	gonumImpl.Drot(n, x, incX, y, incY, c, s)
   273  
   274  }
   275  
   276  // Drotm ...
   277  func (*testBLASImplementation) Drotm(n int, x []float64, incX int, y []float64, incY int, p blas.DrotmParams) {
   278  	gonumImpl.Drotm(n, x, incX, y, incY, p)
   279  
   280  }
   281  
   282  // Dscal ...
   283  func (*testBLASImplementation) Dscal(n int, alpha float64, x []float64, incX int) {
   284  	gonumImpl.Dscal(n, alpha, x, incX)
   285  
   286  }
   287  
   288  // Dgemv ...
   289  func (*testBLASImplementation) Dgemv(tA blas.Transpose, m int, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int) {
   290  	gonumImpl.Dgemv(tA, m, n, alpha, a, lda, x, incX, beta, y, incY)
   291  
   292  }
   293  
   294  // Dgbmv ...
   295  func (*testBLASImplementation) Dgbmv(tA blas.Transpose, m int, n int, kL int, kU int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int) {
   296  	gonumImpl.Dgbmv(tA, m, n, kL, kU, alpha, a, lda, x, incX, beta, y, incY)
   297  
   298  }
   299  
   300  // Dtrmv ...
   301  func (*testBLASImplementation) Dtrmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, lda int, x []float64, incX int) {
   302  	gonumImpl.Dtrmv(ul, tA, d, n, a, lda, x, incX)
   303  
   304  }
   305  
   306  // Dtbmv ...
   307  func (*testBLASImplementation) Dtbmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, k int, a []float64, lda int, x []float64, incX int) {
   308  	gonumImpl.Dtbmv(ul, tA, d, n, k, a, lda, x, incX)
   309  
   310  }
   311  
   312  // Dtpmv ...
   313  func (*testBLASImplementation) Dtpmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, ap []float64, x []float64, incX int) {
   314  	gonumImpl.Dtpmv(ul, tA, d, n, ap, x, incX)
   315  
   316  }
   317  
   318  // Dtrsv ...
   319  func (*testBLASImplementation) Dtrsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, lda int, x []float64, incX int) {
   320  	gonumImpl.Dtrsv(ul, tA, d, n, a, lda, x, incX)
   321  
   322  }
   323  
   324  // Dtbsv ...
   325  func (*testBLASImplementation) Dtbsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, k int, a []float64, lda int, x []float64, incX int) {
   326  	gonumImpl.Dtbsv(ul, tA, d, n, k, a, lda, x, incX)
   327  
   328  }
   329  
   330  // Dtpsv ...
   331  func (*testBLASImplementation) Dtpsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, ap []float64, x []float64, incX int) {
   332  	gonumImpl.Dtpsv(ul, tA, d, n, ap, x, incX)
   333  
   334  }
   335  
   336  // Dsymv ...
   337  func (*testBLASImplementation) Dsymv(ul blas.Uplo, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int) {
   338  	gonumImpl.Dsymv(ul, n, alpha, a, lda, x, incX, beta, y, incY)
   339  
   340  }
   341  
   342  // Dsbmv ...
   343  func (*testBLASImplementation) Dsbmv(ul blas.Uplo, n int, k int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int) {
   344  	gonumImpl.Dsbmv(ul, n, k, alpha, a, lda, x, incX, beta, y, incY)
   345  
   346  }
   347  
   348  // Dspmv ...
   349  func (*testBLASImplementation) Dspmv(ul blas.Uplo, n int, alpha float64, ap []float64, x []float64, incX int, beta float64, y []float64, incY int) {
   350  	gonumImpl.Dspmv(ul, n, alpha, ap, x, incX, beta, y, incY)
   351  
   352  }
   353  
   354  // Dger ...
   355  func (*testBLASImplementation) Dger(m int, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int) {
   356  	gonumImpl.Dger(m, n, alpha, x, incX, y, incY, a, lda)
   357  
   358  }
   359  
   360  // Dsyr ...
   361  func (*testBLASImplementation) Dsyr(ul blas.Uplo, n int, alpha float64, x []float64, incX int, a []float64, lda int) {
   362  	gonumImpl.Dsyr(ul, n, alpha, x, incX, a, lda)
   363  
   364  }
   365  
   366  // Dspr ...
   367  func (*testBLASImplementation) Dspr(ul blas.Uplo, n int, alpha float64, x []float64, incX int, ap []float64) {
   368  	gonumImpl.Dspr(ul, n, alpha, x, incX, ap)
   369  
   370  }
   371  
   372  // Dsyr2 ...
   373  func (*testBLASImplementation) Dsyr2(ul blas.Uplo, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int) {
   374  	gonumImpl.Dsyr2(ul, n, alpha, x, incX, y, incY, a, lda)
   375  
   376  }
   377  
   378  // Dspr2 ...
   379  func (*testBLASImplementation) Dspr2(ul blas.Uplo, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64) {
   380  	gonumImpl.Dspr2(ul, n, alpha, x, incX, y, incY, a)
   381  
   382  }
   383  
   384  // Dsymm ...
   385  func (*testBLASImplementation) Dsymm(s blas.Side, ul blas.Uplo, m int, n int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
   386  	gonumImpl.Dsymm(s, ul, m, n, alpha, a, lda, b, ldb, beta, c, ldc)
   387  
   388  }
   389  
   390  // Dsyrk ...
   391  func (*testBLASImplementation) Dsyrk(ul blas.Uplo, t blas.Transpose, n int, k int, alpha float64, a []float64, lda int, beta float64, c []float64, ldc int) {
   392  	gonumImpl.Dsyrk(ul, t, n, k, alpha, a, lda, beta, c, ldc)
   393  
   394  }
   395  
   396  // Dsyr2k ...
   397  func (*testBLASImplementation) Dsyr2k(ul blas.Uplo, t blas.Transpose, n int, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
   398  	gonumImpl.Dsyr2k(ul, t, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
   399  
   400  }
   401  
   402  // Dtrmm ...
   403  func (*testBLASImplementation) Dtrmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m int, n int, alpha float64, a []float64, lda int, b []float64, ldb int) {
   404  	gonumImpl.Dtrmm(s, ul, tA, d, m, n, alpha, a, lda, b, ldb)
   405  
   406  }
   407  
   408  // Dtrsm ...
   409  func (*testBLASImplementation) Dtrsm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m int, n int, alpha float64, a []float64, lda int, b []float64, ldb int) {
   410  	gonumImpl.Dtrsm(s, ul, tA, d, m, n, alpha, a, lda, b, ldb)
   411  
   412  }
   413  
   414  // Sgemm ...
   415  func (t *testBLASImplementation) Sgemm(tA blas.Transpose, tB blas.Transpose, m int, n int, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) {
   416  	t.used = true
   417  	gonumImpl.Sgemm(tA, tB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
   418  
   419  }
   420  
   421  // Dgemm ...
   422  func (*testBLASImplementation) Dgemm(tA blas.Transpose, tB blas.Transpose, m int, n int, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
   423  	gonumImpl.Dgemm(tA, tB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
   424  
   425  }
   426  func TestUse(t *testing.T) {
   427  	blasI := &testBLASImplementation{}
   428  	Use(blasI)
   429  	g := NewGraph()
   430  	x := NodeFromAny(g, tensor.New(
   431  		tensor.WithShape(1, 1, 7, 5),
   432  		tensor.WithBacking([]float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34})))
   433  	filter := NodeFromAny(g, tensor.New(
   434  		tensor.WithShape(1, 1, 3, 3),
   435  		tensor.WithBacking([]float32{1, 1, 1, 1, 1, 1, 1, 1, 1})))
   436  	y := Must(Conv2d(x, filter, []int{3, 3}, []int{0, 0}, []int{2, 2}, []int{1, 1}))
   437  	m := NewTapeMachine(g)
   438  	if err := m.RunAll(); err != nil {
   439  		t.Fatal(err)
   440  	}
   441  	//54 72 144 162 234 252
   442  	output := y.Value().Data().([]float32)
   443  	if output[0] != 54 ||
   444  		output[1] != 72 ||
   445  		output[2] != 144 ||
   446  		output[3] != 162 ||
   447  		output[4] != 234 ||
   448  		output[5] != 252 {
   449  		t.Fatal("wrong computation value")
   450  	}
   451  
   452  	if !blasI.used {
   453  		t.Fail()
   454  	}
   455  	if WhichBLAS() != blasI {
   456  		t.Fail()
   457  	}
   458  }