gorgonia.org/gorgonia@v0.9.17/cuda/linalg.go (about)

     1  package cuda
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  	"gonum.org/v1/gonum/blas"
     6  	"gorgonia.org/tensor"
     7  )
     8  
     9  var (
    10  	_ tensor.MatVecMuler = &Engine{}
    11  	_ tensor.MatMuler    = &Engine{}
    12  	_ tensor.OuterProder = &Engine{}
    13  )
    14  
    15  // this file implements all the tensor linalg engine interfaces
    16  
    17  func (e *Engine) checkThreeFloat(a, b, ret tensor.Tensor) (ad, bd, retVal *tensor.Dense, err error) {
    18  	if /*a.IsNativelyAccessible() &&*/ !a.IsManuallyManaged() {
    19  		return nil, nil, nil, errors.New("CUDA Engine only takes non-natively accessible memory (memory on graphics cards). a isn't.")
    20  	}
    21  
    22  	if /* b.IsNativelyAccessible() && */ !b.IsManuallyManaged() {
    23  		return nil, nil, nil, errors.New("CUDA Engine only takes non-natively accessible memory (memory on graphics cards). b isn't")
    24  	}
    25  
    26  	if /* ret.IsNativelyAccessible() && */ !ret.IsManuallyManaged() {
    27  		return nil, nil, nil, errors.New("CUDA Engine only takes non-natively accessible memory (memory on graphics cards). ret isn't")
    28  	}
    29  
    30  	if a.Dtype() != b.Dtype() || b.Dtype() != ret.Dtype() {
    31  		return nil, nil, nil, errors.New("Expected a and b and retVal all to have the same Dtype")
    32  	}
    33  	var ok bool
    34  	if ad, ok = a.(*tensor.Dense); !ok {
    35  		return nil, nil, nil, errors.New("Expected a to be a *tensor.Dense")
    36  	}
    37  	if bd, ok = b.(*tensor.Dense); !ok {
    38  		return nil, nil, nil, errors.New("Expected b to be a *tensor.Dense")
    39  	}
    40  	if retVal, ok = ret.(*tensor.Dense); !ok {
    41  		return nil, nil, nil, errors.New("Expected ret to be a *tensor.Dense")
    42  	}
    43  	return
    44  }
    45  
    46  // MatVecMul performs matrix vector multiplication
    47  func (e *Engine) MatVecMul(a, b, prealloc tensor.Tensor) (err error) {
    48  	var ad, bd, pd *tensor.Dense
    49  	if ad, bd, pd, err = e.checkThreeFloat(a, b, prealloc); err != nil {
    50  		return errors.Wrapf(err, "MatVecMul failed pre check")
    51  	}
    52  
    53  	tA := blas.Trans
    54  	do := a.DataOrder()
    55  	z := do.IsTransposed()
    56  
    57  	m := a.Shape()[0]
    58  	n := a.Shape()[1]
    59  
    60  	var lda int
    61  	switch {
    62  	case do.IsRowMajor() && z:
    63  		tA = blas.NoTrans
    64  		lda = m
    65  	case do.IsRowMajor() && !z:
    66  		lda = n
    67  		m, n = n, m
    68  	case do.IsColMajor() && z:
    69  		tA = blas.Trans
    70  		lda = n
    71  		m, n = n, m
    72  	case do.IsColMajor() && !z:
    73  		lda = m
    74  		tA = blas.NoTrans
    75  	}
    76  
    77  	e.c.DoWork()
    78  	incX, incY := 1, 1 // step size
    79  
    80  	// ASPIRATIONAL TODO: different incX and incY
    81  	// TECHNICAL DEBT. TECHDEBT. TECH DEBT
    82  	// Example use case:
    83  	// log.Printf("a %v %v", ad.Strides(), ad.ostrides())
    84  	// log.Printf("b %v", b.Strides())
    85  	// incX := a.Strides()[0]
    86  	// incY = b.Strides()[0]
    87  
    88  	switch ad.Dtype() {
    89  	case tensor.Float64:
    90  		A := ad.Float64s()
    91  		X := bd.Float64s()
    92  		Y := pd.Float64s()
    93  		alpha, beta := float64(1), float64(0)
    94  		e.c.DoWork()
    95  		e.c.Do(func() error { e.b.Dgemv(tA, m, n, alpha, A, lda, X, incX, beta, Y, incY); return e.b.Err() })
    96  	case tensor.Float32:
    97  		A := ad.Float32s()
    98  		X := bd.Float32s()
    99  		Y := pd.Float32s()
   100  		alpha, beta := float32(1), float32(0)
   101  		e.c.DoWork()
   102  		e.c.Do(func() error { e.b.Sgemv(tA, m, n, alpha, A, lda, X, incX, beta, Y, incY); return e.b.Err() })
   103  	default:
   104  		return errors.New("Unsupported Dtype")
   105  	}
   106  	return e.b.Err()
   107  }
   108  
   109  // MatMul performs matrix multiplication
   110  func (e *Engine) MatMul(a, b, prealloc tensor.Tensor) (err error) {
   111  	var ad, bd, pd *tensor.Dense
   112  	if ad, bd, pd, err = e.checkThreeFloat(a, b, prealloc); err != nil {
   113  		return errors.Wrapf(err, "MatVecMul failed pre check")
   114  	}
   115  
   116  	ado := a.DataOrder()
   117  	bdo := b.DataOrder()
   118  	if !ado.HasSameOrder(bdo) {
   119  		return errors.Errorf("a does not have the same data order as b. a is %v. b is %v", a.DataOrder(), b.DataOrder())
   120  	}
   121  
   122  	// get result shapes. k is the shared dimension
   123  	// a is (m, k)
   124  	// b is (k, n)
   125  	// c is (m, n)
   126  	var m, n, k int
   127  	m = ad.Shape()[0]
   128  	k = ad.Shape()[1]
   129  	n = bd.Shape()[1]
   130  
   131  	// // wrt the strides, we use the original strides, because that's what BLAS needs, instead of calling .Strides()
   132  	// // lda in colmajor = number of rows;
   133  	// // lda in row major = number of cols
   134  	var lda, ldb, ldc int
   135  	tA, tB := blas.Trans, blas.Trans
   136  	za := ado.IsTransposed()
   137  	zb := bdo.IsTransposed()
   138  
   139  	// swapping around the operands if they are row major (a becomes b, and b becomes a)
   140  	switch {
   141  	case ado.IsColMajor() && bdo.IsColMajor() && !za && !zb:
   142  		lda = m
   143  		ldb = k
   144  		ldc = prealloc.Shape()[0]
   145  		tA, tB = blas.NoTrans, blas.NoTrans
   146  	case ado.IsColMajor() && bdo.IsColMajor() && za && !zb:
   147  		lda = k
   148  		ldb = k
   149  		ldc = prealloc.Shape()[0]
   150  		tA, tB = blas.Trans, blas.NoTrans
   151  	case ado.IsColMajor() && bdo.IsColMajor() && za && zb:
   152  		lda = k
   153  		ldb = n
   154  		ldc = prealloc.Shape()[0]
   155  		tA, tB = blas.Trans, blas.Trans
   156  	case ado.IsColMajor() && bdo.IsColMajor() && !za && zb:
   157  		lda = m
   158  		ldb = n
   159  		ldc = prealloc.Shape()[0]
   160  		tA, tB = blas.NoTrans, blas.Trans
   161  	case ado.IsRowMajor() && bdo.IsRowMajor() && !za && !zb:
   162  		lda = k
   163  		ldb = n
   164  		ldc = prealloc.Shape()[1]
   165  		tA, tB = blas.NoTrans, blas.NoTrans
   166  
   167  		// magic swappy thingy
   168  		m, n = n, m
   169  		lda, ldb = ldb, lda
   170  		ad, bd = bd, ad
   171  	case ado.IsRowMajor() && bdo.IsRowMajor() && za && !zb:
   172  		lda = m
   173  		ldb = n
   174  		ldc = prealloc.Shape()[1]
   175  		tA, tB = blas.Trans, blas.NoTrans
   176  
   177  		// magic swappy thingy
   178  		m, n = n, m
   179  		lda, ldb = ldb, lda
   180  		tA, tB = tB, tA
   181  		ad, bd = bd, ad
   182  	case ado.IsRowMajor() && bdo.IsRowMajor() && za && zb:
   183  		lda = m
   184  		ldb = k
   185  		ldc = prealloc.Shape()[1]
   186  		tA, tB = blas.Trans, blas.Trans
   187  
   188  		// magic swappy thingy
   189  		m, n = n, m
   190  		lda, ldb = ldb, lda
   191  		ad, bd = bd, ad
   192  	case ado.IsRowMajor() && bdo.IsRowMajor() && !za && zb:
   193  		lda = k
   194  		ldb = k
   195  		ldc = prealloc.Shape()[1]
   196  		tA, tB = blas.NoTrans, blas.Trans
   197  
   198  		// magic swappy thingy
   199  		m, n = n, m
   200  		lda, ldb = ldb, lda
   201  		tA, tB = tB, tA
   202  		ad, bd = bd, ad
   203  
   204  	default:
   205  		panic("Unreachable")
   206  	}
   207  
   208  	e.c.DoWork()
   209  	switch ad.Dtype() {
   210  	case tensor.Float64:
   211  		A := ad.Float64s()
   212  		B := bd.Float64s()
   213  		C := pd.Float64s()
   214  		alpha, beta := float64(1), float64(0)
   215  
   216  		e.c.Do(func() error { e.b.Dgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); return nil })
   217  
   218  	case tensor.Float32:
   219  		A := ad.Float32s()
   220  		B := bd.Float32s()
   221  		C := pd.Float32s()
   222  		alpha, beta := float32(1), float32(0)
   223  		e.c.Do(func() error { e.b.Sgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); return nil })
   224  	default:
   225  		return errors.Errorf("Unsupported Dtype %v", ad.Dtype())
   226  	}
   227  
   228  	return e.b.Err()
   229  }
   230  
   231  // Outer performs outer product (kronecker) multiplication
   232  func (e *Engine) Outer(a, b, prealloc tensor.Tensor) (err error) {
   233  	var ad, bd, pd *tensor.Dense
   234  	if ad, bd, pd, err = e.checkThreeFloat(a, b, prealloc); err != nil {
   235  		return errors.Wrapf(err, "MatVecMul failed pre check")
   236  	}
   237  	m := ad.Size()
   238  	n := bd.Size()
   239  	pdo := pd.DataOrder()
   240  
   241  	var lda int
   242  	switch {
   243  	case pdo.IsColMajor():
   244  		lda = pd.Shape()[0]
   245  	case pdo.IsRowMajor():
   246  		aShape := a.Shape().Clone()
   247  		bShape := b.Shape().Clone()
   248  		if err = a.Reshape(aShape[0], 1); err != nil {
   249  			return err
   250  		}
   251  		if err = b.Reshape(1, bShape[0]); err != nil {
   252  			return err
   253  		}
   254  
   255  		if err = e.MatMul(a, b, prealloc); err != nil {
   256  			return err
   257  		}
   258  
   259  		if err = b.Reshape(bShape...); err != nil {
   260  			return
   261  		}
   262  		if err = a.Reshape(aShape...); err != nil {
   263  			return
   264  		}
   265  		return nil
   266  	}
   267  
   268  	e.c.DoWork()
   269  	incX, incY := 1, 1
   270  	switch ad.Dtype() {
   271  	case tensor.Float64:
   272  		x := ad.Float64s()
   273  		y := bd.Float64s()
   274  		A := pd.Float64s()
   275  		alpha := float64(1)
   276  		e.c.Do(func() error { e.b.Dger(m, n, alpha, x, incX, y, incY, A, lda); return nil })
   277  	case tensor.Float32:
   278  		x := ad.Float32s()
   279  		y := bd.Float32s()
   280  		A := pd.Float32s()
   281  		alpha := float32(1)
   282  		e.c.Do(func() error { e.b.Sger(m, n, alpha, x, incX, y, incY, A, lda); return nil })
   283  	}
   284  	return e.b.Err()
   285  }