gorgonia.org/tensor@v0.9.24/shape.go (about) 1 package tensor 2 3 import ( 4 "fmt" 5 6 "github.com/pkg/errors" 7 ) 8 9 var scalarShape = Shape{} 10 11 // ScalarShape represents a scalar. It has no dimensions, no sizes 12 func ScalarShape() Shape { return scalarShape } 13 14 // Shape represents the dimensions of a Tensor. A (2,3) matrix has a shape of (2,3) - 2 rows, 3 columns. 15 // Likewise, a shape of (2,3,4) means a Tensor has 3 dimensions: 2 layers, 3 rows, 4 columns. 16 // 17 // Vectors are of particular note. This package defines a shape of (x, 1) as a column vector and 18 // a (1, x) as a row vector. Row vectors and column vectors are matrices as well. It is important to note that 19 // row and column vectors and vanilla vectors are comparable under some circumstances 20 type Shape []int 21 22 // TotalSize returns the number of elements expected in a Tensor of a certain shape 23 func (s Shape) TotalSize() int { 24 return ProdInts([]int(s)) 25 } 26 27 // CalcStrides calculates the default strides for a shape 28 func (s Shape) CalcStrides() []int { 29 if s.IsScalar() { 30 return nil 31 } 32 33 retVal := BorrowInts(len(s)) 34 // if s.IsVector() { 35 // retVal[0] = 1 36 // retVal = retVal[:1] 37 // return retVal 38 // } 39 40 acc := 1 41 for i := len(s) - 1; i >= 0; i-- { 42 retVal[i] = acc 43 d := s[i] 44 if d < 0 { 45 panic("negative dimension size does not make sense") 46 } 47 acc *= d 48 } 49 return retVal 50 } 51 52 // CalcStridesWithMask is similar to CalcStrides, except that it has an argument, masks. It is used to mask out given dimensions 53 // during calculation of stride 54 func (s Shape) CalcStridesWithMask(mask []bool) []int { 55 if s.IsScalarEquiv() { 56 return nil 57 } 58 59 retVal := BorrowInts(len(s)) 60 if s.IsVector() { 61 retVal[0] = 1 62 retVal = retVal[:1] 63 return retVal 64 } 65 66 if len(mask) != s.Dims() { 67 panic("mask length must be equal to number of shape dimensions") 68 } 69 acc := 1 70 for i := len(s) - 1; i >= 0; i-- { 71 if mask[i] { 72 retVal[i] = acc 73 } else { 74 retVal[i] = 0 75 } 76 d := s[i] 77 if d < 0 { 78 panic("negative dimension size does not make sense") 79 } 80 if mask[i] { 81 acc *= d 82 } 83 } 84 85 return retVal 86 } 87 88 // CalcStridesColMajor is like CalcStrides, but assumes a col major layout 89 func (s Shape) CalcStridesColMajor() []int { 90 if s.IsScalarEquiv() { 91 return nil 92 } 93 94 retVal := BorrowInts(len(s)) 95 if s.IsVector() { 96 retVal[0] = 1 97 retVal = retVal[:1] 98 return retVal 99 } 100 101 acc := 1 102 for i := 0; i < len(s); i++ { 103 retVal[i] = acc 104 d := s[i] 105 if d < 0 { 106 panic("negative dimension size does not make sense") 107 } 108 acc *= d 109 } 110 return retVal 111 } 112 113 // Eq indicates if a shape is equal with another. There is a soft concept of equality when it comes to vectors. 114 // 115 // If s is a column vector and other is a vanilla vector, they're considered equal if the size of the column dimension is the same as the vector size; 116 // if s is a row vector and other is a vanilla vector, they're considered equal if the size of the row dimension is the same as the vector size 117 func (s Shape) Eq(other Shape) bool { 118 if s.IsScalar() && other.IsScalar() { 119 return true 120 } 121 122 if s.IsVector() && other.IsVector() { 123 switch { 124 case len(s) == 2 && len(other) == 1: 125 if (s.IsColVec() && s[0] == other[0]) || (s.IsRowVec() && s[1] == other[0]) { 126 return true 127 } 128 return false 129 case len(s) == 1 && len(other) == 2: 130 if (other.IsColVec() && other[0] == s[0]) || (other.IsRowVec() && other[1] == s[0]) { 131 return true 132 } 133 return false 134 } 135 } 136 137 if len(s) != len(other) { 138 return false 139 } 140 141 for i, v := range s { 142 if other[i] != v { 143 return false 144 } 145 } 146 return true 147 } 148 149 // Clone clones a shape. 150 func (s Shape) Clone() Shape { 151 retVal := BorrowInts(len(s)) 152 copy(retVal, s) 153 return retVal 154 } 155 156 // IsScalar returns true if the access pattern indicates it's a scalar value 157 func (s Shape) IsScalar() bool { 158 return len(s) == 0 159 } 160 161 // IsScalarEquiv returns true if the access pattern indicates it's a scalar-like value 162 func (s Shape) IsScalarEquiv() bool { 163 if len(s) == 0 { 164 return true 165 } 166 isEquiv := true 167 for i := range s { 168 if s[i] != 1 { 169 return false 170 } 171 } 172 return isEquiv 173 } 174 175 // IsVector returns whether the access pattern falls into one of three possible definitions of vectors: 176 // vanilla vector (not a row or a col) 177 // column vector 178 // row vector 179 func (s Shape) IsVector() bool { return s.IsColVec() || s.IsRowVec() || (len(s) == 1) } 180 181 // IsColVec returns true when the access pattern has the shape (x, 1) 182 func (s Shape) IsColVec() bool { return len(s) == 2 && (s[1] == 1 && s[0] > 1) } 183 184 // IsRowVec returns true when the access pattern has the shape (1, x) 185 func (s Shape) IsRowVec() bool { return len(s) == 2 && (s[0] == 1 && s[1] > 1) } 186 187 // IsVectorLike returns true when the shape looks like a vector 188 // e.g. a number that is surrounded by 1s: 189 // (1, 1, ... 1, 10, 1, 1... 1) 190 func (s Shape) IsVectorLike() bool { 191 var nonOnes int 192 for _, i := range s { 193 if i != 1 { 194 nonOnes++ 195 } 196 } 197 return nonOnes == 1 || nonOnes == 0 // if there is only one non-one then it's a vector or a scalarlike. 198 } 199 200 // IsMatrix returns true if it's a matrix. This is mostly a convenience method. RowVec and ColVecs are also considered matrices 201 func (s Shape) IsMatrix() bool { return len(s) == 2 } 202 203 // Dims returns the number of dimensions in the shape 204 func (s Shape) Dims() int { return len(s) } 205 206 // DimSize returns the size of the dimension wanted. 207 // 208 // This method implemnents the DimSizer interface in Gorgonia. 209 func (s Shape) DimSize(d int) (size int, err error) { 210 if (s.IsScalar() && d != 0) || (!s.IsScalar() && d >= len(s)) { 211 err = errors.Errorf(dimMismatch, len(s), d) 212 return 213 } 214 215 switch { 216 case s.IsScalar(): 217 return 0, nil 218 default: 219 return s[d], nil 220 } 221 } 222 223 // S gives the new shape after a shape has been sliced. It's repeated from the AP S() method mainly because there are other functions in Gorgonia that uses only shape 224 func (s Shape) S(slices ...Slice) (retVal Shape, err error) { 225 opDims := len(s) 226 if len(slices) > opDims { 227 err = errors.Errorf(dimMismatch, opDims, len(slices)) 228 return 229 } 230 231 retVal = s.Clone() 232 233 for d, size := range s { 234 var sl Slice // default is a nil Slice 235 if d <= len(slices)-1 { 236 sl = slices[d] 237 } 238 239 var start, end, step int 240 if start, end, step, err = SliceDetails(sl, size); err != nil { 241 return 242 } 243 244 if step > 0 { 245 retVal[d] = (end - start) / step 246 247 //fix 248 if retVal[d] <= 0 { 249 retVal[d] = 1 250 } 251 } else { 252 retVal[d] = (end - start) 253 } 254 255 } 256 257 // drop any dimension with size 1, except the last dimension 258 offset := 0 259 dims := s.Dims() 260 for d := 0; d < dims; d++ { 261 if retVal[d] == 1 && offset+d <= len(slices)-1 && slices[offset+d] != nil /*&& d != t.dims-1 && dims > 2*/ { 262 retVal = append(retVal[:d], retVal[d+1:]...) 263 d-- 264 dims-- 265 offset++ 266 } 267 } 268 269 if retVal.IsScalar() { 270 ReturnInts(retVal) 271 return ScalarShape(), nil 272 } 273 274 return 275 } 276 277 // Repeat returns the expected new shape given the repetition parameters. 278 func (s Shape) Repeat(axis int, repeats ...int) (newShape Shape, finalRepeats []int, size int, err error) { 279 switch { 280 case axis == AllAxes: 281 size = s.TotalSize() 282 newShape = Shape{size} 283 axis = 0 284 case s.IsScalar(): 285 size = 1 286 // special case for row vecs 287 if axis == 1 { 288 newShape = Shape{1, 0} 289 } else { 290 // otherwise it will be repeated into a vanilla vector 291 newShape = Shape{0} 292 } 293 case s.IsVector() && !s.IsRowVec() && !s.IsColVec() && axis == 1: 294 size = 1 295 newShape = s.Clone() 296 newShape = append(newShape, 1) 297 default: 298 if axis >= len(s) { 299 // error 300 err = errors.Errorf(invalidAxis, axis, s.Dims()) 301 return 302 } 303 size = s[axis] 304 newShape = s.Clone() 305 } 306 307 // special case to allow generic repeats 308 if len(repeats) == 1 { 309 rep := repeats[0] 310 repeats = make([]int, size) 311 for i := range repeats { 312 repeats[i] = rep 313 } 314 } 315 reps := len(repeats) 316 if reps != size { 317 err = errors.Errorf(broadcastError, size, reps) 318 return 319 } 320 321 newSize := SumInts(repeats) 322 newShape[axis] = newSize 323 finalRepeats = repeats 324 return 325 } 326 327 // Concat returns the expected new shape given the concatenation parameters 328 func (s Shape) Concat(axis int, ss ...Shape) (newShape Shape, err error) { 329 dims := s.Dims() 330 331 // check that all the concatenates have the same dimensions 332 for _, shp := range ss { 333 if shp.Dims() != dims { 334 err = errors.Errorf(dimMismatch, dims, shp.Dims()) 335 return 336 } 337 } 338 339 // special case 340 if axis == AllAxes { 341 axis = 0 342 } 343 344 // nope... no negative indexing here. 345 if axis < 0 { 346 err = errors.Errorf(invalidAxis, axis, len(s)) 347 return 348 } 349 350 if axis >= dims { 351 err = errors.Errorf(invalidAxis, axis, len(s)) 352 return 353 } 354 355 newShape = Shape(BorrowInts(dims)) 356 copy(newShape, s) 357 358 for _, shp := range ss { 359 for d := 0; d < dims; d++ { 360 if d == axis { 361 newShape[d] += shp[d] 362 } else { 363 // validate that the rest of the dimensions match up 364 if newShape[d] != shp[d] { 365 err = errors.Wrapf(errors.Errorf(dimMismatch, newShape[d], shp[d]), "Axis: %d, dimension it failed at: %d", axis, d) 366 return 367 } 368 } 369 } 370 } 371 return 372 } 373 374 // Format implements fmt.Formatter, and formats a shape nicely 375 func (s Shape) Format(st fmt.State, r rune) { 376 switch r { 377 case 'v', 's': 378 st.Write([]byte("(")) 379 for i, v := range s { 380 fmt.Fprintf(st, "%d", v) 381 if i < len(s)-1 { 382 st.Write([]byte(", ")) 383 } 384 } 385 st.Write([]byte(")")) 386 default: 387 fmt.Fprintf(st, "%v", []int(s)) 388 } 389 }