gorgonia.org/gorgonia@v0.9.17/vm_tape_cuda.go (about)

     1  // +build cuda
     2  
     3  package gorgonia
     4  
     5  import (
     6  	"github.com/pkg/errors"
     7  	"gorgonia.org/cu"
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  func finalizeTapeMachine(m *tapeMachine) {
    12  	cudaLogf("Finalizing tape machine %p", m)
    13  	m.cleanup()
    14  	m.initFail() // not really a failure. Just call to detroy all the contexts and shit
    15  }
    16  
    17  func (m *tapeMachine) init() {
    18  	var initCUDA bool
    19  	cudaLogf("instructions %v", len(m.p.instructions))
    20  	for _, instr := range m.p.instructions {
    21  		if eo, ok := instr.(*execOp); ok {
    22  			if _, ok := eo.op.(CUDADoer); ok {
    23  				initCUDA = true
    24  				break
    25  			}
    26  		}
    27  	}
    28  
    29  	// don't bother initializing contexts if no instructions were CUDA based
    30  	if !initCUDA {
    31  		cudaLogf("No CUDA ops")
    32  		return
    33  	}
    34  
    35  	if err := m.ExternMetadata.init(m.p.gpumem); err != nil {
    36  		m.ExternMetadata.initFail()
    37  		panic(err)
    38  	}
    39  	m.loadStdLib()
    40  
    41  }
    42  
    43  // loads the standardlib
    44  func (m *tapeMachine) loadStdLib() {
    45  	if cudaStdLib == nil {
    46  		return
    47  	}
    48  
    49  	for _, lib := range cudaStdLib {
    50  		for i := range m.engines {
    51  			e := &m.engines[i]
    52  			if err := e.LoadCUDAFunc(lib.name, lib.data, lib.funcs); err != nil {
    53  				panic(err)
    54  			}
    55  		}
    56  	}
    57  }
    58  
    59  func (m *tapeMachine) getEngine(dev Device) tensor.Engine {
    60  	if dev == CPU {
    61  		return m.Engine
    62  	}
    63  	return &m.Engines()[int(dev)]
    64  }
    65  
    66  func (instr *execOp) exec(m *tapeMachine) (err error) {
    67  	m.logf("Executing %v. Node is: %x", instr, instr.id)
    68  	m.enterLogScope()
    69  	defer m.leaveLogScope()
    70  
    71  	enterLogScope()
    72  	defer leaveLogScope()
    73  
    74  	m.watchedLogf("Inputs:")
    75  	m.enterLogScope()
    76  	var inputs []Value
    77  	for _, reg := range instr.readFrom {
    78  		v := m.getValue(reg)
    79  		inputs = append(inputs, v)
    80  		m.watchedLogf(m.valueFmt, v.Uintptr())
    81  	}
    82  	m.leaveLogScope()
    83  
    84  	toDev := instr.writeTo.device
    85  	var v Value
    86  	switch op := instr.op.(type) {
    87  	case CUDADoer:
    88  		prealloc := m.getValue(instr.writeTo)
    89  		if v, err = op.CUDADo(m, toDev, prealloc, inputs...); err != nil {
    90  			return errors.Wrapf(err, "Happened while attempting to use CUDA to execute %v. Node is %x. Register was %v", instr, instr.id, instr.writeTo.id)
    91  		}
    92  		e := &m.Engines()[int(toDev)]
    93  		setEngine(v, e)
    94  	case CLDoer:
    95  	default:
    96  		switch {
    97  		case instr.preAllocated:
    98  			if pd, ok := instr.op.(UsePreallocDoer); ok {
    99  				p := m.cpumem[instr.writeTo.id]
   100  				if v, err = pd.UsePreallocDo(p, inputs...); err != nil {
   101  					return errors.Wrapf(err, "Happened while attempting to execute %v. Node is %x. Register was: %v ", instr, instr.id, instr.writeTo.id)
   102  				}
   103  			} else {
   104  				// TODO: maybe warn?
   105  				if v, err = instr.op.Do(inputs...); err != nil {
   106  					return errors.Wrap(err, opDoFail)
   107  				}
   108  			}
   109  		case instr.useUnsafe:
   110  			if ud, ok := instr.op.(UnsafeDoer); ok {
   111  				if v, err = ud.UnsafeDo(inputs...); err != nil {
   112  					return errors.Wrap(err, "Failed to carry UnsafeDo()")
   113  				}
   114  			} else {
   115  				// TODO: warn?
   116  				if v, err = instr.op.Do(inputs...); err != nil {
   117  					return errors.Wrap(err, opDoFail)
   118  				}
   119  			}
   120  		default:
   121  			if v, err = instr.op.Do(inputs...); err != nil {
   122  				return errors.Wrap(err, opDoFail)
   123  			}
   124  		}
   125  		setEngine(v, m.Engine)
   126  
   127  	}
   128  	m.watchedLogf("Result E:")
   129  	m.enterLogScope()
   130  	if vt, ok := v.(tensor.Tensor); ok {
   131  		m.watchedLogf("%x | %T", v.Uintptr(), vt.Engine())
   132  	} else {
   133  		m.watchedLogf("%x", v.Uintptr())
   134  	}
   135  	m.leaveLogScope()
   136  	// TODO: type and shape checks
   137  
   138  	// Write
   139  	m.writeValue(instr.writeTo, v)
   140  	node := m.p.g.Node(instr.id).(*Node)
   141  
   142  	if m.trace() && (len(m.watchNodes) == 0 || m.watchNodes.Contains(node)) {
   143  		m.Signal()
   144  		if err = node.bindCopy(v); err != nil {
   145  			return errors.Wrapf(err, "TraceExec failed to bind copy")
   146  		}
   147  	} else {
   148  		node.bind(v)
   149  	}
   150  
   151  	// this is a gradient node then, we should also bind the value to the node's dualValue
   152  	if m.bindDV() && node.derivOf != nil {
   153  		for _, src := range node.derivOf {
   154  			if len(m.bindNodesDV) > 0 && !m.bindNodesDV.Contains(src) {
   155  				continue
   156  			}
   157  
   158  			if src.boundTo != nil {
   159  				dv := dvUnit(src.boundTo)
   160  				cudaLogf("dv.d 0x%x v 0x%x | writeTo: %v", dv.d.Uintptr(), v.Uintptr(), instr.writeTo)
   161  				dev := instr.writeTo.device
   162  				add := newEBOByType(addOpType, TypeOf(dv.d), TypeOf(v))
   163  				switch dev {
   164  				case CPU:
   165  					if d, err := add.UnsafeDo(dv.d, v); err == nil {
   166  						dv.SetDeriv(d)
   167  						src.bind(dv)
   168  					} else {
   169  						return err
   170  					}
   171  				default:
   172  					// temporarily allocate a valu
   173  					ctx := m.Contexts()[int(dev)]
   174  
   175  					dt := dv.d.Dtype()
   176  					shp := dv.d.Shape()
   177  					memsize := calcMemSize(dt, shp)
   178  
   179  					var mem tensor.Memory
   180  					if mem, err = m.Get(dev, memsize); err != nil {
   181  						return errors.Wrapf(err, "Unable to allocate %v bytes from %v", memsize, dev)
   182  					}
   183  
   184  					var d Value
   185  					if d, err = makeValueFromMem(dt, shp, mem); err != nil {
   186  						return
   187  					}
   188  
   189  					// copy dv.d to d
   190  					ctx.MemcpyHtoD(mem.(cu.DevicePtr), dv.d.Pointer(), memsize)
   191  
   192  					// perform  the op
   193  					if _, err = add.CUDADo(m, dev, d, d, v); err != nil {
   194  						return
   195  					}
   196  					// copy the value back into dv.d
   197  					ctx.MemcpyDtoH(dv.d.Pointer(), mem.(cu.DevicePtr), memsize)
   198  					m.Put(dev, mem, memsize) // then free it
   199  
   200  					src.bind(dv)
   201  					// the CPU method is correct. This method is correct for MOST cases, but will not be correct under some other circumstances
   202  					// ctx.MemcpyDtoH(dv.d.Pointer(), cu.DevicePtr(v.Uintptr()), instr.size)
   203  				}
   204  			}
   205  		}
   206  
   207  	}
   208  
   209  	m.watchedLogf("Written To: %v", instr.writeTo)
   210  	m.enterLogScope()
   211  	m.watchedLogf(m.valueFmt, v.Uintptr())
   212  	m.leaveLogScope()
   213  
   214  	return nil
   215  }
   216  
   217  func (instr deviceTransport) exec(m *tapeMachine) (err error) {
   218  	m.logf("Executing %v", instr)
   219  	from := m.getValue(instr.from)
   220  	to := m.getValue(instr.to)
   221  
   222  	var ctx *cu.BatchedContext
   223  	switch {
   224  	case instr.from.device == CPU && instr.to.device != CPU:
   225  		memsize := int64(from.MemSize())
   226  		ctx = m.Contexts()[int(instr.to.device)]
   227  		ctx.MemcpyHtoD(cu.DevicePtr(to.Uintptr()), from.Pointer(), memsize)
   228  	case instr.from.device != CPU && instr.to.device == CPU:
   229  		dt := from.Dtype()
   230  		memsize := calcMemSize(dt, from.Shape())
   231  		ctx = m.Contexts()[int(instr.from.device)]
   232  		ctx.MemcpyDtoH(to.Pointer(), cu.DevicePtr(from.Uintptr()), memsize)
   233  
   234  		// when copying from device to host, it's assumed that the host will want to immediately use
   235  		// so signal the DoWork
   236  		m.Signal()
   237  	}
   238  
   239  	return nil
   240  }