github.com/wzzhu/tensor@v0.9.24/utils.go (about)

     1  package tensor
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  )
     6  
     7  const AllAxes int = -1
     8  
     9  // MinInt returns the lowest between two ints. If both are the  same it returns the first
    10  func MinInt(a, b int) int {
    11  	if a <= b {
    12  		return a
    13  	}
    14  	return b
    15  }
    16  
    17  // MaxInt returns the highest between two ints. If both are the same, it  returns the first
    18  func MaxInt(a, b int) int {
    19  	if a >= b {
    20  		return a
    21  	}
    22  	return b
    23  }
    24  
    25  // MaxInts returns the max of a slice of ints.
    26  func MaxInts(is ...int) (retVal int) {
    27  	for _, i := range is {
    28  		if i > retVal {
    29  			retVal = i
    30  		}
    31  	}
    32  	return
    33  }
    34  
    35  // SumInts sums a slice of ints
    36  func SumInts(a []int) (retVal int) {
    37  	for _, v := range a {
    38  		retVal += v
    39  	}
    40  	return
    41  }
    42  
    43  // ProdInts returns the internal product of an int slice
    44  func ProdInts(a []int) (retVal int) {
    45  	retVal = 1
    46  	if len(a) == 0 {
    47  		return
    48  	}
    49  	for _, v := range a {
    50  		retVal *= v
    51  	}
    52  	return
    53  }
    54  
    55  // IsMonotonicInts returns true if the slice of ints is monotonically increasing. It also returns true for incr1 if every succession is a succession of 1
    56  func IsMonotonicInts(a []int) (monotonic bool, incr1 bool) {
    57  	var prev int
    58  	incr1 = true
    59  	for i, v := range a {
    60  		if i == 0 {
    61  			prev = v
    62  			continue
    63  		}
    64  
    65  		if v < prev {
    66  			return false, false
    67  		}
    68  		if v != prev+1 {
    69  			incr1 = false
    70  		}
    71  		prev = v
    72  	}
    73  	monotonic = true
    74  	return
    75  }
    76  
    77  // Ltoi is Location to Index. Provide a shape, a strides, and a list of integers as coordinates, and returns the index at which the element is.
    78  func Ltoi(shape Shape, strides []int, coords ...int) (at int, err error) {
    79  	if shape.IsScalarEquiv() {
    80  		for _, v := range coords {
    81  			if v != 0 {
    82  				return -1, errors.Errorf("Scalar shape only allows 0 as an index")
    83  			}
    84  		}
    85  		return 0, nil
    86  	}
    87  	for i, coord := range coords {
    88  		if i >= len(shape) {
    89  			err = errors.Errorf(dimMismatch, len(shape), i)
    90  			return
    91  		}
    92  
    93  		size := shape[i]
    94  
    95  		if coord >= size {
    96  			err = errors.Errorf(indexOOBAxis, i, coord, size)
    97  			return
    98  		}
    99  
   100  		var stride int
   101  		switch {
   102  		case shape.IsVector() && len(strides) == 1:
   103  			stride = strides[0]
   104  		case i >= len(strides):
   105  			err = errors.Errorf(dimMismatch, len(strides), i)
   106  			return
   107  		default:
   108  			stride = strides[i]
   109  		}
   110  
   111  		at += stride * coord
   112  	}
   113  	return at, nil
   114  }
   115  
   116  // Itol is Index to Location.
   117  func Itol(i int, shape Shape, strides []int) (coords []int, err error) {
   118  	dims := len(strides)
   119  
   120  	for d := 0; d < dims; d++ {
   121  		var coord int
   122  		coord, i = divmod(i, strides[d])
   123  
   124  		if coord >= shape[d] {
   125  			err = errors.Errorf(indexOOBAxis, d, coord, shape[d])
   126  			// return
   127  		}
   128  
   129  		coords = append(coords, coord)
   130  	}
   131  	return
   132  }
   133  
   134  func UnsafePermute(pattern []int, xs ...[]int) (err error) {
   135  	if len(xs) == 0 {
   136  		err = errors.New("Permute requres something to permute")
   137  		return
   138  	}
   139  
   140  	dims := -1
   141  	patLen := len(pattern)
   142  	for _, x := range xs {
   143  		if dims == -1 {
   144  			dims = len(x)
   145  			if patLen != dims {
   146  				err = errors.Errorf(dimMismatch, len(x), len(pattern))
   147  				return
   148  			}
   149  		} else {
   150  			if len(x) != dims {
   151  				err = errors.Errorf(dimMismatch, len(x), len(pattern))
   152  				return
   153  			}
   154  		}
   155  	}
   156  
   157  	// check that all the axes are < nDims
   158  	// and that there are no axis repeated
   159  	seen := make(map[int]struct{})
   160  	for _, a := range pattern {
   161  		if a >= dims {
   162  			err = errors.Errorf(invalidAxis, a, dims)
   163  			return
   164  		}
   165  
   166  		if _, ok := seen[a]; ok {
   167  			err = errors.Errorf(repeatedAxis, a)
   168  			return
   169  		}
   170  
   171  		seen[a] = struct{}{}
   172  	}
   173  
   174  	// no op really... we did the checks for no reason too. Maybe move this up?
   175  	if monotonic, incr1 := IsMonotonicInts(pattern); monotonic && incr1 {
   176  		err = noopError{}
   177  		return
   178  	}
   179  
   180  	switch dims {
   181  	case 0, 1:
   182  	case 2:
   183  		for _, x := range xs {
   184  			x[0], x[1] = x[1], x[0]
   185  		}
   186  	default:
   187  		for i := 0; i < dims; i++ {
   188  			to := pattern[i]
   189  			for to < i {
   190  				to = pattern[to]
   191  			}
   192  			for _, x := range xs {
   193  				x[i], x[to] = x[to], x[i]
   194  			}
   195  		}
   196  	}
   197  	return nil
   198  }
   199  
   200  // CheckSlice checks a slice to see if it's sane
   201  func CheckSlice(s Slice, size int) error {
   202  	start := s.Start()
   203  	end := s.End()
   204  	step := s.Step()
   205  
   206  	if start > end {
   207  		return errors.Errorf(invalidSliceIndex, start, end)
   208  	}
   209  
   210  	if start < 0 {
   211  		return errors.Errorf(invalidSliceIndex, start, 0)
   212  	}
   213  
   214  	if step == 0 && end-start > 1 {
   215  		return errors.Errorf("Slice has 0 steps. Start is %d and end is %d", start, end)
   216  	}
   217  
   218  	if start >= size {
   219  		return errors.Errorf("Start %d is greater than size %d", start, size)
   220  	}
   221  
   222  	return nil
   223  }
   224  
   225  // SliceDetails is a function that takes a slice and spits out its details. The whole reason for this is to handle the nil Slice, which is this: a[:]
   226  func SliceDetails(s Slice, size int) (start, end, step int, err error) {
   227  	if s == nil {
   228  		start = 0
   229  		end = size
   230  		step = 1
   231  	} else {
   232  		if err = CheckSlice(s, size); err != nil {
   233  			return
   234  		}
   235  
   236  		start = s.Start()
   237  		end = s.End()
   238  		step = s.Step()
   239  
   240  		if end > size {
   241  			end = size
   242  		}
   243  	}
   244  	return
   245  }
   246  
   247  // reuseDenseCheck checks a reuse tensor, and reshapes it to be the correct one
   248  func reuseDenseCheck(reuse DenseTensor, as DenseTensor) (err error) {
   249  	if reuse.DataSize() != as.Size() {
   250  		err = errors.Errorf("Reused Tensor %p does not have expected shape %v. Got %v instead. Reuse Size: %v, as Size %v (real: %d)", reuse, as.Shape(), reuse.Shape(), reuse.DataSize(), as.Size(), as.DataSize())
   251  		return
   252  	}
   253  	return reuseCheckShape(reuse, as.Shape())
   254  
   255  }
   256  
   257  // reuseCheckShape  checks the shape and reshapes it to be correct if the size fits but the shape doesn't.
   258  func reuseCheckShape(reuse DenseTensor, s Shape) (err error) {
   259  	throw := BorrowInts(len(s))
   260  	copy(throw, s)
   261  
   262  	if err = reuse.reshape(throw...); err != nil {
   263  		err = errors.Wrapf(err, reuseReshapeErr, s, reuse.DataSize())
   264  		return
   265  	}
   266  
   267  	// clean up any funny things that may be in the reuse
   268  	if oldAP := reuse.oldAP(); !oldAP.IsZero() {
   269  		oldAP.zero()
   270  	}
   271  
   272  	if axes := reuse.transposeAxes(); axes != nil {
   273  		ReturnInts(axes)
   274  	}
   275  
   276  	if viewOf := reuse.parentTensor(); viewOf != nil {
   277  		reuse.setParentTensor(nil)
   278  	}
   279  	return nil
   280  }
   281  
   282  // memsetBools sets boolean slice to value.
   283  // Reference http://stackoverflow.com/questions/30614165/is-there-analog-of-memset-in-go
   284  func memsetBools(a []bool, v bool) {
   285  	if len(a) == 0 {
   286  		return
   287  	}
   288  	a[0] = v
   289  	for bp := 1; bp < len(a); bp *= 2 {
   290  		copy(a[bp:], a[:bp])
   291  	}
   292  }
   293  
   294  func allones(a []int) bool {
   295  	for i := range a {
   296  		if a[i] != 1 {
   297  			return false
   298  		}
   299  	}
   300  	return true
   301  }
   302  
   303  func getFloat64s(a Tensor) []float64 {
   304  	if um, ok := a.(unsafeMem); ok {
   305  		return um.Float64s()
   306  	}
   307  	return a.Data().([]float64)
   308  }
   309  
   310  func getFloat32s(a Tensor) []float32 {
   311  	if um, ok := a.(unsafeMem); ok {
   312  		return um.Float32s()
   313  	}
   314  	return a.Data().([]float32)
   315  }
   316  
   317  func getInts(a Tensor) []int {
   318  	if um, ok := a.(unsafeMem); ok {
   319  		return um.Ints()
   320  	}
   321  	return a.Data().([]int)
   322  }
   323  
   324  /* FOR ILLUSTRATIVE PURPOSES */
   325  
   326  // Permute permutates a pattern according to xs. This function exists for illustrative purposes (i.e. the dumb, unoptimized version)
   327  //
   328  // In reality, the UnsafePermute function is used.
   329  /*
   330  func Permute(pattern []int, xs ...[]int) (retVal [][]int, err error) {
   331  	if len(xs) == 0 {
   332  		err = errors.New("Permute requires something to permute")
   333  		return
   334  	}
   335  
   336  	dims := -1
   337  	patLen := len(pattern)
   338  	for _, x := range xs {
   339  		if dims == -1 {
   340  			dims = len(x)
   341  			if patLen != dims {
   342  				err = errors.Errorf(dimMismatch, len(x), len(pattern))
   343  				return
   344  			}
   345  		} else {
   346  			if len(x) != dims {
   347  				err = errors.Errorf(dimMismatch, len(x), len(pattern))
   348  				return
   349  			}
   350  		}
   351  	}
   352  
   353  	// check that all the axes are < nDims
   354  	// and that there are no axis repeated
   355  	seen := make(map[int]struct{})
   356  	for _, a := range pattern {
   357  		if a >= dims {
   358  			err = errors.Errorf(invalidAxis, a, dims)
   359  			return
   360  		}
   361  
   362  		if _, ok := seen[a]; ok {
   363  			err = errors.Errorf(repeatedAxis, a)
   364  			return
   365  		}
   366  
   367  		seen[a] = struct{}{}
   368  	}
   369  
   370  	// no op really... we did the checks for no reason too. Maybe move this up?
   371  	if monotonic, incr1 := IsMonotonicInts(pattern); monotonic && incr1 {
   372  		retVal = xs
   373  		err = noopError{}
   374  		return
   375  	}
   376  
   377  	switch dims {
   378  	case 0, 1:
   379  		retVal = xs
   380  	case 2:
   381  		for _, x := range xs {
   382  			rv := []int{x[1], x[0]}
   383  			retVal = append(retVal, rv)
   384  		}
   385  	default:
   386  		retVal = make([][]int, len(xs))
   387  		for i := range retVal {
   388  			retVal[i] = make([]int, dims)
   389  		}
   390  
   391  		for i, v := range pattern {
   392  			for j, x := range xs {
   393  				retVal[j][i] = x[v]
   394  			}
   395  		}
   396  	}
   397  	return
   398  }
   399  */