gorgonia.org/gorgonia@v0.9.17/operatorLinAlg.go (about) 1 package gorgonia 2 3 import ( 4 "github.com/chewxy/hm" 5 "github.com/pkg/errors" 6 "gorgonia.org/tensor" 7 ) 8 9 // ā and Ā are used to denote that it's a matrix/vector type. 10 // if you want to type it, it's Latin Letter A with Macron (lowercase and capital) 11 // Codepoints : U+101 for the small one, and U+100 for the capital one 12 13 type āBinaryOperator byte 14 15 const ( 16 matMulOperator āBinaryOperator = iota // emits S/DGEMM BLAS calls 17 matVecMulOperator // emits S/DGEMV BLAS calls 18 vecDotOperator // emits S/DDOT BLAS calls 19 outerProdOperator // emits S/DGER BLAS calls 20 batchedMatMulOperator // just S/GEMM BLAS calls in a loop 21 22 maxĀBinaryOperator // delimits all possible linalg operators. Add above this line 23 ) 24 25 func (op āBinaryOperator) String() string { 26 if op >= maxĀBinaryOperator { 27 return "UNSUPPORTED LINEAR ALGEBRA OPERATOR" 28 } 29 return āBinOpStrs[op] 30 } 31 32 func (op āBinaryOperator) Type() hm.Type { 33 if op >= maxĀBinaryOperator { 34 panic("UNSUPPORTED LINEAR ALGEBRA OPERATOR") 35 } 36 return āBinOpTypes[op]() 37 } 38 39 func (op āBinaryOperator) DiffWRT(inputs int) []bool { 40 if inputs != 2 { 41 panic("binary linear algebra operator only supports two and only two inputs") 42 } 43 44 if op >= maxĀBinaryOperator { 45 panic("Unsupported unary operator is not differentiable") 46 } 47 return []bool{true, true} 48 } 49 50 // todo: write explanation. 51 func matMulDiffExpr(transA, transB bool, x, y, z, gradZ *Node) (retVal Nodes, err error) { 52 var dzdx, dzdy *Node 53 op := linAlgBinOp{ 54 āBinaryOperator: matMulOperator, 55 } 56 57 switch { 58 case transA && transB: 59 op.transA = transA 60 op.transB = transB 61 if dzdx, err = binOpNode(op, y, gradZ); err != nil { 62 return nil, errors.Wrapf(err, binOpNodeFail, op) 63 } 64 if dzdy, err = binOpNode(op, gradZ, x); err != nil { 65 return nil, errors.Wrapf(err, binOpNodeFail, op) 66 } 67 case !transA && transB: 68 if dzdx, err = binOpNode(op, gradZ, y); err != nil { 69 return nil, errors.Wrapf(err, binOpNodeFail, op) 70 } 71 72 op.transA = true 73 if dzdy, err = binOpNode(op, gradZ, x); err != nil { 74 return nil, errors.Wrapf(err, binOpNodeFail, op) 75 } 76 case transA && !transB: 77 op.transB = true 78 if dzdx, err = binOpNode(op, y, gradZ); err != nil { 79 return nil, errors.Wrapf(err, binOpNodeFail, op) 80 } 81 82 op.transB = false 83 if dzdy, err = binOpNode(op, x, gradZ); err != nil { 84 return nil, errors.Wrapf(err, binOpNodeFail, op) 85 } 86 case !transA && !transB: 87 // dzdy 88 op.transA = false 89 op.transB = true 90 if dzdx, err = binOpNode(op, gradZ, y); err != nil { 91 return nil, errors.Wrapf(err, binOpNodeFail, op) 92 } 93 // do dzdx 94 op.transA = true 95 op.transB = false 96 if dzdy, err = binOpNode(op, x, gradZ); err != nil { 97 return nil, errors.Wrapf(err, binOpNodeFail, op) 98 } 99 } 100 retVal = Nodes{dzdx, dzdy} 101 return 102 } 103 104 func matMulDiff(ctx ExecutionContext, transA, transB bool, x, y, z *Node) (err error) { 105 xdv, ydv, zdv := getDV3(x, y, z) 106 107 op := linAlgBinOp{ 108 āBinaryOperator: matMulOperator, 109 } 110 111 switch { 112 case transA && transB: 113 op.transA = transA 114 op.transB = transB 115 116 // dzdx 117 err = op.IncrDo(xdv.d, ydv.Value, zdv.d) 118 if err = checkErrSetDeriv(err, xdv); err != nil { 119 return errors.Wrapf(err, autodiffFail, x) 120 } 121 122 // dzdy 123 err = op.IncrDo(ydv.d, zdv.d, xdv.Value) 124 if err = checkErrSetDeriv(err, ydv); err != nil { 125 return errors.Wrapf(err, autodiffFail, y) 126 } 127 128 return 129 130 case !transA && transB: 131 // dzdx 132 err = op.IncrDo(xdv.d, zdv.d, ydv.Value) 133 if err = checkErrSetDeriv(err, xdv); err != nil { 134 return errors.Wrapf(err, autodiffFail, x) 135 } 136 137 // dzdy 138 op.transA = true 139 err = op.IncrDo(ydv.d, zdv.d, xdv.Value) 140 if err = checkErrSetDeriv(err, ydv); err != nil { 141 return errors.Wrapf(err, autodiffFail, x) 142 } 143 144 return 145 146 case transA && !transB: 147 // dzdx 148 op.transB = true 149 err = op.IncrDo(xdv.d, ydv.Value, zdv.d) 150 if err = checkErrSetDeriv(err, xdv); err != nil { 151 return errors.Wrapf(err, autodiffFail, x) 152 } 153 154 // dzdy 155 op.transA = false 156 op.transB = false 157 err = op.IncrDo(ydv.d, xdv.Value, zdv.d) 158 if err = checkErrSetDeriv(err, ydv); err != nil { 159 return errors.Wrapf(err, autodiffFail, x) 160 } 161 return 162 case !transA && !transB: 163 op.transB = true 164 err = op.IncrDo(xdv.d, zdv.d, ydv.Value) 165 if err = checkErrSetDeriv(err, xdv); err != nil { 166 return errors.Wrapf(err, autodiffFail, x) 167 } 168 169 op.transA = true 170 op.transB = false 171 err = op.IncrDo(ydv.d, xdv.Value, zdv.d) 172 if err = checkErrSetDeriv(err, ydv); err != nil { 173 return errors.Wrapf(err, autodiffFail, x) 174 } 175 return 176 } 177 178 panic("unreachable") 179 } 180 181 func matVecMulDiffExpr(transA, transB bool, x, y, z, gradZ *Node) (retVal Nodes, err error) { 182 var dzdx, dzdy *Node 183 if transA { 184 dzdx, err = OuterProd(y, gradZ) 185 } else { 186 dzdx, err = OuterProd(gradZ, y) 187 } 188 189 if err != nil { 190 return nil, errors.Wrap(err, "Failed to carry outper product") 191 } 192 193 op := linAlgBinOp{ 194 āBinaryOperator: matVecMulOperator, 195 transA: !transA, 196 } 197 198 if dzdy, err = binOpNode(op, x, gradZ); err != nil { 199 return nil, errors.Wrapf(err, binOpNodeFail, op) 200 } 201 return Nodes{dzdx, dzdy}, nil 202 } 203 204 func matVecMulDiff(ctx ExecutionContext, transA, transB bool, x, y, z *Node) (err error) { 205 xdv, ydv, zdv := getDV3(x, y, z) 206 207 op := linAlgBinOp{ 208 āBinaryOperator: outerProdOperator, 209 } 210 211 if transA { 212 err = op.IncrDo(xdv.d, ydv.Value, zdv.d) 213 } else { 214 err = op.IncrDo(xdv.d, zdv.d, ydv.Value) 215 } 216 if err = checkErrSetDeriv(err, xdv); err != nil { 217 return errors.Wrapf(err, autodiffFail, x) 218 } 219 220 op = linAlgBinOp{ 221 āBinaryOperator: matVecMulOperator, 222 transA: !transA, 223 } 224 225 err = op.IncrDo(ydv.d, xdv.Value, zdv.d) 226 if err = checkErrSetDeriv(err, ydv); err != nil { 227 return errors.Wrapf(err, autodiffFail, x) 228 } 229 return 230 } 231 232 func vecDotDiffExpr(transA, transB bool, x, y, z, gradZ *Node) (retVal Nodes, err error) { 233 var dzdx, dzdy *Node 234 if dzdx, err = HadamardProd(y, gradZ); err == nil { 235 if dzdy, err = HadamardProd(x, gradZ); err == nil { 236 retVal = Nodes{dzdx, dzdy} 237 } else { 238 return nil, errors.Wrap(err, "Failed to carry HadamardProd()") 239 } 240 } else { 241 return nil, errors.Wrap(err, "Failed to carry HadamardProd()") 242 } 243 return 244 } 245 246 func vecDotDiff(ctx ExecutionContext, transA, transB bool, x, y, z *Node) (err error) { 247 xdv, ydv, zdv := getDV3(x, y, z) 248 249 mul := newElemBinOp(mulOpType, x, z) 250 err = mul.IncrDo(xdv.d, ydv.Value, zdv.d) 251 if err = checkErrSetDeriv(err, xdv); err != nil { 252 return errors.Wrapf(err, autodiffFail, x) 253 } 254 255 err = mul.IncrDo(ydv.d, xdv.Value, zdv.d) 256 if err = checkErrSetDeriv(err, ydv); err != nil { 257 return errors.Wrapf(err, autodiffFail, x) 258 } 259 return 260 } 261 262 func outerProdDiffExpr(transA, transB bool, x, y, z, gradZ *Node) (retVal Nodes, err error) { 263 var dzdx, dzdy *Node 264 if dzdx, err = Mul(x, gradZ); err == nil { 265 if dzdy, err = Mul(y, gradZ); err == nil { 266 retVal = Nodes{dzdx, dzdy} 267 } else { 268 return nil, errors.Wrap(err, "Failed to carry Mul()") 269 } 270 } else { 271 return nil, errors.Wrap(err, "Failed to carry Mul()") 272 } 273 return 274 } 275 276 func outerProdDiff(ctx ExecutionContext, transA, transB bool, x, y, z *Node) (err error) { 277 xdv, ydv, zdv := getDV3(x, y, z) 278 279 mul := newElemBinOp(mulOpType, x, z) 280 err = mul.IncrDo(xdv.d, xdv.Value, zdv.d) 281 err = mul.IncrDo(xdv.d, ydv.Value, zdv.d) 282 if err = checkErrSetDeriv(err, xdv); err != nil { 283 return errors.Wrapf(err, autodiffFail, x) 284 } 285 286 err = mul.IncrDo(ydv.d, ydv.Value, zdv.d) 287 if err = checkErrSetDeriv(err, ydv); err != nil { 288 return errors.Wrapf(err, autodiffFail, x) 289 } 290 return 291 } 292 293 func batchedMatMulDiffExpr(transA, transB bool, x, y, z, gradZ *Node) (retVal Nodes, err error) { 294 var dzdx, dzdy *Node 295 op := linAlgBinOp{ 296 āBinaryOperator: batchedMatMulOperator, 297 } 298 299 switch { 300 case transA && transB: 301 op.transA = transA 302 op.transB = transB 303 if dzdx, err = binOpNode(op, y, gradZ); err != nil { 304 return nil, errors.Wrapf(err, binOpNodeFail, op) 305 } 306 if dzdy, err = binOpNode(op, gradZ, x); err != nil { 307 return nil, errors.Wrapf(err, binOpNodeFail, op) 308 } 309 case !transA && transB: 310 if dzdx, err = binOpNode(op, gradZ, y); err != nil { 311 return nil, errors.Wrapf(err, binOpNodeFail, op) 312 } 313 314 op.transA = true 315 if dzdy, err = binOpNode(op, gradZ, x); err != nil { 316 return nil, errors.Wrapf(err, binOpNodeFail, op) 317 } 318 case transA && !transB: 319 op.transB = true 320 if dzdx, err = binOpNode(op, y, gradZ); err != nil { 321 return nil, errors.Wrapf(err, binOpNodeFail, op) 322 } 323 324 op.transB = false 325 if dzdy, err = binOpNode(op, x, gradZ); err != nil { 326 return nil, errors.Wrapf(err, binOpNodeFail, op) 327 } 328 case !transA && !transB: 329 // dzdy 330 op.transA = false 331 op.transB = true 332 if dzdx, err = binOpNode(op, gradZ, y); err != nil { 333 return nil, errors.Wrapf(err, binOpNodeFail, op) 334 } 335 // do dzdx 336 op.transA = true 337 op.transB = false 338 if dzdy, err = binOpNode(op, x, gradZ); err != nil { 339 return nil, errors.Wrapf(err, binOpNodeFail, op) 340 } 341 } 342 retVal = Nodes{dzdx, dzdy} 343 return 344 } 345 346 func batchedMatMulDiff(ctx ExecutionContext, transA, transB bool, x, y, z *Node) (err error) { 347 xdv, ydv, zdv := getDV3(x, y, z) 348 349 op := linAlgBinOp{ 350 āBinaryOperator: batchedMatMulOperator, 351 } 352 353 switch { 354 case transA && transB: 355 op.transA = transA 356 op.transB = transB 357 358 // dzdx 359 err = op.IncrDo(xdv.d, ydv.Value, zdv.d) 360 if err = checkErrSetDeriv(err, xdv); err != nil { 361 return errors.Wrapf(err, autodiffFail, x) 362 } 363 364 // dzdy 365 err = op.IncrDo(ydv.d, zdv.d, xdv.Value) 366 if err = checkErrSetDeriv(err, ydv); err != nil { 367 return errors.Wrapf(err, autodiffFail, y) 368 } 369 370 return 371 372 case !transA && transB: 373 // dzdx 374 err = op.IncrDo(xdv.d, zdv.d, ydv.Value) 375 if err = checkErrSetDeriv(err, xdv); err != nil { 376 return errors.Wrapf(err, autodiffFail, x) 377 } 378 379 // dzdy 380 op.transA = true 381 err = op.IncrDo(ydv.d, zdv.d, xdv.Value) 382 if err = checkErrSetDeriv(err, ydv); err != nil { 383 return errors.Wrapf(err, autodiffFail, x) 384 } 385 386 return 387 388 case transA && !transB: 389 // dzdx 390 op.transB = true 391 err = op.IncrDo(xdv.d, ydv.Value, zdv.d) 392 if err = checkErrSetDeriv(err, xdv); err != nil { 393 return errors.Wrapf(err, autodiffFail, x) 394 } 395 396 // dzdy 397 op.transA = false 398 op.transB = false 399 err = op.IncrDo(ydv.d, xdv.Value, zdv.d) 400 if err = checkErrSetDeriv(err, ydv); err != nil { 401 return errors.Wrapf(err, autodiffFail, x) 402 } 403 return 404 case !transA && !transB: 405 op.transB = true 406 err = op.IncrDo(xdv.d, zdv.d, ydv.Value) 407 if err = checkErrSetDeriv(err, xdv); err != nil { 408 return errors.Wrapf(err, autodiffFail, x) 409 } 410 411 op.transA = true 412 op.transB = false 413 err = op.IncrDo(ydv.d, xdv.Value, zdv.d) 414 if err = checkErrSetDeriv(err, ydv); err != nil { 415 return errors.Wrapf(err, autodiffFail, x) 416 } 417 return 418 } 419 420 panic("unreachable") 421 } 422 423 func batchedMatMul(a, b, c tensor.Tensor, transA, transB, incr bool) (retVal tensor.Tensor, err error) { 424 shapeA := a.Shape().Clone() 425 shapeB := b.Shape().Clone() 426 outer := shapeA[:len(shapeA)-2] 427 innerA := shapeA[len(shapeA)-2:] 428 innerB := shapeB[len(shapeB)-2:] 429 430 if c == nil { 431 newShape := append(outer, innerA[0], innerB[1]) 432 c = tensor.New(tensor.Of(a.Dtype()), tensor.WithShape(newShape...), tensor.WithEngine(a.Engine())) 433 } 434 435 slices := make([]sli, len(outer)) 436 ss := make([]tensor.Slice, len(slices)) 437 for i := range slices { 438 slices[i].end = slices[i].start + 1 439 ss[i] = &slices[i] 440 } 441 442 var as, bs, cs tensor.Tensor 443 for halt := false; !halt; halt = incrSlices(slices, outer) { 444 if as, err = a.Slice(ss...); err != nil { 445 return nil, errors.Wrapf(err, "Slicing %v from a failed", ss) 446 } 447 if bs, err = b.Slice(ss...); err != nil { 448 return nil, errors.Wrapf(err, "Slicing %v from b failed", ss) 449 } 450 if cs, err = c.Slice(ss...); err != nil { 451 return nil, errors.Wrapf(err, "Slicing %v from c failed", ss) 452 } 453 454 if transA { 455 as.T() 456 } 457 if transB { 458 bs.T() 459 } 460 461 var fo tensor.FuncOpt 462 if incr { 463 fo = tensor.WithIncr(cs) 464 } else { 465 fo = tensor.WithReuse(cs) 466 } 467 468 if _, err = tensor.MatMul(as, bs, fo); err != nil { 469 return nil, errors.Wrapf(err, "MatMul on batch %v failed.", ss) 470 } 471 472 } 473 474 return c, nil 475 } 476 477 // incrSlices increments the slices. If everything has matched then return true 478 func incrSlices(a []sli, shp tensor.Shape) (halt bool) { 479 for i := len(a) - 1; i >= 0; i-- { 480 if shp[i]-a[i].start == 1 { 481 a[i].start = 0 482 a[i].end = 1 483 if i == 0 { 484 return true 485 } 486 continue 487 } 488 489 a[i].start++ 490 a[i].end = a[i].start + 1 491 return false 492 } 493 return true 494 }