gorgonia.org/gorgonia@v0.9.17/cuda.go (about)

     1  // +build cuda
     2  
     3  package gorgonia
     4  
     5  // for non-cuda builds, look at noextern.go
     6  
     7  import (
     8  	"log"
     9  	"sync"
    10  
    11  	"github.com/pkg/errors"
    12  	"gorgonia.org/cu"
    13  	cudnn "gorgonia.org/cu/dnn"
    14  	"gorgonia.org/gorgonia/cuda"
    15  	"gorgonia.org/tensor"
    16  )
    17  
    18  // CUDA tells the package that CUDA is used
    19  const CUDA = true
    20  
    21  var (
    22  	_ External    = &ExternMetadata{}
    23  	_ CUDAMachine = &tapeMachine{}
    24  	_ CUDAMachine = &lispMachine{}
    25  )
    26  
    27  const (
    28  	// Any address of a variable residing in global memory or returned by one of the
    29  	// memory allocation routines from the driver or runtime API is always aligned to at
    30  	// least 256 bytes.
    31  	//
    32  	memalign    = 32
    33  	scalarAlign = 8
    34  )
    35  
    36  //go:generate cudagen -same-module
    37  
    38  var cudaStdLib []cudaLib
    39  
    40  type cudaLib struct {
    41  	name  string
    42  	data  string
    43  	funcs []string
    44  }
    45  
    46  // CUDAMachine is a representation of CUDA capable VMs.
    47  type CUDAMachine interface {
    48  	External
    49  	Engines() []cuda.Engine
    50  	Contexts() []*cu.BatchedContext
    51  	CUDNNContexts() []*cudnn.Context
    52  
    53  	ElemGridSize(n, dev int) (gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ int)
    54  }
    55  
    56  // ExternMetadata holds any metadata for CUDA related stuff.
    57  // The slices in there are indexed by deviceID
    58  type ExternMetadata struct {
    59  	tensor.Engine
    60  	sync.Mutex
    61  
    62  	// operational stuff
    63  	u cu.Device   // device currently in use
    64  	b batchedBLAS // UNUSED
    65  
    66  	engines       []cuda.Engine
    67  	workAvailable chan bool
    68  	syncChan      chan struct{}
    69  	initialized   bool
    70  }
    71  
    72  // ElemGridSize calculates the gridsize for elementwise operations
    73  func (m *ExternMetadata) ElemGridSize(n, dev int) (gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ int) {
    74  	if dev >= len(m.engines) {
    75  		// error
    76  	}
    77  	return m.engines[dev].ElemGridSize(n)
    78  }
    79  
    80  // WorkAvailable returns a channel of empty struct, which is used to signal to the VM when there is work available. The VM will then call the DoWork method
    81  func (m *ExternMetadata) WorkAvailable() <-chan bool { return m.workAvailable }
    82  
    83  // Sync the channels
    84  func (m *ExternMetadata) Sync() chan struct{} { return m.syncChan }
    85  
    86  // DoWork flushes any batched cgo calls. In this build it flushes any batched CUDA calls and any batched CBLAS calls.
    87  func (m *ExternMetadata) DoWork() error {
    88  	for _, e := range m.engines {
    89  		if err := e.DoWork(); err != nil {
    90  			return err
    91  		}
    92  	}
    93  	return nil
    94  }
    95  
    96  // Engines ...
    97  func (m *ExternMetadata) Engines() []cuda.Engine { return m.engines }
    98  
    99  // Contexts return a slice of contexts that is being used by this CUDAMachine
   100  func (m *ExternMetadata) Contexts() []*cu.BatchedContext {
   101  	retVal := make([]*cu.BatchedContext, 0, len(m.engines))
   102  	for _, e := range m.engines {
   103  		retVal = append(retVal, e.Context())
   104  	}
   105  	return retVal
   106  }
   107  
   108  // CUDNNContexts returns the CUDNN context
   109  func (m *ExternMetadata) CUDNNContexts() []*cudnn.Context {
   110  	retVal := make([]*cudnn.Context, 0, len(m.engines))
   111  	for _, e := range m.engines {
   112  		retVal = append(retVal, e.CUDNNContext())
   113  	}
   114  	return retVal
   115  }
   116  
   117  // Get gets a previously allocated memory slab of the provided size. If no memories of that size exist,
   118  // it returns a NoOpError. The caller is then responsible for allocating the memory themselves.
   119  func (m *ExternMetadata) Get(dev Device, size int64) (tensor.Memory, error) {
   120  	d := int(dev)
   121  	if d >= len(m.engines) {
   122  		return nil, noopError{} // this should not be a noopError
   123  	}
   124  	return m.engines[dev].Get(size)
   125  }
   126  
   127  // GetFromValue allocates a memory on the GPU, and then copies the data over. v MUST be on CPU.
   128  func (m *ExternMetadata) GetFromValue(dev Device, v Value) (tensor.Memory, error) {
   129  	d := int(dev)
   130  	if d >= len(m.engines) {
   131  		return nil, noopError{}
   132  	}
   133  	memsize := calcMemSize(v.Dtype(), v.Shape())
   134  
   135  	mem, err := m.engines[dev].Get(memsize)
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  	ptr := cu.DevicePtr(mem.Uintptr())
   140  	ctx := m.engines[dev].Context()
   141  	ctx.MemcpyHtoD(ptr, v.Pointer(), memsize)
   142  	return cu.DevicePtr(ptr), nil
   143  }
   144  
   145  // Put puts a previously allocated memory slab of the provided size back into the pool
   146  func (m *ExternMetadata) Put(dev Device, mem tensor.Memory, size int64) {
   147  	d := int(dev)
   148  	if d >= len(m.engines) {
   149  		return // wat??
   150  	}
   151  
   152  	m.engines[dev].Put(mem, size)
   153  }
   154  
   155  // PutValue puts a previously allocated memory slab back into the pool
   156  func (m *ExternMetadata) PutValue(dev Device, v Value) {
   157  	d := int(dev)
   158  	if d >= len(m.engines) {
   159  		return
   160  	}
   161  	memsize := calcMemSize(v.Dtype(), v.Shape())
   162  	m.engines[dev].Put(v, memsize)
   163  }
   164  
   165  // Transfer transfers data from device to device.
   166  func (m *ExternMetadata) Transfer(toDev, fromDev Device, v Value, synchronous bool) (retVal Value, err error) {
   167  	defer func() {
   168  		if synchronous {
   169  			m.Signal()
   170  		}
   171  	}()
   172  
   173  	memsize := calcMemSize(v.Dtype(), v.Shape())
   174  	switch {
   175  	case fromDev == CPU && toDev != CPU:
   176  		d := int(toDev)
   177  		if d > len(m.engines) {
   178  			return nil, errors.Errorf("No context for ToDev")
   179  		}
   180  
   181  		ctx := m.engines[d].Context()
   182  		var mem tensor.Memory
   183  		if mem, err = m.Get(toDev, memsize); err != nil {
   184  			return
   185  		}
   186  		ctx.MemcpyHtoD(cu.DevicePtr(mem.Uintptr()), v.Pointer(), memsize)
   187  		return makeValueFromMem(TypeOf(v), v.Shape(), mem)
   188  
   189  	case fromDev != CPU && toDev == CPU:
   190  		d := int(fromDev)
   191  		if d > len(m.engines) {
   192  			return nil, errors.Errorf("No context for FromDev")
   193  		}
   194  
   195  		ctx := m.engines[d].Context()
   196  		if retVal, err = makeValue(TypeOf(v), v.Shape()); err != nil {
   197  			return
   198  		}
   199  		ctx.MemcpyDtoH(retVal.Pointer(), cu.DevicePtr(v.Uintptr()), memsize)
   200  		return
   201  	case fromDev == toDev:
   202  		return v, nil
   203  	case fromDev != toDev && fromDev != CPU && toDev != CPU:
   204  
   205  	}
   206  	panic("Unreachable")
   207  }
   208  
   209  // Signal sends a signal down the workavailable channel, telling the VM to call the DoWork method. Signal is a synchronous method
   210  func (m *ExternMetadata) Signal() {
   211  	if m.workAvailable != nil {
   212  		m.signal()
   213  		<-m.syncChan
   214  	}
   215  }
   216  
   217  // Reset frees all the memories, and coalesces the allocator
   218  func (m *ExternMetadata) Reset() {
   219  	for i := range m.engines {
   220  		m.engines[i].ResetAllocator()
   221  	}
   222  }
   223  
   224  func (m *ExternMetadata) init(sizes []int64) (err error) {
   225  	m.Lock()
   226  	initialized := m.initialized
   227  	m.Unlock()
   228  	if initialized {
   229  		return nil
   230  	}
   231  	devices, err := cu.NumDevices()
   232  	if err != nil {
   233  		return errors.Wrapf(err, "Failed to get number of devices")
   234  	}
   235  
   236  	if devices == 0 {
   237  		return errors.New("No Devices Found")
   238  	}
   239  
   240  	cudaLogf("Creating Engines")
   241  	m.Lock()
   242  	defer m.Unlock()
   243  	m.engines = make([]cuda.Engine, len(sizes))
   244  	for i := range m.engines {
   245  		e := &m.engines[i]
   246  		dev, err := cu.GetDevice(i)
   247  		if err != nil {
   248  			return errors.Wrapf(err, "Failed to get device %d", i)
   249  		}
   250  
   251  		if err = e.Init(dev, sizes[i]); err != nil {
   252  			return err
   253  		}
   254  		ctx := e.Context()
   255  		go m.collectWork(i, ctx.WorkAvailable())
   256  	}
   257  
   258  	m.initialized = true
   259  	cudaLogf("CUDA initialized. Engines: %v", m.engines)
   260  	return nil
   261  }
   262  
   263  func (m *ExternMetadata) initFail() {
   264  	cudaLogf("Cleanup")
   265  	m.engines = nil
   266  
   267  	if m.workAvailable != nil {
   268  		close(m.workAvailable)
   269  	}
   270  	m.workAvailable = nil
   271  }
   272  
   273  // cleanup cleans up the ancillary allocations made during the calling of batched CUDA functions.
   274  func (m *ExternMetadata) cleanup() {
   275  	for _, e := range m.engines {
   276  		e.Close()
   277  	}
   278  }
   279  
   280  // collectWork is a muxer for all the channels for the different devices
   281  func (m *ExternMetadata) collectWork(devID int, workAvailable <-chan struct{}) {
   282  	for range workAvailable {
   283  		m.workAvailable <- false
   284  	}
   285  }
   286  
   287  // collectBLASWork is a muxer for CBLAS/CuBLAS (if any) and the devices
   288  func (m *ExternMetadata) collectBLASWork() {}
   289  
   290  func (m *ExternMetadata) signal() { m.workAvailable <- true }
   291  
   292  func (m *ExternMetadata) setEngine(e tensor.Engine) {}
   293  
   294  // AddToStdLib allows for custom ops to be included into the "stdlib" of CUDA functions, so that when the VMs are created, they're loaded automatically
   295  // without having to specify extra loading.
   296  func AddToStdLib(name, data string, funcs []string) {
   297  	cudaStdLib = append(cudaStdLib, cudaLib{
   298  		name:  name,
   299  		data:  data,
   300  		funcs: funcs,
   301  	})
   302  }
   303  
   304  func init() {
   305  	log.Println("Using CUDA build")
   306  }
   307  
   308  // ValueOnDevice gets the value of the node as a Value but on the desired device. If the node's valud is not on the same device
   309  // as the desired device, a copy will be made.
   310  func (n *Node) ValueOnDevice(toDev Device, extern External) (retVal Value, allocOnExtern bool, err error) {
   311  	if n.dataOn == toDev {
   312  		return n.Value(), false, nil
   313  	}
   314  	v := n.Value()
   315  	fromDev := n.Device()
   316  
   317  	var synchronous bool
   318  	if toDev == CPU {
   319  		synchronous = true
   320  	}
   321  	if toDev != fromDev && toDev != CPU {
   322  		allocOnExtern = true
   323  	}
   324  	retVal, err = extern.Transfer(toDev, fromDev, v, synchronous)
   325  	return
   326  }
   327  
   328  // GradOnDevice gets the gradient value of the node as a Value but on the desired device. If the node's valud is not on the same device
   329  // as the desired device, a copy will be made.
   330  func (n *Node) GradOnDevice(toDev Device, extern External) (retVal Value, allocOnExtern bool, err error) {
   331  	if n.dataOn == toDev {
   332  		retVal, err = n.Grad()
   333  		return
   334  	}
   335  
   336  	var d Value
   337  	if dv, ok := n.boundTo.(*dualValue); ok {
   338  		d = dv.d
   339  	} else if n.deriv != nil {
   340  		return n.deriv.ValueOnDevice(toDev, extern)
   341  	} else {
   342  		return nil, false, errors.Errorf("No gradient node/value found for %v", n)
   343  	}
   344  	if d == nil {
   345  		return nil, false, errors.Errorf("No gradient node/value found for %v", n)
   346  	}
   347  
   348  	fromDev := n.Device()
   349  
   350  	var synchronous bool
   351  	if toDev == CPU {
   352  		synchronous = true
   353  	}
   354  	if toDev != CPU && toDev != fromDev {
   355  		allocOnExtern = true
   356  	}
   357  	retVal, err = extern.Transfer(toDev, fromDev, d, synchronous)
   358  	return
   359  }