github.com/wzzhu/tensor@v0.9.24/utils.go (about) 1 package tensor 2 3 import ( 4 "github.com/pkg/errors" 5 ) 6 7 const AllAxes int = -1 8 9 // MinInt returns the lowest between two ints. If both are the same it returns the first 10 func MinInt(a, b int) int { 11 if a <= b { 12 return a 13 } 14 return b 15 } 16 17 // MaxInt returns the highest between two ints. If both are the same, it returns the first 18 func MaxInt(a, b int) int { 19 if a >= b { 20 return a 21 } 22 return b 23 } 24 25 // MaxInts returns the max of a slice of ints. 26 func MaxInts(is ...int) (retVal int) { 27 for _, i := range is { 28 if i > retVal { 29 retVal = i 30 } 31 } 32 return 33 } 34 35 // SumInts sums a slice of ints 36 func SumInts(a []int) (retVal int) { 37 for _, v := range a { 38 retVal += v 39 } 40 return 41 } 42 43 // ProdInts returns the internal product of an int slice 44 func ProdInts(a []int) (retVal int) { 45 retVal = 1 46 if len(a) == 0 { 47 return 48 } 49 for _, v := range a { 50 retVal *= v 51 } 52 return 53 } 54 55 // IsMonotonicInts returns true if the slice of ints is monotonically increasing. It also returns true for incr1 if every succession is a succession of 1 56 func IsMonotonicInts(a []int) (monotonic bool, incr1 bool) { 57 var prev int 58 incr1 = true 59 for i, v := range a { 60 if i == 0 { 61 prev = v 62 continue 63 } 64 65 if v < prev { 66 return false, false 67 } 68 if v != prev+1 { 69 incr1 = false 70 } 71 prev = v 72 } 73 monotonic = true 74 return 75 } 76 77 // Ltoi is Location to Index. Provide a shape, a strides, and a list of integers as coordinates, and returns the index at which the element is. 78 func Ltoi(shape Shape, strides []int, coords ...int) (at int, err error) { 79 if shape.IsScalarEquiv() { 80 for _, v := range coords { 81 if v != 0 { 82 return -1, errors.Errorf("Scalar shape only allows 0 as an index") 83 } 84 } 85 return 0, nil 86 } 87 for i, coord := range coords { 88 if i >= len(shape) { 89 err = errors.Errorf(dimMismatch, len(shape), i) 90 return 91 } 92 93 size := shape[i] 94 95 if coord >= size { 96 err = errors.Errorf(indexOOBAxis, i, coord, size) 97 return 98 } 99 100 var stride int 101 switch { 102 case shape.IsVector() && len(strides) == 1: 103 stride = strides[0] 104 case i >= len(strides): 105 err = errors.Errorf(dimMismatch, len(strides), i) 106 return 107 default: 108 stride = strides[i] 109 } 110 111 at += stride * coord 112 } 113 return at, nil 114 } 115 116 // Itol is Index to Location. 117 func Itol(i int, shape Shape, strides []int) (coords []int, err error) { 118 dims := len(strides) 119 120 for d := 0; d < dims; d++ { 121 var coord int 122 coord, i = divmod(i, strides[d]) 123 124 if coord >= shape[d] { 125 err = errors.Errorf(indexOOBAxis, d, coord, shape[d]) 126 // return 127 } 128 129 coords = append(coords, coord) 130 } 131 return 132 } 133 134 func UnsafePermute(pattern []int, xs ...[]int) (err error) { 135 if len(xs) == 0 { 136 err = errors.New("Permute requres something to permute") 137 return 138 } 139 140 dims := -1 141 patLen := len(pattern) 142 for _, x := range xs { 143 if dims == -1 { 144 dims = len(x) 145 if patLen != dims { 146 err = errors.Errorf(dimMismatch, len(x), len(pattern)) 147 return 148 } 149 } else { 150 if len(x) != dims { 151 err = errors.Errorf(dimMismatch, len(x), len(pattern)) 152 return 153 } 154 } 155 } 156 157 // check that all the axes are < nDims 158 // and that there are no axis repeated 159 seen := make(map[int]struct{}) 160 for _, a := range pattern { 161 if a >= dims { 162 err = errors.Errorf(invalidAxis, a, dims) 163 return 164 } 165 166 if _, ok := seen[a]; ok { 167 err = errors.Errorf(repeatedAxis, a) 168 return 169 } 170 171 seen[a] = struct{}{} 172 } 173 174 // no op really... we did the checks for no reason too. Maybe move this up? 175 if monotonic, incr1 := IsMonotonicInts(pattern); monotonic && incr1 { 176 err = noopError{} 177 return 178 } 179 180 switch dims { 181 case 0, 1: 182 case 2: 183 for _, x := range xs { 184 x[0], x[1] = x[1], x[0] 185 } 186 default: 187 for i := 0; i < dims; i++ { 188 to := pattern[i] 189 for to < i { 190 to = pattern[to] 191 } 192 for _, x := range xs { 193 x[i], x[to] = x[to], x[i] 194 } 195 } 196 } 197 return nil 198 } 199 200 // CheckSlice checks a slice to see if it's sane 201 func CheckSlice(s Slice, size int) error { 202 start := s.Start() 203 end := s.End() 204 step := s.Step() 205 206 if start > end { 207 return errors.Errorf(invalidSliceIndex, start, end) 208 } 209 210 if start < 0 { 211 return errors.Errorf(invalidSliceIndex, start, 0) 212 } 213 214 if step == 0 && end-start > 1 { 215 return errors.Errorf("Slice has 0 steps. Start is %d and end is %d", start, end) 216 } 217 218 if start >= size { 219 return errors.Errorf("Start %d is greater than size %d", start, size) 220 } 221 222 return nil 223 } 224 225 // SliceDetails is a function that takes a slice and spits out its details. The whole reason for this is to handle the nil Slice, which is this: a[:] 226 func SliceDetails(s Slice, size int) (start, end, step int, err error) { 227 if s == nil { 228 start = 0 229 end = size 230 step = 1 231 } else { 232 if err = CheckSlice(s, size); err != nil { 233 return 234 } 235 236 start = s.Start() 237 end = s.End() 238 step = s.Step() 239 240 if end > size { 241 end = size 242 } 243 } 244 return 245 } 246 247 // reuseDenseCheck checks a reuse tensor, and reshapes it to be the correct one 248 func reuseDenseCheck(reuse DenseTensor, as DenseTensor) (err error) { 249 if reuse.DataSize() != as.Size() { 250 err = errors.Errorf("Reused Tensor %p does not have expected shape %v. Got %v instead. Reuse Size: %v, as Size %v (real: %d)", reuse, as.Shape(), reuse.Shape(), reuse.DataSize(), as.Size(), as.DataSize()) 251 return 252 } 253 return reuseCheckShape(reuse, as.Shape()) 254 255 } 256 257 // reuseCheckShape checks the shape and reshapes it to be correct if the size fits but the shape doesn't. 258 func reuseCheckShape(reuse DenseTensor, s Shape) (err error) { 259 throw := BorrowInts(len(s)) 260 copy(throw, s) 261 262 if err = reuse.reshape(throw...); err != nil { 263 err = errors.Wrapf(err, reuseReshapeErr, s, reuse.DataSize()) 264 return 265 } 266 267 // clean up any funny things that may be in the reuse 268 if oldAP := reuse.oldAP(); !oldAP.IsZero() { 269 oldAP.zero() 270 } 271 272 if axes := reuse.transposeAxes(); axes != nil { 273 ReturnInts(axes) 274 } 275 276 if viewOf := reuse.parentTensor(); viewOf != nil { 277 reuse.setParentTensor(nil) 278 } 279 return nil 280 } 281 282 // memsetBools sets boolean slice to value. 283 // Reference http://stackoverflow.com/questions/30614165/is-there-analog-of-memset-in-go 284 func memsetBools(a []bool, v bool) { 285 if len(a) == 0 { 286 return 287 } 288 a[0] = v 289 for bp := 1; bp < len(a); bp *= 2 { 290 copy(a[bp:], a[:bp]) 291 } 292 } 293 294 func allones(a []int) bool { 295 for i := range a { 296 if a[i] != 1 { 297 return false 298 } 299 } 300 return true 301 } 302 303 func getFloat64s(a Tensor) []float64 { 304 if um, ok := a.(unsafeMem); ok { 305 return um.Float64s() 306 } 307 return a.Data().([]float64) 308 } 309 310 func getFloat32s(a Tensor) []float32 { 311 if um, ok := a.(unsafeMem); ok { 312 return um.Float32s() 313 } 314 return a.Data().([]float32) 315 } 316 317 func getInts(a Tensor) []int { 318 if um, ok := a.(unsafeMem); ok { 319 return um.Ints() 320 } 321 return a.Data().([]int) 322 } 323 324 /* FOR ILLUSTRATIVE PURPOSES */ 325 326 // Permute permutates a pattern according to xs. This function exists for illustrative purposes (i.e. the dumb, unoptimized version) 327 // 328 // In reality, the UnsafePermute function is used. 329 /* 330 func Permute(pattern []int, xs ...[]int) (retVal [][]int, err error) { 331 if len(xs) == 0 { 332 err = errors.New("Permute requires something to permute") 333 return 334 } 335 336 dims := -1 337 patLen := len(pattern) 338 for _, x := range xs { 339 if dims == -1 { 340 dims = len(x) 341 if patLen != dims { 342 err = errors.Errorf(dimMismatch, len(x), len(pattern)) 343 return 344 } 345 } else { 346 if len(x) != dims { 347 err = errors.Errorf(dimMismatch, len(x), len(pattern)) 348 return 349 } 350 } 351 } 352 353 // check that all the axes are < nDims 354 // and that there are no axis repeated 355 seen := make(map[int]struct{}) 356 for _, a := range pattern { 357 if a >= dims { 358 err = errors.Errorf(invalidAxis, a, dims) 359 return 360 } 361 362 if _, ok := seen[a]; ok { 363 err = errors.Errorf(repeatedAxis, a) 364 return 365 } 366 367 seen[a] = struct{}{} 368 } 369 370 // no op really... we did the checks for no reason too. Maybe move this up? 371 if monotonic, incr1 := IsMonotonicInts(pattern); monotonic && incr1 { 372 retVal = xs 373 err = noopError{} 374 return 375 } 376 377 switch dims { 378 case 0, 1: 379 retVal = xs 380 case 2: 381 for _, x := range xs { 382 rv := []int{x[1], x[0]} 383 retVal = append(retVal, rv) 384 } 385 default: 386 retVal = make([][]int, len(xs)) 387 for i := range retVal { 388 retVal[i] = make([]int, dims) 389 } 390 391 for i, v := range pattern { 392 for j, x := range xs { 393 retVal[j][i] = x[v] 394 } 395 } 396 } 397 return 398 } 399 */