gorgonia.org/gorgonia@v0.9.17/operations.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 6 "github.com/pkg/errors" 7 "gorgonia.org/tensor" 8 ) 9 10 // contains all public operations that can be performed on nodes 11 // all the functions here have the signature: 12 // func (...) (*Node, error) 13 14 /* BINARY FUNCTIONS */ 15 func binOpNode(op BinaryOp, a, b *Node) (retVal *Node, err error) { 16 stabLogf("Creating node for %v, a: %p, b: %p", op, a, b) 17 enterLogScope() 18 defer leaveLogScope() 19 // maybe make stabilization a build tag? 20 if stabilization { 21 enterLogScope() 22 if ebo, ok := op.(elemBinOp); ok { 23 ot := ebo.binOpType() 24 25 enterLogScope() 26 for _, fn := range binOpStabilizationFns[ot] { 27 if retVal, err = fn(a, b); err == nil { 28 leaveLogScope() 29 return 30 } 31 32 if _, ok := err.(errNoStabilization); !ok { 33 leaveLogScope() 34 return 35 } 36 } 37 leaveLogScope() 38 } 39 leaveLogScope() 40 } 41 stabLogf("No bin op stabilization") 42 return ApplyOp(op, a, b) 43 } 44 45 // Mul is the general handler for multiplication of nodes. It is extremely overloaded. Only use if you know what you're doing 46 // 47 // If any of the nodes are ScalarType, then it'll be redirected to HadamardProd() instead 48 // If the nodes are both vectors (that is, have a shape of (x, 1) or (1, x)), then the operator used will be a vectorDot 49 // If only one of the nodes is a vector, then the operator used will be a matrix-vector multiplication will be used, and most importantly, 50 // a transpose will be used (when necessary) 51 // If both nodes are matrices, then well, matrix multiplication will be done 52 func Mul(a, b *Node) (retVal *Node, err error) { 53 if a.IsScalar() || b.IsScalar() { 54 return HadamardProd(a, b) 55 } 56 57 var op BinaryOp 58 switch { 59 case a.IsVector() && b.IsVector(): 60 op = linAlgBinOp{āBinaryOperator: vecDotOperator} 61 return binOpNode(op, a, b) 62 case a.IsVector() && b.IsMatrix(): 63 op = linAlgBinOp{āBinaryOperator: matVecMulOperator, transA: true} 64 return binOpNode(op, b, a) 65 case a.IsMatrix() && b.IsVector(): 66 op = linAlgBinOp{āBinaryOperator: matVecMulOperator} 67 return binOpNode(op, a, b) 68 case a.IsMatrix() && b.IsMatrix(): 69 op = linAlgBinOp{āBinaryOperator: matMulOperator} 70 return binOpNode(op, a, b) 71 default: 72 return nil, errors.Errorf(nyiFail, "Mul", fmt.Sprintf("a %v b %v", a.shape, b.shape)) 73 } 74 } 75 76 // BatchedMatMul returns a node representing the batched mat mul operation. 77 // 78 // A list of transpose options are allowed. The 79 func BatchedMatMul(a, b *Node, transes ...bool) (retVal *Node, err error) { 80 op := linAlgBinOp{āBinaryOperator: batchedMatMulOperator} 81 switch len(transes) { 82 case 0: 83 // noop 84 case 1: 85 // transA 86 op.transA = transes[0] 87 case 2: 88 // transA and transB 89 op.transA = transes[0] 90 op.transB = transes[1] 91 default: 92 // unsupported 93 op.transA = transes[0] 94 op.transB = transes[1] 95 } 96 97 return binOpNode(op, a, b) 98 } 99 100 // OuterProd returns a Node representing the outer product of two vectors. This function will return an error if both input nodes are not vectors 101 func OuterProd(a, b *Node) (retVal *Node, err error) { 102 if !a.IsVector() || !b.IsVector() { 103 return nil, errors.Errorf("Expected only vectors to be able to do OuterProd. %v is %v. %v is %v", a, a.Shape(), b, b.Shape()) //for now 104 } 105 106 // TODO: maybe align shapes? 107 op := linAlgBinOp{āBinaryOperator: outerProdOperator} 108 return binOpNode(op, a, b) 109 } 110 111 // Div is a shortcut function for HadamardDiv for scalar values. For matrix/tensor values, the matrix division operation is not yet handled, and will panic. 112 func Div(a, b *Node) (retVal *Node, err error) { 113 if a.IsScalar() || b.IsScalar() || a.Shape().Eq(b.Shape()) { 114 return HadamardDiv(a, b) 115 } 116 117 // otherwise, matrix division 118 panic("Unhandled") 119 } 120 121 // Auto automatically calculates the padding for the given operations, for example: 122 // gorgonia.Auto(gorgonia.BroadcastHadamardProd, a, b) 123 func Auto(op func(a, b *Node, leftPattern, rightPattern []byte) (*Node, error), a, b *Node) (*Node, error) { 124 aShape := a.Shape() 125 bShape := b.Shape() 126 127 if aShape.Dims() != bShape.Dims() { 128 return nil, fmt.Errorf("shapes %v and %v should have the same dimensions", aShape, bShape) 129 } 130 131 var ( 132 leftPattern, rightPattern []byte 133 ) 134 135 for i := 0; i < aShape.Dims(); i++ { 136 if aShape[i] > bShape[i] { 137 leftPattern = append(leftPattern, byte(i)) 138 } else if aShape[i] < bShape[i] { 139 rightPattern = append(rightPattern, byte(i)) 140 } 141 } 142 143 return op(a, b, rightPattern, leftPattern) 144 } 145 146 /* UNARY STUFF */ 147 148 func unaryOpNode(op Op, a *Node) (retVal *Node, err error) { 149 stabLogf("Creating node for %v, a: %p %v", op, a, a) 150 enterLogScope() 151 defer leaveLogScope() 152 if stabilization { 153 154 // do optimization/stabilization 155 // TODO: maybe recursively stabilize? 156 enterLogScope() 157 ot := op.(elemUnaryOp).unaryOpType() 158 for _, fn := range unaryOpStabilizationFns[ot] { 159 if retVal, err = fn(a); err == nil { 160 stabLogf("stabilized") 161 leaveLogScope() 162 return 163 } 164 165 if _, ok := err.(errNoStabilization); !ok { 166 stabLogf("Actual error") 167 leaveLogScope() 168 return 169 } 170 stabLogf("No stabilization found") 171 } 172 leaveLogScope() 173 stabLogf("No stabilizations - retVal: %v", retVal) 174 } 175 176 return ApplyOp(op, a) 177 } 178 179 // more complex unaries 180 181 // SoftMax implements the softmax operation. The softmax operation is a stable operation. 182 func SoftMax(x *Node, axis ...int) (*Node, error) { 183 xShape := x.Shape() 184 op := newSoftmaxOp(xShape, axis...) 185 186 return ApplyOp(op, x) 187 } 188 189 // LogSumExp performs addition in the log domain 190 func LogSumExp(a *Node, axis int) (retVal *Node, err error) { 191 var max, exp, sum, logSum *Node 192 if max, err = Max(a, axis); err != nil { 193 return nil, errors.Wrap(err, operationError) 194 } 195 if retVal, err = Sub(a, max); err == nil { 196 if exp, err = Exp(retVal); err == nil { 197 if sum, err = Sum(exp, axis); err == nil { 198 if sum, err = Add(sum, max); err == nil { 199 if logSum, err = Log(sum); err == nil { 200 return Sum(logSum, axis) 201 } 202 } 203 } 204 } 205 } 206 return nil, errors.Wrap(err, operationError) 207 } 208 209 /* Aggregate Functions */ 210 211 // At is a symbolic operation for getting a value at the provided coordinates. 212 // If the input is a scalar, all the coordinates MUST be 0, or else an error will be returned. 213 func At(a *Node, coords ...int) (retVal *Node, err error) { 214 if a.IsScalar() { 215 for _, c := range coords { 216 if c != 0 { 217 return nil, errors.Errorf("At() only works with scalars when the coordinates are (0...0). Got %v instead", coords) 218 } 219 } 220 return a, nil 221 } 222 223 dims := a.Dims() 224 op := atOp{ 225 coordinates: coords, 226 d: dims, 227 } 228 229 return ApplyOp(op, a) 230 } 231 232 // Max performs a max() on the input and the provided axes. 233 func Max(a *Node, along ...int) (retVal *Node, err error) { 234 if a.IsScalar() { 235 // can't max a scalar. Should return error 236 return a, nil 237 } 238 239 dims := a.Dims() 240 if len(along) == 0 { 241 along = intRange(0, dims) 242 } 243 244 op := newMaxOp(along, dims) 245 246 return ApplyOp(op, a) 247 } 248 249 // Mean performs a mean() on the input and the provided axes. 250 func Mean(a *Node, along ...int) (retVal *Node, err error) { 251 if a.IsScalar() { 252 // can't mean a scalar... return error 253 return a, nil 254 } 255 256 dims := a.Dims() 257 258 if len(along) == 0 { 259 along = intRange(0, dims) 260 } 261 262 var s *Node 263 if s, err = Sum(a, along...); err != nil { 264 return nil, errors.Wrap(err, operationError) 265 } 266 267 sizes := make(Nodes, len(along)) 268 for i, axis := range along { 269 if sizes[i], err = SizeOf(axis, a); err != nil { 270 return nil, errors.Wrap(err, operationError) 271 } 272 } 273 274 var counts *Node 275 if counts, err = ReduceMul(sizes); err == nil { 276 return HadamardDiv(s, counts) 277 } 278 return nil, errors.Wrap(err, operationError) 279 } 280 281 // Sum performs a sum() on the input and the provided axes. 282 func Sum(a *Node, along ...int) (retVal *Node, err error) { 283 if a.IsScalar() { 284 retVal = a // or error? 285 return 286 } 287 288 dims := a.Dims() 289 if len(along) == 0 { 290 along = intRange(0, dims) 291 } 292 293 op := newSumOp(along, a.shape, dims) 294 return ApplyOp(op, a) 295 } 296 297 // Norm returns the p-norm of a Value. Use p=2 if you want to use unordered norms. 298 // 299 // This is a simpler version of the norms found in the Tensor package, which specializes and optimizes even more 300 // (well, given it's adapted from Numpy, it is clearly way more optimized) 301 func Norm(a *Node, axis, p int) (retVal *Node, err error) { 302 if p == 2 { 303 if retVal, err = Square(a); err == nil { 304 if retVal, err = Sum(retVal, axis); err == nil { 305 if retVal, err = Sqrt(retVal); err != nil { 306 return nil, errors.Wrap(err, operationError) 307 } 308 } else { 309 return nil, errors.Wrap(err, operationError) 310 } 311 } else { 312 return nil, errors.Wrap(err, operationError) 313 } 314 return 315 } 316 317 var dt tensor.Dtype 318 if dt, err = dtypeOf(a.t); err != nil { 319 return nil, errors.Wrapf(err, "Failed to determine the dtype of %T", a.t) 320 } 321 322 var b, inv *Node 323 switch dt { 324 case Float32: 325 b = NewConstant(float32(p)) 326 inv = NewConstant(float32(1) / float32(p)) 327 case Float64: 328 b = NewConstant(float64(p)) 329 inv = NewConstant(float64(1) / float64(p)) 330 default: 331 return nil, errors.New("Cannot norm a non-floating point type") 332 } 333 334 if retVal, err = Pow(a, b); err == nil { 335 if retVal, err = Sum(retVal, axis); err == nil { 336 if retVal, err = Pow(retVal, inv); err != nil { 337 return nil, errors.Wrap(err, operationError) 338 } 339 } else { 340 return nil, errors.Wrap(err, operationError) 341 } 342 } else { 343 return nil, errors.Wrap(err, operationError) 344 } 345 return 346 } 347 348 // Reduction 349 350 // ReduceAdd takes a slice of *Nodes, and folds them into one by adding 351 func ReduceAdd(nodes Nodes, opts ...NodeConsOpt) (retVal *Node, err error) { 352 switch len(nodes) { 353 case 0: 354 return nil, nil // or error? 355 case 1: 356 return nodes[0], nil 357 case 2: 358 if retVal, err = Add(nodes[0], nodes[1]); err == nil { 359 for _, opt := range opts { 360 opt(retVal) 361 } 362 } else { 363 return nil, errors.Wrap(err, operationError) 364 } 365 return 366 } 367 368 retVal = nodes[0] 369 for i, n := range nodes { 370 if i == 0 { 371 continue 372 } 373 374 if retVal, err = Add(retVal, n); err != nil { 375 err = errors.Wrap(err, operationError) 376 return 377 } 378 for _, opt := range opts { 379 opt(retVal) 380 } 381 } 382 return 383 } 384 385 // ReduceMul is like foldl(*, nodes) 386 func ReduceMul(nodes Nodes, opts ...NodeConsOpt) (retVal *Node, err error) { 387 switch len(nodes) { 388 case 0: 389 return nil, nil // or error? 390 case 1: 391 return nodes[0], nil 392 case 2: 393 if retVal, err = Mul(nodes[0], nodes[1]); err == nil { 394 for _, opt := range opts { 395 opt(retVal) 396 } 397 } else { 398 return nil, errors.Wrap(err, operationError) 399 } 400 return 401 } 402 403 retVal = nodes[0] 404 for i, n := range nodes { 405 if i == 0 { 406 continue 407 } 408 409 if retVal, err = Mul(retVal, n); err != nil { 410 return nil, errors.Wrap(err, operationError) 411 } 412 for _, opt := range opts { 413 opt(retVal) 414 } 415 } 416 return 417 } 418 419 /* Shape related operations */ 420 421 // SizeOf returns the size of a value along an axis 422 func SizeOf(axis int, x *Node) (retVal *Node, err error) { 423 op := sizeOp{ 424 axis: axis, 425 d: x.Dims(), 426 } 427 428 // if the shape is known 429 if x.shape != nil { 430 op.val = x.shape[axis] 431 } 432 433 return ApplyOp(op, x) 434 } 435 436 // Slice slices a *Node. For T[:] slices, pass in nil. Will error out if node's type is not a Tensor 437 func Slice(n *Node, slices ...tensor.Slice) (retVal *Node, err error) { 438 if _, ok := n.t.(TensorType); !ok { 439 return nil, errors.Errorf("Cannot slice on non Tensor tensor. Got %T", n.t) 440 } 441 442 if len(slices) > n.shape.Dims() { 443 return nil, errors.Errorf("Cannot slice %v. Shape: %v. Slices: %d", n, n.shape, len(slices)) 444 } 445 446 retVal = n 447 var dimsChanged int 448 for i, s := range slices { 449 var along int 450 if i > 0 { 451 if prev := slices[i-1]; prev != nil { 452 if prev.End()-prev.Start() == 1 { 453 dimsChanged++ 454 } 455 } 456 } 457 along = i - dimsChanged 458 459 op := newSliceOp(s, along, retVal.Dims()) 460 if retVal, err = ApplyOp(op, retVal); err != nil { 461 return 462 } 463 } 464 return 465 } 466 467 // Transpose performs a transpose on the input and provided permutation axes. 468 func Transpose(n *Node, axes ...int) (retVal *Node, err error) { 469 // prep axes 470 if len(axes) > 0 && len(axes) != n.Dims() { 471 return nil, errors.Errorf("n has %d dims, while requested transposes is %d", n.Dims(), len(axes)) 472 } 473 dims := len(n.shape) 474 if len(axes) == 0 || axes == nil { 475 axes = make([]int, dims) 476 for i := 0; i < dims; i++ { 477 axes[i] = dims - 1 - i 478 } 479 } 480 481 // if axes is 0, 1, 2, 3... then no op 482 if monotonic, incr1 := tensor.IsMonotonicInts(axes); monotonic && incr1 && axes[0] == 0 { 483 retVal = n 484 return 485 } 486 op := transposeOp{ 487 pattern: axes, 488 d: len(axes), 489 } 490 491 return ApplyOp(op, n) 492 } 493 494 // Concat performs a concatenate on the provided axis and inputs. 495 func Concat(axis int, ns ...*Node) (retVal *Node, err error) { 496 // check that all the nodes have the same number of dimensions 497 var d int 498 for i, n := range ns { 499 if i == 0 { 500 d = n.shape.Dims() 501 continue 502 } 503 504 if n.shape.Dims() != d { 505 err = errors.Errorf("Dimension mismatch. Expected all the nodes to be concatenated to have %d dimensions. Got %d instead", d, n.shape.Dims()) 506 return 507 } 508 } 509 510 if d == 0 { 511 err = errors.Errorf("Concat only works on Tensor nodes") 512 return 513 } 514 515 if axis >= d { 516 err = errors.Errorf("Invalid axis. Nodes have %d dimensions. Axis is %d", d, axis) 517 return 518 } 519 520 op := concatOp{axis: axis, d: d, children: len(ns)} 521 return ApplyOp(op, ns...) 522 } 523 524 // Unconcat is the opposite of the built in concat function 525 // TODO: port this back to Gorgonia and use Gorgonia's sli instead 526 func Unconcat(a *Node, along int, n int) (Nodes, error) { 527 aShape := a.Shape() 528 if along < 0 || along > aShape.Dims() { 529 return nil, errors.Errorf("Unable to Unconcat a of shape %v along axis %d", aShape, along) 530 } 531 532 if aShape[along]%n != 0 { 533 return nil, errors.Errorf("Axis %d of %v cannot be nicely split into %d parts", along, aShape, n) 534 } 535 536 newShapeAlong := aShape[along] / n 537 batches := aShape[along] / newShapeAlong 538 539 var start int 540 var retVal Nodes 541 for i := 0; i < batches; i++ { 542 ss := make([]tensor.Slice, len(aShape)) 543 for i := range ss { 544 if i == along { 545 ss[i] = S(start, start+newShapeAlong) 546 } else { 547 ss[i] = S(0, aShape[i]) 548 } 549 } 550 551 a2, err := Slice(a, ss...) 552 if err != nil { 553 return nil, errors.Wrapf(err, "Unable to slice a of shape %v along %d on batch %d. Slices were: %v", aShape, along, i, ss) 554 } 555 retVal = append(retVal, a2) 556 start += newShapeAlong 557 } 558 return retVal, nil 559 } 560 561 // Reshape reshapes a node and returns a new node with the new shape 562 func Reshape(n *Node, to tensor.Shape) (retVal *Node, err error) { 563 // check shape 564 var negs int 565 var infer int 566 for i, s := range to { 567 if s < 0 { 568 negs++ 569 infer = i 570 } 571 } 572 if negs > 1 { 573 return nil, errors.Errorf("Unfortunately, inference of reshape parameters only allow for one variable (a negative number). Got %v instead", to) 574 } 575 576 if negs == 1 { 577 prod := 1 578 for i, s := range to { 579 if i == infer { 580 continue 581 } 582 prod *= s 583 } 584 inferred, rem := divmod(n.Shape().TotalSize(), prod) 585 if rem != 0 { 586 return nil, errors.Errorf("Cannot reshape %v to %v", n.Shape(), to) 587 } 588 to[infer] = inferred 589 } 590 591 // the Node n might not have shape at this point, in that case we skip the check 592 if n.Shape().Dims() > 0 && n.Shape().TotalSize() != to.TotalSize() { 593 return nil, errors.Errorf("shape size doesn't not match. Expected %v, got %v", n.Shape().TotalSize(), to.TotalSize()) 594 } 595 596 op := reshapeOp{ 597 from: n.Shape(), 598 to: to, 599 } 600 return ApplyOp(op, n) 601 } 602 603 // Ravel flattens the given node and returns the new node 604 func Ravel(n *Node) (retVal *Node, err error) { 605 return Reshape(n, tensor.Shape{n.shape.TotalSize()}) 606 } 607 608 /* Contraction related operations */ 609 610 // Tensordot performs a tensor contraction of a and b along specified axes. 611 func Tensordot(aAxes []int, bAxes []int, a, b *Node) (retVal *Node, err error) { 612 613 // Check if input tensors actually have dim ⩾ 1 614 if (len(a.Shape()) < 1) || (len(b.Shape()) < 1) || (a.Dims() < 1) || (b.Dims() < 1) { 615 return nil, errors.New("Input Node's shape should have length at least 1") 616 } 617 618 // Check if number of specified axes for a and b matches 619 if len(aAxes) != len(bAxes) { 620 return nil, errors.New("Number of Axes supplied along which to contract tensors does not match") 621 } 622 623 // Check for duplicate indices 624 if containsDuplicate(aAxes) || containsDuplicate(bAxes) { 625 return nil, errors.New("Supplied axes to contract along contain duplicates") 626 } 627 628 // Check for more compatibility 629 630 aShape := a.Shape() 631 bShape := b.Shape() 632 633 for _, aAxis := range aAxes { 634 if aAxis >= len(aShape) { 635 return nil, errors.New("Supplied higher higher axes number to contract along than Tensor's actual number of axes") 636 } 637 } 638 639 for _, bAxis := range bAxes { 640 if bAxis >= len(bShape) { 641 return nil, errors.New("Supplied higher higher axes number to contract along than Tensor's actual number of axes") 642 } 643 } 644 645 for aAxis, aDim := range aAxes { 646 if aShape[aDim] != bShape[bAxes[aAxis]] { 647 return nil, errors.New("Dimension mismatch: Can't contract tensors along supplied axes") 648 } 649 } 650 651 // Otherwise, apply contraction 652 op := makeTensordotOp(a, b, aAxes, bAxes) 653 654 return ApplyOp(op, a, b) 655 } 656 657 // Mish is a novel activation function that is self regularizing. 658 // 659 // https://arxiv.org/abs/1908.08681 660 func Mish(a *Node) (retVal *Node, err error) { 661 var sp, tsp *Node 662 if sp, err = Softplus(a); err != nil { 663 return nil, errors.Wrap(err, "Mish() - SoftPlus failed") 664 } 665 if tsp, err = Tanh(sp); err != nil { 666 return nil, errors.Wrap(err, "Mish() - Tanh failed") 667 } 668 return HadamardProd(a, tsp) 669 } 670 671 // Private functions 672 673 func containsDuplicate(slice []int) bool { 674 if nil == slice { 675 return false 676 } 677 678 for index1, value1 := range slice { 679 for index2, value2 := range slice { 680 if (value1 == value2) && (index1 != index2) { 681 return true 682 } 683 } 684 } 685 686 return false 687 }