gorgonia.org/gorgonia@v0.9.17/blas_test.go (about) 1 package gorgonia 2 3 import ( 4 "testing" 5 6 "gonum.org/v1/gonum/blas" 7 "gonum.org/v1/gonum/blas/gonum" 8 "gorgonia.org/tensor" 9 ) 10 11 var gonumImpl = gonum.Implementation{} 12 13 // testBLASImplementation of the interface 14 type testBLASImplementation struct { 15 gonum.Implementation 16 used bool 17 } 18 19 // Sdsdot computes the dot product of the two vectors plus a constant 20 // alpha + ∑_i x[i]*y[i] 21 // 22 // Float32 implementations are autogenerated and not directly tested. 23 // Sdsdot ... 24 func (*testBLASImplementation) Sdsdot(n int, alpha float32, x []float32, incX int, y []float32, incY int) float32 { 25 return gonumImpl.Sdsdot(n, alpha, x, incX, y, incY) 26 } 27 28 // Dsdot computes the dot product of the two vectors 29 // ∑_i x[i]*y[i] 30 // 31 // Float32 implementations are autogenerated and not directly tested. 32 // Dsdot ... 33 func (*testBLASImplementation) Dsdot(n int, x []float32, incX int, y []float32, incY int) float64 { 34 return gonumImpl.Dsdot(n, x, incX, y, incY) 35 } 36 37 // Sdot ... 38 func (*testBLASImplementation) Sdot(n int, x []float32, incX int, y []float32, incY int) float32 { 39 return gonumImpl.Sdot(n, x, incX, y, incY) 40 } 41 42 // Snrm2 ... 43 func (*testBLASImplementation) Snrm2(n int, x []float32, incX int) float32 { 44 return gonumImpl.Snrm2(n, x, incX) 45 } 46 47 // Sasum ... 48 func (*testBLASImplementation) Sasum(n int, x []float32, incX int) float32 { 49 return gonumImpl.Sasum(n, x, incX) 50 } 51 52 // Isamax ... 53 func (*testBLASImplementation) Isamax(n int, x []float32, incX int) int { 54 return gonumImpl.Isamax(n, x, incX) 55 } 56 57 // Sswap ... 58 func (*testBLASImplementation) Sswap(n int, x []float32, incX int, y []float32, incY int) { 59 gonumImpl.Sswap(n, x, incX, y, incY) 60 } 61 62 // Scopy ... 63 func (*testBLASImplementation) Scopy(n int, x []float32, incX int, y []float32, incY int) { 64 gonumImpl.Scopy(n, x, incX, y, incY) 65 } 66 67 // Saxpy ... 68 func (*testBLASImplementation) Saxpy(n int, alpha float32, x []float32, incX int, y []float32, incY int) { 69 gonumImpl.Saxpy(n, alpha, x, incX, y, incY) 70 } 71 72 // Srotg ... 73 func (*testBLASImplementation) Srotg(a float32, b float32) (c float32, s float32, r float32, z float32) { 74 return gonumImpl.Srotg(a, b) 75 } 76 77 // Srotmg ... 78 func (*testBLASImplementation) Srotmg(d1 float32, d2 float32, b1 float32, b2 float32) (p blas.SrotmParams, rd1 float32, rd2 float32, rb1 float32) { 79 return gonumImpl.Srotmg(d1, d2, b1, b2) 80 } 81 82 // Srot ... 83 func (*testBLASImplementation) Srot(n int, x []float32, incX int, y []float32, incY int, c float32, s float32) { 84 gonumImpl.Srot(n, x, incX, y, incY, c, s) 85 } 86 87 // Srotm ... 88 func (*testBLASImplementation) Srotm(n int, x []float32, incX int, y []float32, incY int, p blas.SrotmParams) { 89 gonumImpl.Srotm(n, x, incX, y, incY, p) 90 } 91 92 // Sscal ... 93 func (*testBLASImplementation) Sscal(n int, alpha float32, x []float32, incX int) { 94 gonumImpl.Sscal(n, alpha, x, incX) 95 } 96 97 // Sgemv ... 98 func (*testBLASImplementation) Sgemv(tA blas.Transpose, m int, n int, alpha float32, a []float32, lda int, x []float32, incX int, beta float32, y []float32, incY int) { 99 gonumImpl.Sgemv(tA, m, n, alpha, a, lda, x, incX, beta, y, incY) 100 } 101 102 // Sgbmv ... 103 func (*testBLASImplementation) Sgbmv(tA blas.Transpose, m int, n int, kL int, kU int, alpha float32, a []float32, lda int, x []float32, incX int, beta float32, y []float32, incY int) { 104 gonumImpl.Sgbmv(tA, m, n, kL, kU, alpha, a, lda, x, incX, beta, y, incY) 105 } 106 107 // Strmv ... 108 func (*testBLASImplementation) Strmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float32, lda int, x []float32, incX int) { 109 gonumImpl.Strmv(ul, tA, d, n, a, lda, x, incX) 110 } 111 112 // Stbmv ... 113 func (*testBLASImplementation) Stbmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, k int, a []float32, lda int, x []float32, incX int) { 114 gonumImpl.Stbmv(ul, tA, d, n, k, a, lda, x, incX) 115 } 116 117 // Stpmv ... 118 func (*testBLASImplementation) Stpmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, ap []float32, x []float32, incX int) { 119 gonumImpl.Stpmv(ul, tA, d, n, ap, x, incX) 120 } 121 122 // Strsv ... 123 func (*testBLASImplementation) Strsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float32, lda int, x []float32, incX int) { 124 gonumImpl.Strsv(ul, tA, d, n, a, lda, x, incX) 125 } 126 127 // Stbsv ... 128 func (*testBLASImplementation) Stbsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, k int, a []float32, lda int, x []float32, incX int) { 129 gonumImpl.Stbsv(ul, tA, d, n, k, a, lda, x, incX) 130 } 131 132 // Stpsv ... 133 func (*testBLASImplementation) Stpsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, ap []float32, x []float32, incX int) { 134 gonumImpl.Stpsv(ul, tA, d, n, ap, x, incX) 135 136 } 137 138 // Ssymv ... 139 func (*testBLASImplementation) Ssymv(ul blas.Uplo, n int, alpha float32, a []float32, lda int, x []float32, incX int, beta float32, y []float32, incY int) { 140 gonumImpl.Ssymv(ul, n, alpha, a, lda, x, incX, beta, y, incY) 141 142 } 143 144 // Ssbmv ... 145 func (*testBLASImplementation) Ssbmv(ul blas.Uplo, n int, k int, alpha float32, a []float32, lda int, x []float32, incX int, beta float32, y []float32, incY int) { 146 gonumImpl.Ssbmv(ul, n, k, alpha, a, lda, x, incX, beta, y, incY) 147 148 } 149 150 // Sspmv ... 151 func (*testBLASImplementation) Sspmv(ul blas.Uplo, n int, alpha float32, ap []float32, x []float32, incX int, beta float32, y []float32, incY int) { 152 gonumImpl.Sspmv(ul, n, alpha, ap, x, incX, beta, y, incY) 153 154 } 155 156 // Sger ... 157 func (*testBLASImplementation) Sger(m int, n int, alpha float32, x []float32, incX int, y []float32, incY int, a []float32, lda int) { 158 gonumImpl.Sger(m, n, alpha, x, incX, y, incY, a, lda) 159 160 } 161 162 // Ssyr ... 163 func (*testBLASImplementation) Ssyr(ul blas.Uplo, n int, alpha float32, x []float32, incX int, a []float32, lda int) { 164 gonumImpl.Ssyr(ul, n, alpha, x, incX, a, lda) 165 166 } 167 168 // Sspr ... 169 func (*testBLASImplementation) Sspr(ul blas.Uplo, n int, alpha float32, x []float32, incX int, ap []float32) { 170 gonumImpl.Sspr(ul, n, alpha, x, incX, ap) 171 172 } 173 174 // Ssyr2 ... 175 func (*testBLASImplementation) Ssyr2(ul blas.Uplo, n int, alpha float32, x []float32, incX int, y []float32, incY int, a []float32, lda int) { 176 gonumImpl.Ssyr2(ul, n, alpha, x, incX, y, incY, a, lda) 177 178 } 179 180 // Sspr2 ... 181 func (*testBLASImplementation) Sspr2(ul blas.Uplo, n int, alpha float32, x []float32, incX int, y []float32, incY int, a []float32) { 182 gonumImpl.Sspr2(ul, n, alpha, x, incX, y, incY, a) 183 184 } 185 186 // Ssymm ... 187 func (*testBLASImplementation) Ssymm(s blas.Side, ul blas.Uplo, m int, n int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) { 188 gonumImpl.Ssymm(s, ul, m, n, alpha, a, lda, b, ldb, beta, c, ldc) 189 190 } 191 192 // Ssyrk ... 193 func (*testBLASImplementation) Ssyrk(ul blas.Uplo, t blas.Transpose, n int, k int, alpha float32, a []float32, lda int, beta float32, c []float32, ldc int) { 194 gonumImpl.Ssyrk(ul, t, n, k, alpha, a, lda, beta, c, ldc) 195 196 } 197 198 // Ssyr2k ... 199 func (*testBLASImplementation) Ssyr2k(ul blas.Uplo, t blas.Transpose, n int, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) { 200 gonumImpl.Ssyr2k(ul, t, n, k, alpha, a, lda, b, ldb, beta, c, ldc) 201 202 } 203 204 // Strmm ... 205 func (*testBLASImplementation) Strmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m int, n int, alpha float32, a []float32, lda int, b []float32, ldb int) { 206 gonumImpl.Strmm(s, ul, tA, d, m, n, alpha, a, lda, b, ldb) 207 208 } 209 210 // Strsm ... 211 func (*testBLASImplementation) Strsm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m int, n int, alpha float32, a []float32, lda int, b []float32, ldb int) { 212 gonumImpl.Strsm(s, ul, tA, d, m, n, alpha, a, lda, b, ldb) 213 214 } 215 216 // Ddot ... 217 func (*testBLASImplementation) Ddot(n int, x []float64, incX int, y []float64, incY int) float64 { 218 return gonumImpl.Ddot(n, x, incX, y, incY) 219 220 } 221 222 // Dnrm2 ... 223 func (*testBLASImplementation) Dnrm2(n int, x []float64, incX int) float64 { 224 return gonumImpl.Dnrm2(n, x, incX) 225 226 } 227 228 // Dasum ... 229 func (*testBLASImplementation) Dasum(n int, x []float64, incX int) float64 { 230 return gonumImpl.Dasum(n, x, incX) 231 232 } 233 234 // Idamax ... 235 func (*testBLASImplementation) Idamax(n int, x []float64, incX int) int { 236 return gonumImpl.Idamax(n, x, incX) 237 238 } 239 240 // Dswap ... 241 func (*testBLASImplementation) Dswap(n int, x []float64, incX int, y []float64, incY int) { 242 gonumImpl.Dswap(n, x, incX, y, incY) 243 244 } 245 246 // Dcopy ... 247 func (*testBLASImplementation) Dcopy(n int, x []float64, incX int, y []float64, incY int) { 248 gonumImpl.Dcopy(n, x, incX, y, incY) 249 250 } 251 252 // Daxpy ... 253 func (*testBLASImplementation) Daxpy(n int, alpha float64, x []float64, incX int, y []float64, incY int) { 254 gonumImpl.Daxpy(n, alpha, x, incX, y, incY) 255 256 } 257 258 // Drotg ... 259 func (*testBLASImplementation) Drotg(a float64, b float64) (c float64, s float64, r float64, z float64) { 260 return gonumImpl.Drotg(a, b) 261 262 } 263 264 // Drotmg ... 265 func (*testBLASImplementation) Drotmg(d1 float64, d2 float64, b1 float64, b2 float64) (p blas.DrotmParams, rd1 float64, rd2 float64, rb1 float64) { 266 return gonumImpl.Drotmg(d1, d2, b1, b2) 267 268 } 269 270 // Drot ... 271 func (*testBLASImplementation) Drot(n int, x []float64, incX int, y []float64, incY int, c float64, s float64) { 272 gonumImpl.Drot(n, x, incX, y, incY, c, s) 273 274 } 275 276 // Drotm ... 277 func (*testBLASImplementation) Drotm(n int, x []float64, incX int, y []float64, incY int, p blas.DrotmParams) { 278 gonumImpl.Drotm(n, x, incX, y, incY, p) 279 280 } 281 282 // Dscal ... 283 func (*testBLASImplementation) Dscal(n int, alpha float64, x []float64, incX int) { 284 gonumImpl.Dscal(n, alpha, x, incX) 285 286 } 287 288 // Dgemv ... 289 func (*testBLASImplementation) Dgemv(tA blas.Transpose, m int, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int) { 290 gonumImpl.Dgemv(tA, m, n, alpha, a, lda, x, incX, beta, y, incY) 291 292 } 293 294 // Dgbmv ... 295 func (*testBLASImplementation) Dgbmv(tA blas.Transpose, m int, n int, kL int, kU int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int) { 296 gonumImpl.Dgbmv(tA, m, n, kL, kU, alpha, a, lda, x, incX, beta, y, incY) 297 298 } 299 300 // Dtrmv ... 301 func (*testBLASImplementation) Dtrmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, lda int, x []float64, incX int) { 302 gonumImpl.Dtrmv(ul, tA, d, n, a, lda, x, incX) 303 304 } 305 306 // Dtbmv ... 307 func (*testBLASImplementation) Dtbmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, k int, a []float64, lda int, x []float64, incX int) { 308 gonumImpl.Dtbmv(ul, tA, d, n, k, a, lda, x, incX) 309 310 } 311 312 // Dtpmv ... 313 func (*testBLASImplementation) Dtpmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, ap []float64, x []float64, incX int) { 314 gonumImpl.Dtpmv(ul, tA, d, n, ap, x, incX) 315 316 } 317 318 // Dtrsv ... 319 func (*testBLASImplementation) Dtrsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, lda int, x []float64, incX int) { 320 gonumImpl.Dtrsv(ul, tA, d, n, a, lda, x, incX) 321 322 } 323 324 // Dtbsv ... 325 func (*testBLASImplementation) Dtbsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, k int, a []float64, lda int, x []float64, incX int) { 326 gonumImpl.Dtbsv(ul, tA, d, n, k, a, lda, x, incX) 327 328 } 329 330 // Dtpsv ... 331 func (*testBLASImplementation) Dtpsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, ap []float64, x []float64, incX int) { 332 gonumImpl.Dtpsv(ul, tA, d, n, ap, x, incX) 333 334 } 335 336 // Dsymv ... 337 func (*testBLASImplementation) Dsymv(ul blas.Uplo, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int) { 338 gonumImpl.Dsymv(ul, n, alpha, a, lda, x, incX, beta, y, incY) 339 340 } 341 342 // Dsbmv ... 343 func (*testBLASImplementation) Dsbmv(ul blas.Uplo, n int, k int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int) { 344 gonumImpl.Dsbmv(ul, n, k, alpha, a, lda, x, incX, beta, y, incY) 345 346 } 347 348 // Dspmv ... 349 func (*testBLASImplementation) Dspmv(ul blas.Uplo, n int, alpha float64, ap []float64, x []float64, incX int, beta float64, y []float64, incY int) { 350 gonumImpl.Dspmv(ul, n, alpha, ap, x, incX, beta, y, incY) 351 352 } 353 354 // Dger ... 355 func (*testBLASImplementation) Dger(m int, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int) { 356 gonumImpl.Dger(m, n, alpha, x, incX, y, incY, a, lda) 357 358 } 359 360 // Dsyr ... 361 func (*testBLASImplementation) Dsyr(ul blas.Uplo, n int, alpha float64, x []float64, incX int, a []float64, lda int) { 362 gonumImpl.Dsyr(ul, n, alpha, x, incX, a, lda) 363 364 } 365 366 // Dspr ... 367 func (*testBLASImplementation) Dspr(ul blas.Uplo, n int, alpha float64, x []float64, incX int, ap []float64) { 368 gonumImpl.Dspr(ul, n, alpha, x, incX, ap) 369 370 } 371 372 // Dsyr2 ... 373 func (*testBLASImplementation) Dsyr2(ul blas.Uplo, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int) { 374 gonumImpl.Dsyr2(ul, n, alpha, x, incX, y, incY, a, lda) 375 376 } 377 378 // Dspr2 ... 379 func (*testBLASImplementation) Dspr2(ul blas.Uplo, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64) { 380 gonumImpl.Dspr2(ul, n, alpha, x, incX, y, incY, a) 381 382 } 383 384 // Dsymm ... 385 func (*testBLASImplementation) Dsymm(s blas.Side, ul blas.Uplo, m int, n int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) { 386 gonumImpl.Dsymm(s, ul, m, n, alpha, a, lda, b, ldb, beta, c, ldc) 387 388 } 389 390 // Dsyrk ... 391 func (*testBLASImplementation) Dsyrk(ul blas.Uplo, t blas.Transpose, n int, k int, alpha float64, a []float64, lda int, beta float64, c []float64, ldc int) { 392 gonumImpl.Dsyrk(ul, t, n, k, alpha, a, lda, beta, c, ldc) 393 394 } 395 396 // Dsyr2k ... 397 func (*testBLASImplementation) Dsyr2k(ul blas.Uplo, t blas.Transpose, n int, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) { 398 gonumImpl.Dsyr2k(ul, t, n, k, alpha, a, lda, b, ldb, beta, c, ldc) 399 400 } 401 402 // Dtrmm ... 403 func (*testBLASImplementation) Dtrmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m int, n int, alpha float64, a []float64, lda int, b []float64, ldb int) { 404 gonumImpl.Dtrmm(s, ul, tA, d, m, n, alpha, a, lda, b, ldb) 405 406 } 407 408 // Dtrsm ... 409 func (*testBLASImplementation) Dtrsm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m int, n int, alpha float64, a []float64, lda int, b []float64, ldb int) { 410 gonumImpl.Dtrsm(s, ul, tA, d, m, n, alpha, a, lda, b, ldb) 411 412 } 413 414 // Sgemm ... 415 func (t *testBLASImplementation) Sgemm(tA blas.Transpose, tB blas.Transpose, m int, n int, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) { 416 t.used = true 417 gonumImpl.Sgemm(tA, tB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) 418 419 } 420 421 // Dgemm ... 422 func (*testBLASImplementation) Dgemm(tA blas.Transpose, tB blas.Transpose, m int, n int, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) { 423 gonumImpl.Dgemm(tA, tB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) 424 425 } 426 func TestUse(t *testing.T) { 427 blasI := &testBLASImplementation{} 428 Use(blasI) 429 g := NewGraph() 430 x := NodeFromAny(g, tensor.New( 431 tensor.WithShape(1, 1, 7, 5), 432 tensor.WithBacking([]float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34}))) 433 filter := NodeFromAny(g, tensor.New( 434 tensor.WithShape(1, 1, 3, 3), 435 tensor.WithBacking([]float32{1, 1, 1, 1, 1, 1, 1, 1, 1}))) 436 y := Must(Conv2d(x, filter, []int{3, 3}, []int{0, 0}, []int{2, 2}, []int{1, 1})) 437 m := NewTapeMachine(g) 438 if err := m.RunAll(); err != nil { 439 t.Fatal(err) 440 } 441 //54 72 144 162 234 252 442 output := y.Value().Data().([]float32) 443 if output[0] != 54 || 444 output[1] != 72 || 445 output[2] != 144 || 446 output[3] != 162 || 447 output[4] != 234 || 448 output[5] != 252 { 449 t.Fatal("wrong computation value") 450 } 451 452 if !blasI.used { 453 t.Fail() 454 } 455 if WhichBLAS() != blasI { 456 t.Fail() 457 } 458 }