gorgonia.org/gorgonia@v0.9.17/blase/work.go (about)

     1  package blase
     2  
     3  /*
     4  #cgo CFLAGS: -g -O3 -std=gnu99
     5  
     6  #include <stdio.h>
     7  #include <stdint.h>
     8  #include "work.h"
     9  #include "cblas.h"
    10  
    11  uintptr_t process(struct fnargs* fa, int count) {
    12  	uintptr_t ret;
    13  
    14  	// printf("How much work: %d\n", count);
    15  
    16  	ret = processFn(&fa[0]);
    17  	if (count > 1) {
    18  		ret = processFn(&fa[1]);
    19  	}
    20  	if (count > 2) {
    21  		ret = processFn(&fa[2]);
    22  	}
    23  
    24  	return ret;
    25  }
    26  
    27  */
    28  import "C"
    29  import (
    30  	"unsafe"
    31  
    32  	"gonum.org/v1/gonum/blas"
    33  	"gonum.org/v1/netlib/blas/netlib"
    34  )
    35  
    36  var impl = newContext()
    37  
    38  var (
    39  	_ blas.Float32    = impl
    40  	_ blas.Float64    = impl
    41  	_ blas.Complex64  = impl
    42  	_ blas.Complex128 = impl
    43  )
    44  
    45  // Implementation returns a BLAS implementation that implements Float32, Float64, Complex64 and Complex128
    46  func Implementation() *context { return impl }
    47  
    48  const workbufLen int = 3
    49  
    50  //A Worker is a BLAS implementation that reports back if there is anything in the queue (WorkAvailable())
    51  // and a way to flush that queue
    52  type Worker interface {
    53  	WorkAvailable() <-chan struct{}
    54  	DoWork()
    55  }
    56  
    57  type call struct {
    58  	args *fnargs
    59  
    60  	/*
    61  		this flag only applies to any BLAS function that has a return value:
    62  			cblas_sdsdot
    63  			cblas_dsdot
    64  			cblas_sdot
    65  			cblas_ddot
    66  			cblas_cdotu_sub
    67  			cblas_snrm2
    68  			cblas_sasum
    69  			cblas_dnrm2
    70  			cblas_dasum
    71  			cblas_scnrm2
    72  			cblas_scasum
    73  			cblas_dznrm2
    74  			cblas_dzasum
    75  
    76  		These are routines that are recast as functions
    77  			cblas_cdotc_sub
    78  			cblas_zdotu_sub
    79  			cblas_zdotc_sub
    80  
    81  		Not sure about these (they return CBLAS_INDEX)
    82  			cblas_isamax
    83  			cblas_idamax
    84  			cblas_icamax
    85  			cblas_izamax
    86  
    87  		For the rest of the BLAS routines (i.e. they return void), don't set the blocking
    88  	*/
    89  	blocking bool
    90  }
    91  
    92  type context struct {
    93  	netlib.Implementation
    94  
    95  	workAvailable chan struct{}
    96  	work          chan call
    97  
    98  	fns   []C.struct_fnargs
    99  	queue []call
   100  }
   101  
   102  func newContext() *context {
   103  	return &context{
   104  		workAvailable: make(chan struct{}, 1),
   105  		work:          make(chan call, workbufLen),
   106  
   107  		fns:   make([]C.struct_fnargs, workbufLen, workbufLen),
   108  		queue: make([]call, 0, workbufLen),
   109  	}
   110  }
   111  
   112  func (ctx *context) enqueue(c call) {
   113  	ctx.work <- c
   114  	select {
   115  	case ctx.workAvailable <- struct{}{}:
   116  	default:
   117  	}
   118  	if c.blocking {
   119  		// do something
   120  		ctx.DoWork()
   121  	}
   122  }
   123  
   124  // DoWork retrieves as many work items as possible, puts them into a queue, and then processes the queue.
   125  // The function may return without doing any work.
   126  func (ctx *context) DoWork() {
   127  	for {
   128  		select {
   129  		case w := <-ctx.work:
   130  			ctx.queue = append(ctx.queue, w)
   131  		default:
   132  			return
   133  		}
   134  
   135  		blocking := ctx.queue[len(ctx.queue)-1].blocking
   136  	enqueue:
   137  		for len(ctx.queue) < cap(ctx.queue) && !blocking {
   138  			select {
   139  			case w := <-ctx.work:
   140  				ctx.queue = append(ctx.queue, w)
   141  				blocking = ctx.queue[len(ctx.queue)-1].blocking
   142  			default:
   143  				break enqueue
   144  
   145  			}
   146  
   147  			for i, c := range ctx.queue {
   148  				ctx.fns[i] = *(*C.struct_fnargs)(unsafe.Pointer(c.args))
   149  			}
   150  			C.process(&ctx.fns[0], C.int(len(ctx.queue)))
   151  
   152  			// clear queue
   153  			ctx.queue = ctx.queue[:0]
   154  		}
   155  	}
   156  }
   157  
   158  // WorkAvailable is the channel which users should subscribe to to know if there is work incoming.
   159  func (ctx *context) WorkAvailable() <-chan struct{} { return ctx.workAvailable }
   160  
   161  // String implements runtime.Stringer and fmt.Stringer. It returns the name of the BLAS implementation.
   162  func (*context) String() string { return "Blase" }