gorgonia.org/tensor@v0.9.24/shape.go (about)

     1  package tensor
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/pkg/errors"
     7  )
     8  
     9  var scalarShape = Shape{}
    10  
    11  // ScalarShape represents a scalar. It has no dimensions, no sizes
    12  func ScalarShape() Shape { return scalarShape }
    13  
    14  // Shape represents the dimensions of a Tensor. A (2,3) matrix has a shape of (2,3) - 2 rows, 3 columns.
    15  // Likewise, a shape of (2,3,4) means a Tensor has 3 dimensions: 2 layers, 3 rows, 4 columns.
    16  //
    17  // Vectors are of particular note. This package defines a shape of (x, 1) as a column vector and
    18  // a (1, x) as a row vector. Row vectors and column vectors are matrices as well. It is important to note that
    19  // row and column vectors and vanilla vectors are comparable under some circumstances
    20  type Shape []int
    21  
    22  // TotalSize returns the number of elements expected in a Tensor of a certain shape
    23  func (s Shape) TotalSize() int {
    24  	return ProdInts([]int(s))
    25  }
    26  
    27  // CalcStrides calculates the default strides for a shape
    28  func (s Shape) CalcStrides() []int {
    29  	if s.IsScalar() {
    30  		return nil
    31  	}
    32  
    33  	retVal := BorrowInts(len(s))
    34  	// if s.IsVector() {
    35  	// 	retVal[0] = 1
    36  	// 	retVal = retVal[:1]
    37  	// 	return retVal
    38  	// }
    39  
    40  	acc := 1
    41  	for i := len(s) - 1; i >= 0; i-- {
    42  		retVal[i] = acc
    43  		d := s[i]
    44  		if d < 0 {
    45  			panic("negative dimension size does not make sense")
    46  		}
    47  		acc *= d
    48  	}
    49  	return retVal
    50  }
    51  
    52  // CalcStridesWithMask is similar to CalcStrides, except that it has an argument, masks. It is used to mask out given dimensions
    53  // during calculation of stride
    54  func (s Shape) CalcStridesWithMask(mask []bool) []int {
    55  	if s.IsScalarEquiv() {
    56  		return nil
    57  	}
    58  
    59  	retVal := BorrowInts(len(s))
    60  	if s.IsVector() {
    61  		retVal[0] = 1
    62  		retVal = retVal[:1]
    63  		return retVal
    64  	}
    65  
    66  	if len(mask) != s.Dims() {
    67  		panic("mask length must be equal to number of shape dimensions")
    68  	}
    69  	acc := 1
    70  	for i := len(s) - 1; i >= 0; i-- {
    71  		if mask[i] {
    72  			retVal[i] = acc
    73  		} else {
    74  			retVal[i] = 0
    75  		}
    76  		d := s[i]
    77  		if d < 0 {
    78  			panic("negative dimension size does not make sense")
    79  		}
    80  		if mask[i] {
    81  			acc *= d
    82  		}
    83  	}
    84  
    85  	return retVal
    86  }
    87  
    88  // CalcStridesColMajor is like CalcStrides, but assumes a col major layout
    89  func (s Shape) CalcStridesColMajor() []int {
    90  	if s.IsScalarEquiv() {
    91  		return nil
    92  	}
    93  
    94  	retVal := BorrowInts(len(s))
    95  	if s.IsVector() {
    96  		retVal[0] = 1
    97  		retVal = retVal[:1]
    98  		return retVal
    99  	}
   100  
   101  	acc := 1
   102  	for i := 0; i < len(s); i++ {
   103  		retVal[i] = acc
   104  		d := s[i]
   105  		if d < 0 {
   106  			panic("negative dimension size does not make sense")
   107  		}
   108  		acc *= d
   109  	}
   110  	return retVal
   111  }
   112  
   113  // Eq indicates if a shape is equal with another. There is a soft concept of equality when it comes to vectors.
   114  //
   115  // If s is a column vector and other is a vanilla vector, they're considered equal if the size of the column dimension is the same as the vector size;
   116  // if s is a row vector and other is a vanilla vector, they're considered equal if the size of the row dimension is the same as the vector size
   117  func (s Shape) Eq(other Shape) bool {
   118  	if s.IsScalar() && other.IsScalar() {
   119  		return true
   120  	}
   121  
   122  	if s.IsVector() && other.IsVector() {
   123  		switch {
   124  		case len(s) == 2 && len(other) == 1:
   125  			if (s.IsColVec() && s[0] == other[0]) || (s.IsRowVec() && s[1] == other[0]) {
   126  				return true
   127  			}
   128  			return false
   129  		case len(s) == 1 && len(other) == 2:
   130  			if (other.IsColVec() && other[0] == s[0]) || (other.IsRowVec() && other[1] == s[0]) {
   131  				return true
   132  			}
   133  			return false
   134  		}
   135  	}
   136  
   137  	if len(s) != len(other) {
   138  		return false
   139  	}
   140  
   141  	for i, v := range s {
   142  		if other[i] != v {
   143  			return false
   144  		}
   145  	}
   146  	return true
   147  }
   148  
   149  // Clone clones a shape.
   150  func (s Shape) Clone() Shape {
   151  	retVal := BorrowInts(len(s))
   152  	copy(retVal, s)
   153  	return retVal
   154  }
   155  
   156  // IsScalar returns true if the access pattern indicates it's a scalar value
   157  func (s Shape) IsScalar() bool {
   158  	return len(s) == 0
   159  }
   160  
   161  // IsScalarEquiv returns true if the access pattern indicates it's a scalar-like value
   162  func (s Shape) IsScalarEquiv() bool {
   163  	if len(s) == 0 {
   164  		return true
   165  	}
   166  	isEquiv := true
   167  	for i := range s {
   168  		if s[i] != 1 {
   169  			return false
   170  		}
   171  	}
   172  	return isEquiv
   173  }
   174  
   175  // IsVector returns whether the access pattern falls into one of three possible definitions of vectors:
   176  //		vanilla vector (not a row or a col)
   177  //		column vector
   178  //		row vector
   179  func (s Shape) IsVector() bool { return s.IsColVec() || s.IsRowVec() || (len(s) == 1) }
   180  
   181  // IsColVec returns true when the access pattern has the shape (x, 1)
   182  func (s Shape) IsColVec() bool { return len(s) == 2 && (s[1] == 1 && s[0] > 1) }
   183  
   184  // IsRowVec returns true when the access pattern has the shape (1, x)
   185  func (s Shape) IsRowVec() bool { return len(s) == 2 && (s[0] == 1 && s[1] > 1) }
   186  
   187  // IsVectorLike returns true when the shape looks like a vector
   188  // e.g. a number that is surrounded by 1s:
   189  // 	(1, 1, ... 1, 10, 1, 1... 1)
   190  func (s Shape) IsVectorLike() bool {
   191  	var nonOnes int
   192  	for _, i := range s {
   193  		if i != 1 {
   194  			nonOnes++
   195  		}
   196  	}
   197  	return nonOnes == 1 || nonOnes == 0 // if there is only one non-one then it's a vector or a scalarlike.
   198  }
   199  
   200  // IsMatrix returns true if it's a matrix. This is mostly a convenience method. RowVec and ColVecs are also considered matrices
   201  func (s Shape) IsMatrix() bool { return len(s) == 2 }
   202  
   203  // Dims returns the number of dimensions in the shape
   204  func (s Shape) Dims() int { return len(s) }
   205  
   206  // DimSize returns the size of the dimension wanted.
   207  //
   208  // This method implemnents the DimSizer interface in Gorgonia.
   209  func (s Shape) DimSize(d int) (size int, err error) {
   210  	if (s.IsScalar() && d != 0) || (!s.IsScalar() && d >= len(s)) {
   211  		err = errors.Errorf(dimMismatch, len(s), d)
   212  		return
   213  	}
   214  
   215  	switch {
   216  	case s.IsScalar():
   217  		return 0, nil
   218  	default:
   219  		return s[d], nil
   220  	}
   221  }
   222  
   223  // S gives the new shape after a shape has been sliced. It's repeated from the AP S() method mainly because there are other functions in Gorgonia that uses only shape
   224  func (s Shape) S(slices ...Slice) (retVal Shape, err error) {
   225  	opDims := len(s)
   226  	if len(slices) > opDims {
   227  		err = errors.Errorf(dimMismatch, opDims, len(slices))
   228  		return
   229  	}
   230  
   231  	retVal = s.Clone()
   232  
   233  	for d, size := range s {
   234  		var sl Slice // default is a nil Slice
   235  		if d <= len(slices)-1 {
   236  			sl = slices[d]
   237  		}
   238  
   239  		var start, end, step int
   240  		if start, end, step, err = SliceDetails(sl, size); err != nil {
   241  			return
   242  		}
   243  
   244  		if step > 0 {
   245  			retVal[d] = (end - start) / step
   246  
   247  			//fix
   248  			if retVal[d] <= 0 {
   249  				retVal[d] = 1
   250  			}
   251  		} else {
   252  			retVal[d] = (end - start)
   253  		}
   254  
   255  	}
   256  
   257  	// drop any dimension with size 1, except the last dimension
   258  	offset := 0
   259  	dims := s.Dims()
   260  	for d := 0; d < dims; d++ {
   261  		if retVal[d] == 1 && offset+d <= len(slices)-1 && slices[offset+d] != nil /*&& d != t.dims-1  && dims > 2*/ {
   262  			retVal = append(retVal[:d], retVal[d+1:]...)
   263  			d--
   264  			dims--
   265  			offset++
   266  		}
   267  	}
   268  
   269  	if retVal.IsScalar() {
   270  		ReturnInts(retVal)
   271  		return ScalarShape(), nil
   272  	}
   273  
   274  	return
   275  }
   276  
   277  // Repeat returns the expected new shape given the repetition parameters.
   278  func (s Shape) Repeat(axis int, repeats ...int) (newShape Shape, finalRepeats []int, size int, err error) {
   279  	switch {
   280  	case axis == AllAxes:
   281  		size = s.TotalSize()
   282  		newShape = Shape{size}
   283  		axis = 0
   284  	case s.IsScalar():
   285  		size = 1
   286  		// special case for row vecs
   287  		if axis == 1 {
   288  			newShape = Shape{1, 0}
   289  		} else {
   290  			// otherwise it will be repeated into a vanilla vector
   291  			newShape = Shape{0}
   292  		}
   293  	case s.IsVector() && !s.IsRowVec() && !s.IsColVec() && axis == 1:
   294  		size = 1
   295  		newShape = s.Clone()
   296  		newShape = append(newShape, 1)
   297  	default:
   298  		if axis >= len(s) {
   299  			// error
   300  			err = errors.Errorf(invalidAxis, axis, s.Dims())
   301  			return
   302  		}
   303  		size = s[axis]
   304  		newShape = s.Clone()
   305  	}
   306  
   307  	// special case to allow generic repeats
   308  	if len(repeats) == 1 {
   309  		rep := repeats[0]
   310  		repeats = make([]int, size)
   311  		for i := range repeats {
   312  			repeats[i] = rep
   313  		}
   314  	}
   315  	reps := len(repeats)
   316  	if reps != size {
   317  		err = errors.Errorf(broadcastError, size, reps)
   318  		return
   319  	}
   320  
   321  	newSize := SumInts(repeats)
   322  	newShape[axis] = newSize
   323  	finalRepeats = repeats
   324  	return
   325  }
   326  
   327  // Concat returns the expected new shape given the concatenation parameters
   328  func (s Shape) Concat(axis int, ss ...Shape) (newShape Shape, err error) {
   329  	dims := s.Dims()
   330  
   331  	// check that all the concatenates have the same dimensions
   332  	for _, shp := range ss {
   333  		if shp.Dims() != dims {
   334  			err = errors.Errorf(dimMismatch, dims, shp.Dims())
   335  			return
   336  		}
   337  	}
   338  
   339  	// special case
   340  	if axis == AllAxes {
   341  		axis = 0
   342  	}
   343  
   344  	// nope... no negative indexing here.
   345  	if axis < 0 {
   346  		err = errors.Errorf(invalidAxis, axis, len(s))
   347  		return
   348  	}
   349  
   350  	if axis >= dims {
   351  		err = errors.Errorf(invalidAxis, axis, len(s))
   352  		return
   353  	}
   354  
   355  	newShape = Shape(BorrowInts(dims))
   356  	copy(newShape, s)
   357  
   358  	for _, shp := range ss {
   359  		for d := 0; d < dims; d++ {
   360  			if d == axis {
   361  				newShape[d] += shp[d]
   362  			} else {
   363  				// validate that the rest of the dimensions match up
   364  				if newShape[d] != shp[d] {
   365  					err = errors.Wrapf(errors.Errorf(dimMismatch, newShape[d], shp[d]), "Axis: %d, dimension it failed at: %d", axis, d)
   366  					return
   367  				}
   368  			}
   369  		}
   370  	}
   371  	return
   372  }
   373  
   374  // Format implements fmt.Formatter, and formats a shape nicely
   375  func (s Shape) Format(st fmt.State, r rune) {
   376  	switch r {
   377  	case 'v', 's':
   378  		st.Write([]byte("("))
   379  		for i, v := range s {
   380  			fmt.Fprintf(st, "%d", v)
   381  			if i < len(s)-1 {
   382  				st.Write([]byte(", "))
   383  			}
   384  		}
   385  		st.Write([]byte(")"))
   386  	default:
   387  		fmt.Fprintf(st, "%v", []int(s))
   388  	}
   389  }