gorgonia.org/gorgonia@v0.9.17/op_math_cuda.go (about)

     1  // +build cuda
     2  
     3  package gorgonia
     4  
     5  import (
     6  	"fmt"
     7  	"unsafe"
     8  
     9  	"github.com/pkg/errors"
    10  	"gorgonia.org/cu"
    11  	"gorgonia.org/gorgonia/cuda"
    12  	"gorgonia.org/tensor"
    13  )
    14  
    15  // module names
    16  const (
    17  	elemBinOpMod   = "elembinop"
    18  	elemUnaryOpMod = "elemunaryop"
    19  )
    20  
    21  func (op elemUnaryOp) CallsExtern() bool { return true }
    22  
    23  func (op elemUnaryOp) CUDADo(extern External, dev Device, prealloc Value, inputs ...Value) (retVal Value, err error) {
    24  	if err = checkArity(op, len(inputs)); err != nil {
    25  		return
    26  	}
    27  
    28  	cudaLogf("CUDADoing %v | prealloc %x | %x", op, prealloc.Uintptr(), inputs[0].Uintptr())
    29  	enterLogScope()
    30  	defer leaveLogScope()
    31  
    32  	// check
    33  	cudaLogf("checking if input is scalar")
    34  	a := inputs[0]
    35  	dt := a.Dtype()
    36  
    37  	// build name
    38  	name := fmt.Sprintf("%v.%v_f%d", elemUnaryOpMod, op.unaryOpType(), int(dt.Size())*8)
    39  
    40  	machine := extern.(CUDAMachine)
    41  	eng := machine.Engines()[int(dev)]
    42  	if !eng.HasFunc(name) {
    43  		cudaLogf("extern does not have func %q", name)
    44  		extern.Signal()
    45  
    46  		if retVal, err = op.do(a); err != nil {
    47  			return
    48  		}
    49  		if prealloc == nil {
    50  			return
    51  		}
    52  		return Copy(prealloc, retVal)
    53  	}
    54  	fn := eng.Functions()[name]
    55  	ctx := machine.Contexts()[int(dev)]
    56  
    57  	retVal = prealloc
    58  	if prealloc == nil {
    59  		prealloc = a
    60  		retVal = a
    61  	}
    62  
    63  	var mem cu.DevicePtr
    64  	if prealloc.Uintptr() == a.Uintptr() && a.Shape().Eq(prealloc.Shape()) {
    65  		mem = cu.DevicePtr(a.Uintptr())
    66  	} else {
    67  		mem = cu.DevicePtr(prealloc.Uintptr())
    68  		memSize := int64(a.MemSize())
    69  		memA := cu.DevicePtr(a.Uintptr())
    70  		ctx.Memcpy(mem, memA, memSize)
    71  	}
    72  	size := logicalSize(a.Shape())
    73  
    74  	// blocks, threads := machine.(*tapeMachine).blockThread(int(size), int(dev))
    75  	gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ := machine.ElemGridSize(int(size), int(dev))
    76  	args := []unsafe.Pointer{
    77  		unsafe.Pointer(&mem),
    78  		unsafe.Pointer(&size),
    79  	}
    80  	cudaLogf("gx %d, gy %d, gz %d | bx %d by %d, bz %d", gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ)
    81  	cudaLogf("CUDADO %q, Mem: %v size %v, args %v", name, mem, size, args)
    82  	cudaLogf("LaunchKernel Params. mem: %v. Size %v", mem, size)
    83  	ctx.LaunchAndSync(fn, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, 0, cu.NoStream, args)
    84  	return
    85  }
    86  
    87  func (op elemBinOp) CallsExtern() bool { return true }
    88  
    89  func (op elemBinOp) CUDADo(extern External, dev Device, prealloc Value, inputs ...Value) (retVal Value, err error) {
    90  	if err = checkArity(op, len(inputs)); err != nil {
    91  		return
    92  	}
    93  	cudaLogf("CUDADoing %v", op)
    94  	enterLogScope()
    95  	defer leaveLogScope()
    96  
    97  	a := inputs[0]
    98  	b := inputs[1]
    99  	as := a.Shape()
   100  	bs := b.Shape()
   101  
   102  	m := extern.(CUDAMachine)
   103  	e := &m.Engines()[int(dev)]
   104  
   105  	if as.IsScalar() && bs.IsScalar() {
   106  		return op.ssop(a, b, prealloc, e)
   107  	}
   108  
   109  	if aT, ok := a.(tensor.Tensor); ok {
   110  		tensor.WithEngine(e)(aT)
   111  	}
   112  	if bT, ok := b.(tensor.Tensor); ok {
   113  		tensor.WithEngine(e)(bT)
   114  	}
   115  
   116  	pT, toReuse := prealloc.(tensor.Tensor)
   117  	if toReuse {
   118  		tensor.WithEngine(e)(pT)
   119  	}
   120  
   121  	boType := op.binOpType()
   122  	if fn := binOps[boType]; fn != nil {
   123  		if toReuse {
   124  			return (*fn)(a, b, tensor.WithReuse(pT))
   125  		}
   126  		return (*fn)(a, b, tensor.UseUnsafe())
   127  	}
   128  
   129  	if fn := cmpOps[boType]; fn != nil {
   130  		if toReuse {
   131  			return (*fn)(a, b, tensor.WithReuse(pT))
   132  		}
   133  		return (*fn)(a, b, tensor.UseUnsafe())
   134  	}
   135  
   136  	return nil, errors.Errorf("op %v cannot be done by CUDA", op)
   137  }
   138  
   139  func (op elemBinOp) ssop(a, b, prealloc Value, e *cuda.Engine) (retVal Value, err error) {
   140  	dt := a.Dtype()
   141  	ctx := e.Context()
   142  	opName := ʘBinOpNames[op.binOpType()]
   143  	name := fmt.Sprintf("%v.%v_ss_f%d", elemBinOpMod, opName, int(dt.Size())*8)
   144  	var mem, memB cu.DevicePtr
   145  	var size int64
   146  	if prealloc == nil {
   147  		mem = cu.DevicePtr(a.Uintptr())
   148  		retVal = a
   149  		size = int64(logicalSize(a.Shape()))
   150  	} else {
   151  		mem = cu.DevicePtr(prealloc.Uintptr())
   152  		memA := cu.DevicePtr(a.Uintptr())
   153  		memSize := int64(a.MemSize())
   154  		ctx.Memcpy(mem, memA, memSize)
   155  
   156  		size = int64(logicalSize(prealloc.Shape()))
   157  		retVal = prealloc
   158  	}
   159  	memB = cu.DevicePtr(b.Uintptr())
   160  	fn := e.Functions()[name]
   161  
   162  	var args []unsafe.Pointer
   163  	cudaLogf("%v mem %v, memB %v", op, mem, memB)
   164  	gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ := e.ElemGridSize(int(size))
   165  	args = []unsafe.Pointer{
   166  		unsafe.Pointer(&mem),
   167  		unsafe.Pointer(&memB),
   168  		unsafe.Pointer(&size),
   169  	}
   170  
   171  	cudaLogf("CUDADO %q, size %v", name, size)
   172  	cudaLogf("LaunchKernel params. mem: %v memB: %v size: %v", mem, memB, size)
   173  	cudaLogf("%d, %d, %d, %d, %d, %d", gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ)
   174  	ctx.LaunchAndSync(fn, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, 0, cu.NoStream, args)
   175  	return
   176  }
   177  
   178  /* LINEAR ALGEBRA STUFF */
   179  
   180  func (op linAlgBinOp) CallsExtern() bool { return true }
   181  
   182  func (op linAlgBinOp) CUDADo(extern External, dev Device, prealloc Value, inputs ...Value) (retVal Value, err error) {
   183  	if err = checkArity(op, len(inputs)); err != nil {
   184  		return
   185  	}
   186  
   187  	m := extern.(CUDAMachine)
   188  	e := &m.Engines()[int(dev)]
   189  
   190  	a := inputs[0]
   191  	b := inputs[1]
   192  
   193  	aT, ok := a.(tensor.Tensor)
   194  	if !ok {
   195  		return nil, errors.Errorf("Expected a a to be a Tensor. Got %T instead", a)
   196  	}
   197  	bT, ok := b.(tensor.Tensor)
   198  	if !ok {
   199  		return nil, errors.Errorf("Expected a b to be a Tensor. Got %T instead", b)
   200  	}
   201  
   202  	pT, ok := prealloc.(tensor.Tensor)
   203  	if !ok {
   204  		return nil, errors.Errorf("Expected a prealloc to be a Tensor. Got %T instead", prealloc)
   205  	}
   206  	tensor.WithEngine(e)(bT)
   207  	tensor.WithEngine(e)(aT)
   208  	tensor.WithEngine(e)(pT)
   209  
   210  	if op.transA && op.āBinaryOperator != batchedMatMulOperator {
   211  		if err = aT.T(); err != nil {
   212  			return nil, errors.Wrap(err, tFail)
   213  		}
   214  		// untranspose
   215  		defer aT.T()
   216  	}
   217  
   218  	if op.transB && op.āBinaryOperator != batchedMatMulOperator {
   219  		if err = bT.T(); err != nil {
   220  			return nil, errors.Wrap(err, tFail)
   221  		}
   222  		// untranspose
   223  		defer bT.T()
   224  	}
   225  
   226  	switch op.āBinaryOperator {
   227  	case matMulOperator:
   228  		return tensor.MatMul(aT, bT, tensor.WithReuse(pT))
   229  	case matVecMulOperator:
   230  		return tensor.MatVecMul(aT, bT, tensor.WithReuse(pT))
   231  	case vecDotOperator:
   232  		return nil, errors.New("NYI")
   233  	case outerProdOperator:
   234  		return tensor.Outer(aT, bT, tensor.WithReuse(pT))
   235  	case batchedMatMulOperator:
   236  		// checks were done when the op was created
   237  		return batchedMatMul(aT, bT, nil, op.transA, op.transB, false)
   238  	}
   239  	panic("Unreachable")
   240  }
   241  
   242  /* API stuff  */
   243  
   244  // NewAddOp creates a new *ExternalOp that wraps a add op
   245  func NewAddOp(a, b *Node, ctx ExecutionContext) *ExternalOp {
   246  	add := newElemBinOp(addOpType, a, b)
   247  	op := NewExternalOp(add, ctx, nil)
   248  	if a.Device() == CPU && b.Device() == CPU {
   249  		op.Device = CPU
   250  		return op
   251  	}
   252  
   253  	if a.Device() != CPU {
   254  		op.Device = a.Device()
   255  		return op
   256  	}
   257  
   258  	if b.Device() != CPU {
   259  		op.Device = b.Device()
   260  		return op
   261  	}
   262  
   263  	return op
   264  }
   265  
   266  // NewSubOp creates a new *ExternalOp that wraps a sub op
   267  func NewSubOp(a, b *Node, ctx ExecutionContext) *ExternalOp {
   268  	sub := newEBOByType(subOpType, a.t, b.t)
   269  	op := NewExternalOp(sub, ctx, nil)
   270  
   271  	if a.Device() == CPU && b.Device() == CPU {
   272  		op.Device = CPU
   273  		return op
   274  	}
   275  
   276  	if a.Device() != CPU {
   277  		op.Device = a.Device()
   278  		return op
   279  	}
   280  
   281  	if b.Device() != CPU {
   282  		op.Device = b.Device()
   283  		return op
   284  	}
   285  	return op
   286  }
   287  
   288  // NewHadamardProdOp creates a new *ExternalOp that wraps a mul op
   289  func NewHadamardProdOp(a, b *Node, ctx ExecutionContext) *ExternalOp {
   290  	mul := newEBOByType(mulOpType, a.t, b.t)
   291  	op := NewExternalOp(mul, ctx, nil)
   292  
   293  	if a.Device() == CPU && b.Device() == CPU {
   294  		op.Device = CPU
   295  		return op
   296  	}
   297  
   298  	if a.Device() != CPU {
   299  		op.Device = a.Device()
   300  		return op
   301  	}
   302  
   303  	if b.Device() != CPU {
   304  		op.Device = b.Device()
   305  		return op
   306  	}
   307  	return op
   308  }