gorgonia.org/gorgonia@v0.9.17/blase/blas.go (about)

     1  // Package blase is a thin wrapper over Gonum's BLAS interface that provides a queue
     2  // so that cgo calls are batched. This package was created so MKL usage can be improved.
     3  //
     4  // Any cblas function that is not handled will result in the blocking BLAS call being called
     5  package blase
     6  
     7  /*
     8  #include <stdint.h>
     9  #include <stdio.h>
    10  #include "cblas.h"
    11  #include "work.h"
    12  
    13  // useful to help print stuff to see if things are correct
    14  void prrintfnargs(struct fnargs* args){
    15  	printf("HELLO\n");
    16  	printf("fn: %d\n", args->fn);
    17  	printf("o: %d\n", args->order);
    18  	printf("tA: %d\n",args->tA);
    19  	printf("tB: %d\n",args->tB);
    20  	printf("----\n");
    21  	// printf("a0: %f\n", (double*)args->a0);
    22  	// printf("a1: %f\n", (double*)args->a1);
    23  	// printf("a2: %f\n", (double*)args->a2);
    24  	// printf("a3: %f\n", (double*)args->a3);
    25  	printf("----\n");
    26  	printf("i0: %d\n", args->i0);
    27  	printf("i1: %d\n", args->i1);
    28  	printf("i2: %d\n", args->i2);
    29  	printf("i3: %d\n", args->i3);
    30  	printf("i4: %d\n", args->i4);
    31  	printf("i5: %d\n", args->i5);
    32  	printf("----\n");
    33  	printf("d0: %f\n", args->d0);
    34  	printf("d1: %f\n", args->d1);
    35  	printf("d2: %f\n", args->d2);
    36  	printf("d3: %f\n", args->d3);
    37  	printf("=========\n");
    38  }
    39  */
    40  import "C"
    41  
    42  import (
    43  	"unsafe"
    44  
    45  	"gonum.org/v1/gonum/blas"
    46  )
    47  
    48  const rowMajor = 101 // rowMajor and rowMajor ONLY
    49  
    50  func (ctx *context) Dgemm(tA blas.Transpose, tB blas.Transpose, m int, n int, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
    51  	fn := &fnargs{
    52  		fn:    C.cblasFn(fn_cblas_dgemm),
    53  		order: C.cblas_order(rowMajor),
    54  		tA:    C.cblas_transpose(tA),
    55  		tB:    C.cblas_transpose(tB),
    56  		i0:    C.int(m),
    57  		i1:    C.int(n),
    58  		i2:    C.int(k),
    59  		d0:    C.double(alpha),
    60  		a0:    uintptr(unsafe.Pointer(&a[0])),
    61  		i3:    C.int(lda),
    62  		a1:    uintptr(unsafe.Pointer(&b[0])),
    63  		i4:    C.int(ldb),
    64  		d1:    C.double(beta),
    65  		a2:    uintptr(unsafe.Pointer(&c[0])),
    66  		i5:    C.int(ldc),
    67  	}
    68  	call := call{args: fn, blocking: false}
    69  	ctx.enqueue(call)
    70  }
    71  
    72  func (ctx *context) Dgemv(tA blas.Transpose, m int, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int) {
    73  	fn := &fnargs{
    74  		fn:    C.cblasFn(fn_cblas_dgemv),
    75  		order: C.cblas_order(rowMajor),
    76  		tA:    C.cblas_transpose(tA),
    77  		i0:    C.int(m),
    78  		i1:    C.int(m),
    79  		d0:    C.double(alpha),
    80  		a0:    uintptr(unsafe.Pointer(&a[0])),
    81  		i2:    C.int(lda),
    82  		a1:    uintptr(unsafe.Pointer(&x[0])),
    83  		i3:    C.int(incX),
    84  		d1:    C.double(beta),
    85  		a2:    uintptr(unsafe.Pointer(&y[0])),
    86  		i4:    C.int(incY),
    87  	}
    88  
    89  	// Cs := fn.toCStruct()
    90  	// C.prrintfnargs(&Cs)
    91  	// fmt.Println("Sleeping")
    92  	// time.Sleep(10)
    93  	// fmt.Println("Slept")
    94  	// ctx.Implementation.Dgemv(tA, m, n, alpha, a, lda, x, incX, beta, y, incY)
    95  
    96  	call := call{args: fn, blocking: false}
    97  	ctx.enqueue(call)
    98  }