gorgonia.org/gorgonia@v0.9.17/shape.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  	"gorgonia.org/tensor"
     6  )
     7  
     8  var scalarShape = tensor.ScalarShape()
     9  
    10  type axes []int
    11  type coordinates []int
    12  
    13  // only works for 2D
    14  func transpose2D(shape tensor.Shape) tensor.Shape {
    15  	if len(shape) != 2 {
    16  		return shape
    17  	}
    18  	retVal := tensor.BorrowInts(2)
    19  	retVal[0] = shape[1]
    20  	retVal[1] = shape[0]
    21  	return retVal
    22  }
    23  
    24  // for batched matmul
    25  func transposeBatch2D(shape tensor.Shape) tensor.Shape {
    26  	if len(shape) != 3 {
    27  		return shape
    28  	}
    29  	retVal := tensor.BorrowInts(3)
    30  	retVal[0] = shape[0]
    31  	retVal[1] = shape[2]
    32  	retVal[2] = shape[1]
    33  	return retVal
    34  }
    35  
    36  // calcBroadcastShape calculates the new shape of a given Node and broadcast axes.
    37  // Note that `a` will be the *Node reshaped to the newShape.
    38  func calcBroadcastShape(a *Node, expectedDims int, broadcastAlong []int) (newShape tensor.Shape) {
    39  	shp := a.Shape()
    40  	if shp.Dims() == expectedDims {
    41  		newShape = shp.Clone()
    42  	} else {
    43  		newShape = make(tensor.Shape, expectedDims)
    44  		for _, i := range broadcastAlong {
    45  			newShape[i] = 1
    46  		}
    47  	}
    48  
    49  	switch {
    50  	case a.Shape().Eq(tensor.ScalarShape()):
    51  		for i := range newShape {
    52  			newShape[i] = 1
    53  		}
    54  	case shp.Dims() == expectedDims:
    55  	default:
    56  		for _, s := range a.Shape() {
    57  			// search for first non 0
    58  			for j := range newShape {
    59  				if newShape[j] == 0 {
    60  					newShape[j] = s
    61  					break
    62  				}
    63  			}
    64  		}
    65  	}
    66  
    67  	return
    68  }
    69  
    70  // KeepDims is a function that ensures that input and output dimensions are the same though the shape may change.
    71  //
    72  // The expandLeft flag in the function indicates if any shape expansion should be done leftwards or rightwards.
    73  // For example, if fn() returns a tensor with a shape (3) and the desired dimension is 2,
    74  // then if `expandLeft` is true the result will be `(1, 3)`. Otherwise the result will be `(3, 1)`.
    75  //
    76  // At the moment, results that turn into scalars cannot have their dimensions kept - the semantics isn't well established yet and is a work in progress.
    77  func KeepDims(a *Node, expandLeft bool, fn func(a *Node) (*Node, error)) (*Node, error) {
    78  	oshape := a.Shape()
    79  	adims := oshape.Dims()
    80  	b, err := fn(a)
    81  	if err != nil {
    82  		return nil, err
    83  	}
    84  
    85  	// happy path = quick exit
    86  	newShape := b.Shape()
    87  	if newShape.Eq(oshape) {
    88  		return b, nil
    89  	}
    90  
    91  	bdims := newShape.Dims()
    92  	diff := adims - bdims
    93  	if diff < 0 {
    94  		return b, errors.Errorf("Unable to KeepDims for a result with shape %v. It has more dimensions than input %v", newShape, oshape)
    95  	}
    96  	var retShape tensor.Shape
    97  	if expandLeft {
    98  		retShape = tensor.BorrowInts(diff + newShape.Dims())
    99  		for i := 0; i < diff; i++ {
   100  			retShape[i] = 1
   101  		}
   102  		copy(retShape[diff:], newShape)
   103  	} else {
   104  		retShape = newShape.Clone()
   105  		for i := 0; i < diff; i++ {
   106  			retShape = append(retShape, 1)
   107  		}
   108  
   109  	}
   110  	return Reshape(b, retShape)
   111  }