gorgonia.org/gorgonia@v0.9.17/cuda/linalg.go (about) 1 package cuda 2 3 import ( 4 "github.com/pkg/errors" 5 "gonum.org/v1/gonum/blas" 6 "gorgonia.org/tensor" 7 ) 8 9 var ( 10 _ tensor.MatVecMuler = &Engine{} 11 _ tensor.MatMuler = &Engine{} 12 _ tensor.OuterProder = &Engine{} 13 ) 14 15 // this file implements all the tensor linalg engine interfaces 16 17 func (e *Engine) checkThreeFloat(a, b, ret tensor.Tensor) (ad, bd, retVal *tensor.Dense, err error) { 18 if /*a.IsNativelyAccessible() &&*/ !a.IsManuallyManaged() { 19 return nil, nil, nil, errors.New("CUDA Engine only takes non-natively accessible memory (memory on graphics cards). a isn't.") 20 } 21 22 if /* b.IsNativelyAccessible() && */ !b.IsManuallyManaged() { 23 return nil, nil, nil, errors.New("CUDA Engine only takes non-natively accessible memory (memory on graphics cards). b isn't") 24 } 25 26 if /* ret.IsNativelyAccessible() && */ !ret.IsManuallyManaged() { 27 return nil, nil, nil, errors.New("CUDA Engine only takes non-natively accessible memory (memory on graphics cards). ret isn't") 28 } 29 30 if a.Dtype() != b.Dtype() || b.Dtype() != ret.Dtype() { 31 return nil, nil, nil, errors.New("Expected a and b and retVal all to have the same Dtype") 32 } 33 var ok bool 34 if ad, ok = a.(*tensor.Dense); !ok { 35 return nil, nil, nil, errors.New("Expected a to be a *tensor.Dense") 36 } 37 if bd, ok = b.(*tensor.Dense); !ok { 38 return nil, nil, nil, errors.New("Expected b to be a *tensor.Dense") 39 } 40 if retVal, ok = ret.(*tensor.Dense); !ok { 41 return nil, nil, nil, errors.New("Expected ret to be a *tensor.Dense") 42 } 43 return 44 } 45 46 // MatVecMul performs matrix vector multiplication 47 func (e *Engine) MatVecMul(a, b, prealloc tensor.Tensor) (err error) { 48 var ad, bd, pd *tensor.Dense 49 if ad, bd, pd, err = e.checkThreeFloat(a, b, prealloc); err != nil { 50 return errors.Wrapf(err, "MatVecMul failed pre check") 51 } 52 53 tA := blas.Trans 54 do := a.DataOrder() 55 z := do.IsTransposed() 56 57 m := a.Shape()[0] 58 n := a.Shape()[1] 59 60 var lda int 61 switch { 62 case do.IsRowMajor() && z: 63 tA = blas.NoTrans 64 lda = m 65 case do.IsRowMajor() && !z: 66 lda = n 67 m, n = n, m 68 case do.IsColMajor() && z: 69 tA = blas.Trans 70 lda = n 71 m, n = n, m 72 case do.IsColMajor() && !z: 73 lda = m 74 tA = blas.NoTrans 75 } 76 77 e.c.DoWork() 78 incX, incY := 1, 1 // step size 79 80 // ASPIRATIONAL TODO: different incX and incY 81 // TECHNICAL DEBT. TECHDEBT. TECH DEBT 82 // Example use case: 83 // log.Printf("a %v %v", ad.Strides(), ad.ostrides()) 84 // log.Printf("b %v", b.Strides()) 85 // incX := a.Strides()[0] 86 // incY = b.Strides()[0] 87 88 switch ad.Dtype() { 89 case tensor.Float64: 90 A := ad.Float64s() 91 X := bd.Float64s() 92 Y := pd.Float64s() 93 alpha, beta := float64(1), float64(0) 94 e.c.DoWork() 95 e.c.Do(func() error { e.b.Dgemv(tA, m, n, alpha, A, lda, X, incX, beta, Y, incY); return e.b.Err() }) 96 case tensor.Float32: 97 A := ad.Float32s() 98 X := bd.Float32s() 99 Y := pd.Float32s() 100 alpha, beta := float32(1), float32(0) 101 e.c.DoWork() 102 e.c.Do(func() error { e.b.Sgemv(tA, m, n, alpha, A, lda, X, incX, beta, Y, incY); return e.b.Err() }) 103 default: 104 return errors.New("Unsupported Dtype") 105 } 106 return e.b.Err() 107 } 108 109 // MatMul performs matrix multiplication 110 func (e *Engine) MatMul(a, b, prealloc tensor.Tensor) (err error) { 111 var ad, bd, pd *tensor.Dense 112 if ad, bd, pd, err = e.checkThreeFloat(a, b, prealloc); err != nil { 113 return errors.Wrapf(err, "MatVecMul failed pre check") 114 } 115 116 ado := a.DataOrder() 117 bdo := b.DataOrder() 118 if !ado.HasSameOrder(bdo) { 119 return errors.Errorf("a does not have the same data order as b. a is %v. b is %v", a.DataOrder(), b.DataOrder()) 120 } 121 122 // get result shapes. k is the shared dimension 123 // a is (m, k) 124 // b is (k, n) 125 // c is (m, n) 126 var m, n, k int 127 m = ad.Shape()[0] 128 k = ad.Shape()[1] 129 n = bd.Shape()[1] 130 131 // // wrt the strides, we use the original strides, because that's what BLAS needs, instead of calling .Strides() 132 // // lda in colmajor = number of rows; 133 // // lda in row major = number of cols 134 var lda, ldb, ldc int 135 tA, tB := blas.Trans, blas.Trans 136 za := ado.IsTransposed() 137 zb := bdo.IsTransposed() 138 139 // swapping around the operands if they are row major (a becomes b, and b becomes a) 140 switch { 141 case ado.IsColMajor() && bdo.IsColMajor() && !za && !zb: 142 lda = m 143 ldb = k 144 ldc = prealloc.Shape()[0] 145 tA, tB = blas.NoTrans, blas.NoTrans 146 case ado.IsColMajor() && bdo.IsColMajor() && za && !zb: 147 lda = k 148 ldb = k 149 ldc = prealloc.Shape()[0] 150 tA, tB = blas.Trans, blas.NoTrans 151 case ado.IsColMajor() && bdo.IsColMajor() && za && zb: 152 lda = k 153 ldb = n 154 ldc = prealloc.Shape()[0] 155 tA, tB = blas.Trans, blas.Trans 156 case ado.IsColMajor() && bdo.IsColMajor() && !za && zb: 157 lda = m 158 ldb = n 159 ldc = prealloc.Shape()[0] 160 tA, tB = blas.NoTrans, blas.Trans 161 case ado.IsRowMajor() && bdo.IsRowMajor() && !za && !zb: 162 lda = k 163 ldb = n 164 ldc = prealloc.Shape()[1] 165 tA, tB = blas.NoTrans, blas.NoTrans 166 167 // magic swappy thingy 168 m, n = n, m 169 lda, ldb = ldb, lda 170 ad, bd = bd, ad 171 case ado.IsRowMajor() && bdo.IsRowMajor() && za && !zb: 172 lda = m 173 ldb = n 174 ldc = prealloc.Shape()[1] 175 tA, tB = blas.Trans, blas.NoTrans 176 177 // magic swappy thingy 178 m, n = n, m 179 lda, ldb = ldb, lda 180 tA, tB = tB, tA 181 ad, bd = bd, ad 182 case ado.IsRowMajor() && bdo.IsRowMajor() && za && zb: 183 lda = m 184 ldb = k 185 ldc = prealloc.Shape()[1] 186 tA, tB = blas.Trans, blas.Trans 187 188 // magic swappy thingy 189 m, n = n, m 190 lda, ldb = ldb, lda 191 ad, bd = bd, ad 192 case ado.IsRowMajor() && bdo.IsRowMajor() && !za && zb: 193 lda = k 194 ldb = k 195 ldc = prealloc.Shape()[1] 196 tA, tB = blas.NoTrans, blas.Trans 197 198 // magic swappy thingy 199 m, n = n, m 200 lda, ldb = ldb, lda 201 tA, tB = tB, tA 202 ad, bd = bd, ad 203 204 default: 205 panic("Unreachable") 206 } 207 208 e.c.DoWork() 209 switch ad.Dtype() { 210 case tensor.Float64: 211 A := ad.Float64s() 212 B := bd.Float64s() 213 C := pd.Float64s() 214 alpha, beta := float64(1), float64(0) 215 216 e.c.Do(func() error { e.b.Dgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); return nil }) 217 218 case tensor.Float32: 219 A := ad.Float32s() 220 B := bd.Float32s() 221 C := pd.Float32s() 222 alpha, beta := float32(1), float32(0) 223 e.c.Do(func() error { e.b.Sgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); return nil }) 224 default: 225 return errors.Errorf("Unsupported Dtype %v", ad.Dtype()) 226 } 227 228 return e.b.Err() 229 } 230 231 // Outer performs outer product (kronecker) multiplication 232 func (e *Engine) Outer(a, b, prealloc tensor.Tensor) (err error) { 233 var ad, bd, pd *tensor.Dense 234 if ad, bd, pd, err = e.checkThreeFloat(a, b, prealloc); err != nil { 235 return errors.Wrapf(err, "MatVecMul failed pre check") 236 } 237 m := ad.Size() 238 n := bd.Size() 239 pdo := pd.DataOrder() 240 241 var lda int 242 switch { 243 case pdo.IsColMajor(): 244 lda = pd.Shape()[0] 245 case pdo.IsRowMajor(): 246 aShape := a.Shape().Clone() 247 bShape := b.Shape().Clone() 248 if err = a.Reshape(aShape[0], 1); err != nil { 249 return err 250 } 251 if err = b.Reshape(1, bShape[0]); err != nil { 252 return err 253 } 254 255 if err = e.MatMul(a, b, prealloc); err != nil { 256 return err 257 } 258 259 if err = b.Reshape(bShape...); err != nil { 260 return 261 } 262 if err = a.Reshape(aShape...); err != nil { 263 return 264 } 265 return nil 266 } 267 268 e.c.DoWork() 269 incX, incY := 1, 1 270 switch ad.Dtype() { 271 case tensor.Float64: 272 x := ad.Float64s() 273 y := bd.Float64s() 274 A := pd.Float64s() 275 alpha := float64(1) 276 e.c.Do(func() error { e.b.Dger(m, n, alpha, x, incX, y, incY, A, lda); return nil }) 277 case tensor.Float32: 278 x := ad.Float32s() 279 y := bd.Float32s() 280 A := pd.Float32s() 281 alpha := float32(1) 282 e.c.Do(func() error { e.b.Sger(m, n, alpha, x, incX, y, incY, A, lda); return nil }) 283 } 284 return e.b.Err() 285 }