gorgonia.org/tensor@v0.9.24/api_utils.go (about)

     1  package tensor
     2  
     3  import (
     4  	"log"
     5  	"math"
     6  	"math/rand"
     7  	"reflect"
     8  	"sort"
     9  
    10  	"github.com/chewxy/math32"
    11  )
    12  
    13  // SortIndex is similar to numpy's argsort
    14  // TODO: tidy this up
    15  func SortIndex(in interface{}) (out []int) {
    16  	switch list := in.(type) {
    17  	case []int:
    18  		orig := make([]int, len(list))
    19  		out = make([]int, len(list))
    20  		copy(orig, list)
    21  		sort.Ints(list)
    22  		for i, s := range list {
    23  			for j, o := range orig {
    24  				if o == s {
    25  					out[i] = j
    26  					break
    27  				}
    28  			}
    29  		}
    30  	case []float64:
    31  		orig := make([]float64, len(list))
    32  		out = make([]int, len(list))
    33  		copy(orig, list)
    34  		sort.Float64s(list)
    35  
    36  		for i, s := range list {
    37  			for j, o := range orig {
    38  				if o == s {
    39  					out[i] = j
    40  					break
    41  				}
    42  			}
    43  		}
    44  	case sort.Interface:
    45  		sort.Sort(list)
    46  
    47  		log.Printf("TODO: SortIndex for sort.Interface not yet done.")
    48  	}
    49  
    50  	return
    51  }
    52  
    53  // SampleIndex samples a slice or a Tensor.
    54  // TODO: tidy this up.
    55  func SampleIndex(in interface{}) int {
    56  	// var l int
    57  	switch list := in.(type) {
    58  	case []int:
    59  		var sum, i int
    60  		// l = len(list)
    61  		r := rand.Int()
    62  		for {
    63  			sum += list[i]
    64  			if sum > r && i > 0 {
    65  				return i
    66  			}
    67  			i++
    68  		}
    69  	case []float64:
    70  		var sum float64
    71  		var i int
    72  		// l = len(list)
    73  		r := rand.Float64()
    74  		for {
    75  			sum += list[i]
    76  			if sum > r && i > 0 {
    77  				return i
    78  			}
    79  			i++
    80  		}
    81  	case *Dense:
    82  		var i int
    83  		switch list.t.Kind() {
    84  		case reflect.Float64:
    85  			var sum float64
    86  			r := rand.Float64()
    87  			data := list.Float64s()
    88  			// l = len(data)
    89  			for {
    90  				datum := data[i]
    91  				if math.IsNaN(datum) || math.IsInf(datum, 0) {
    92  					return i
    93  				}
    94  
    95  				sum += datum
    96  				if sum > r && i > 0 {
    97  					return i
    98  				}
    99  				i++
   100  			}
   101  		case reflect.Float32:
   102  			var sum float32
   103  			r := rand.Float32()
   104  			data := list.Float32s()
   105  			// l = len(data)
   106  			for {
   107  				datum := data[i]
   108  				if math32.IsNaN(datum) || math32.IsInf(datum, 0) {
   109  					return i
   110  				}
   111  
   112  				sum += datum
   113  				if sum > r && i > 0 {
   114  					return i
   115  				}
   116  				i++
   117  			}
   118  		default:
   119  			panic("not yet implemented")
   120  		}
   121  	default:
   122  		panic("Not yet implemented")
   123  	}
   124  	return -1
   125  }