gorgonia.org/tensor@v0.9.24/api_utils.go (about) 1 package tensor 2 3 import ( 4 "log" 5 "math" 6 "math/rand" 7 "reflect" 8 "sort" 9 10 "github.com/chewxy/math32" 11 ) 12 13 // SortIndex is similar to numpy's argsort 14 // TODO: tidy this up 15 func SortIndex(in interface{}) (out []int) { 16 switch list := in.(type) { 17 case []int: 18 orig := make([]int, len(list)) 19 out = make([]int, len(list)) 20 copy(orig, list) 21 sort.Ints(list) 22 for i, s := range list { 23 for j, o := range orig { 24 if o == s { 25 out[i] = j 26 break 27 } 28 } 29 } 30 case []float64: 31 orig := make([]float64, len(list)) 32 out = make([]int, len(list)) 33 copy(orig, list) 34 sort.Float64s(list) 35 36 for i, s := range list { 37 for j, o := range orig { 38 if o == s { 39 out[i] = j 40 break 41 } 42 } 43 } 44 case sort.Interface: 45 sort.Sort(list) 46 47 log.Printf("TODO: SortIndex for sort.Interface not yet done.") 48 } 49 50 return 51 } 52 53 // SampleIndex samples a slice or a Tensor. 54 // TODO: tidy this up. 55 func SampleIndex(in interface{}) int { 56 // var l int 57 switch list := in.(type) { 58 case []int: 59 var sum, i int 60 // l = len(list) 61 r := rand.Int() 62 for { 63 sum += list[i] 64 if sum > r && i > 0 { 65 return i 66 } 67 i++ 68 } 69 case []float64: 70 var sum float64 71 var i int 72 // l = len(list) 73 r := rand.Float64() 74 for { 75 sum += list[i] 76 if sum > r && i > 0 { 77 return i 78 } 79 i++ 80 } 81 case *Dense: 82 var i int 83 switch list.t.Kind() { 84 case reflect.Float64: 85 var sum float64 86 r := rand.Float64() 87 data := list.Float64s() 88 // l = len(data) 89 for { 90 datum := data[i] 91 if math.IsNaN(datum) || math.IsInf(datum, 0) { 92 return i 93 } 94 95 sum += datum 96 if sum > r && i > 0 { 97 return i 98 } 99 i++ 100 } 101 case reflect.Float32: 102 var sum float32 103 r := rand.Float32() 104 data := list.Float32s() 105 // l = len(data) 106 for { 107 datum := data[i] 108 if math32.IsNaN(datum) || math32.IsInf(datum, 0) { 109 return i 110 } 111 112 sum += datum 113 if sum > r && i > 0 { 114 return i 115 } 116 i++ 117 } 118 default: 119 panic("not yet implemented") 120 } 121 default: 122 panic("Not yet implemented") 123 } 124 return -1 125 }