gorgonia.org/gorgonia@v0.9.17/cuda/engine.go (about)

     1  package cuda
     2  
     3  import "C"
     4  
     5  import (
     6  	"fmt"
     7  	"sync"
     8  	"unsafe"
     9  
    10  	"github.com/pkg/errors"
    11  	"gorgonia.org/cu"
    12  	"gorgonia.org/cu/blas"
    13  	"gorgonia.org/cu/dnn"
    14  	"gorgonia.org/tensor"
    15  )
    16  
    17  var (
    18  	_ tensor.Adder = &Engine{}
    19  	_ tensor.Suber = &Engine{}
    20  	_ tensor.Muler = &Engine{}
    21  	_ tensor.Diver = &Engine{}
    22  	_ tensor.Power = &Engine{}
    23  	_ tensor.Moder = &Engine{}
    24  	// _ tensor.FMAer       = &Engine{}
    25  	_ tensor.MatMuler    = &Engine{}
    26  	_ tensor.MatVecMuler = &Engine{}
    27  	_ tensor.OuterProder = &Engine{}
    28  	// _ tensor.Dotter      = &Engine{}
    29  	// _ tensor.SVDer       = &Engine{}
    30  	_ tensor.Lter   = &Engine{}
    31  	_ tensor.Lteer  = &Engine{}
    32  	_ tensor.Gter   = &Engine{}
    33  	_ tensor.Gteer  = &Engine{}
    34  	_ tensor.ElEqer = &Engine{}
    35  )
    36  
    37  // Engine is a CUDA engine
    38  type Engine struct {
    39  	tensor.Engine
    40  	sync.Mutex
    41  
    42  	a bfc
    43  	b cublas.Standard
    44  	c cu.BatchedContext
    45  	d cu.Device
    46  	f map[string]cu.Function
    47  	m map[string]cu.Module
    48  	n cudnn.Context
    49  
    50  	warp int
    51  	mtpb int
    52  	mgdx int
    53  	mgdy int
    54  	mgdz int
    55  	mbdx int
    56  	mbdy int
    57  	mbdz int
    58  
    59  	freeMem  int64
    60  	totalMem int64
    61  
    62  	syncChan      chan struct{}
    63  	finishChan    chan struct{}
    64  	finishChan2   chan struct{}
    65  	workAvailable chan bool
    66  	err           error
    67  	initialized   bool
    68  	running       bool
    69  }
    70  
    71  // AllocAccessible returns true because the engine return Go-accessible memory pointers
    72  func (e *Engine) AllocAccessible() bool { return true }
    73  
    74  // Alloc allocates a chunk of certain size from engine memory
    75  func (e *Engine) Alloc(size int64) (tensor.Memory, error) {
    76  	// return e.c.MemAllocManaged(size, cu.AttachGlobal)
    77  	return e.Get(size)
    78  }
    79  
    80  // AllocFlags returns allocation flags
    81  func (e *Engine) AllocFlags() (tensor.MemoryFlag, tensor.DataOrder) {
    82  	return tensor.MakeMemoryFlag(tensor.ManuallyManaged), tensor.ColMajor
    83  }
    84  
    85  // Free rees memory
    86  func (e *Engine) Free(mem tensor.Memory, size int64) error {
    87  	// e.c.MemFree(mem.(cu.DevicePtr))
    88  	// return e.c.Error()
    89  	e.Put(mem, size)
    90  	return nil
    91  }
    92  
    93  func (e *Engine) Memset(mem tensor.Memory, val interface{}) error {
    94  	panic("not implemented")
    95  }
    96  
    97  func (e *Engine) Memclr(mem tensor.Memory) {
    98  	panic("not implemented")
    99  }
   100  
   101  // Memcpy is part of the implementation of tensor.Engine. It is eager, and will signal the context to actually do work.
   102  // The memory that will be copied is up to the smallest of sizes between dst and src.
   103  // i.e. if dst is 8 bytes and src is 16 bytes, only the first 8 bytes of src will be copied.
   104  // Likewise, if dst is 20 bytes and src is 3 bytes, only 3 bytes will be copied.
   105  func (e *Engine) Memcpy(dst tensor.Memory, src tensor.Memory) error {
   106  	sSize := src.MemSize()
   107  	dSize := dst.MemSize()
   108  
   109  	var size int64
   110  	switch {
   111  	case dSize < sSize:
   112  		size = int64(dSize)
   113  	case sSize < dSize:
   114  		size = int64(sSize)
   115  	default:
   116  		size = int64(dSize)
   117  	}
   118  	d := cu.DevicePtr(dst.Uintptr())
   119  	s := cu.DevicePtr(src.Uintptr())
   120  	e.c.Memcpy(d, s, size)
   121  	e.Signal()
   122  	<-e.syncChan
   123  	return e.c.Error()
   124  }
   125  
   126  func (e *Engine) memcpy(dst cu.DevicePtr, src cu.DevicePtr, size int64) {
   127  	e.c.Memcpy(dst, src, size)
   128  }
   129  
   130  func (e *Engine) Accessible(mem tensor.Memory) (tensor.Memory, error) {
   131  	panic("not implemented")
   132  }
   133  
   134  // WorksWith returns true because the data order can be directly worked with
   135  func (e *Engine) WorksWith(order tensor.DataOrder) bool { return true }
   136  
   137  // NonStdAlloc nothing instead of running the default built in allocator
   138  func (e *Engine) NonStdAlloc() {}
   139  
   140  // Errors returns an error message
   141  func (e *Engine) Errors() error { return e.c.Errors() }
   142  
   143  // NaNChecker checks that the tensor contains a NaN
   144  func (e *Engine) HasNaN(a tensor.Tensor) (bool, error) {
   145  	dt := a.Dtype()
   146  	name := fmt.Sprintf("misc.hasNaN_f%v", int(dt.Size()*8))
   147  
   148  	if !e.HasFunc(name) {
   149  		return false, errors.Errorf("Unable to perform HasNaN(). The tensor engine does not have the function %q", name)
   150  	}
   151  
   152  	mem := cu.DevicePtr(a.Uintptr())
   153  	size := int64(logicalSize(a.Shape()))
   154  	fn := e.f[name]
   155  
   156  	var retVal C.int
   157  	gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ := e.ElemGridSize(int(size))
   158  	args := []unsafe.Pointer{
   159  		unsafe.Pointer(&mem),
   160  		unsafe.Pointer(&size),
   161  		unsafe.Pointer(&retVal),
   162  	}
   163  	e.c.LaunchAndSync(fn, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, 0, cu.NoStream, args)
   164  	e.DoWork()
   165  	return int(retVal) > 0, e.c.Error()
   166  }
   167  
   168  // InfChecker checks that the tensor contains a Inf
   169  func (e *Engine) HasInf(a tensor.Tensor) (bool, error) {
   170  	dt := a.Dtype()
   171  	name := fmt.Sprintf("misc.hasInf_f%v", int(dt.Size()*8))
   172  
   173  	if !e.HasFunc(name) {
   174  		return false, errors.Errorf("Unable to perform HasInf(). The tensor engine does not have the function %q", name)
   175  	}
   176  
   177  	mem := cu.DevicePtr(a.Uintptr())
   178  	size := int64(logicalSize(a.Shape()))
   179  	fn := e.f[name]
   180  
   181  	var retVal C.int
   182  	gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ := e.ElemGridSize(int(size))
   183  	args := []unsafe.Pointer{
   184  		unsafe.Pointer(&mem),
   185  		unsafe.Pointer(&size),
   186  		unsafe.Pointer(&retVal),
   187  	}
   188  	e.c.LaunchAndSync(fn, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, 0, cu.NoStream, args)
   189  	e.DoWork()
   190  	return int(retVal) > 0, e.c.Error()
   191  }