gorgonia.org/gorgonia@v0.9.17/shape.go (about) 1 package gorgonia 2 3 import ( 4 "github.com/pkg/errors" 5 "gorgonia.org/tensor" 6 ) 7 8 var scalarShape = tensor.ScalarShape() 9 10 type axes []int 11 type coordinates []int 12 13 // only works for 2D 14 func transpose2D(shape tensor.Shape) tensor.Shape { 15 if len(shape) != 2 { 16 return shape 17 } 18 retVal := tensor.BorrowInts(2) 19 retVal[0] = shape[1] 20 retVal[1] = shape[0] 21 return retVal 22 } 23 24 // for batched matmul 25 func transposeBatch2D(shape tensor.Shape) tensor.Shape { 26 if len(shape) != 3 { 27 return shape 28 } 29 retVal := tensor.BorrowInts(3) 30 retVal[0] = shape[0] 31 retVal[1] = shape[2] 32 retVal[2] = shape[1] 33 return retVal 34 } 35 36 // calcBroadcastShape calculates the new shape of a given Node and broadcast axes. 37 // Note that `a` will be the *Node reshaped to the newShape. 38 func calcBroadcastShape(a *Node, expectedDims int, broadcastAlong []int) (newShape tensor.Shape) { 39 shp := a.Shape() 40 if shp.Dims() == expectedDims { 41 newShape = shp.Clone() 42 } else { 43 newShape = make(tensor.Shape, expectedDims) 44 for _, i := range broadcastAlong { 45 newShape[i] = 1 46 } 47 } 48 49 switch { 50 case a.Shape().Eq(tensor.ScalarShape()): 51 for i := range newShape { 52 newShape[i] = 1 53 } 54 case shp.Dims() == expectedDims: 55 default: 56 for _, s := range a.Shape() { 57 // search for first non 0 58 for j := range newShape { 59 if newShape[j] == 0 { 60 newShape[j] = s 61 break 62 } 63 } 64 } 65 } 66 67 return 68 } 69 70 // KeepDims is a function that ensures that input and output dimensions are the same though the shape may change. 71 // 72 // The expandLeft flag in the function indicates if any shape expansion should be done leftwards or rightwards. 73 // For example, if fn() returns a tensor with a shape (3) and the desired dimension is 2, 74 // then if `expandLeft` is true the result will be `(1, 3)`. Otherwise the result will be `(3, 1)`. 75 // 76 // At the moment, results that turn into scalars cannot have their dimensions kept - the semantics isn't well established yet and is a work in progress. 77 func KeepDims(a *Node, expandLeft bool, fn func(a *Node) (*Node, error)) (*Node, error) { 78 oshape := a.Shape() 79 adims := oshape.Dims() 80 b, err := fn(a) 81 if err != nil { 82 return nil, err 83 } 84 85 // happy path = quick exit 86 newShape := b.Shape() 87 if newShape.Eq(oshape) { 88 return b, nil 89 } 90 91 bdims := newShape.Dims() 92 diff := adims - bdims 93 if diff < 0 { 94 return b, errors.Errorf("Unable to KeepDims for a result with shape %v. It has more dimensions than input %v", newShape, oshape) 95 } 96 var retShape tensor.Shape 97 if expandLeft { 98 retShape = tensor.BorrowInts(diff + newShape.Dims()) 99 for i := 0; i < diff; i++ { 100 retShape[i] = 1 101 } 102 copy(retShape[diff:], newShape) 103 } else { 104 retShape = newShape.Clone() 105 for i := 0; i < diff; i++ { 106 retShape = append(retShape, 1) 107 } 108 109 } 110 return Reshape(b, retShape) 111 }