gitee.com/quant1x/num@v0.3.2/argmax.go (about) 1 package num 2 3 import ( 4 "gitee.com/quant1x/num/x32" 5 "gitee.com/quant1x/num/x64" 6 ) 7 8 // ArgMax Returns the indices of the maximum values along an axis. 9 // 10 // 返回轴上最大值的索引 11 func ArgMax[T Number](x []T) int { 12 ret := UnaryOperations2[T, int](x, x32.ArgMax, x64.ArgMax, __go_arg_max[T]) 13 return ret 14 } 15 16 func ArgMax2[T BaseType](x []T) int { 17 var d int 18 switch vs := any(x).(type) { 19 case []float32: 20 d = ArgMax(vs) 21 case []float64: 22 d = ArgMax(vs) 23 case []int: 24 d = ArgMax(vs) 25 case []int8: 26 d = ArgMax(vs) 27 case []int16: 28 d = ArgMax(vs) 29 case []int32: 30 d = ArgMax(vs) 31 case []int64: 32 d = ArgMax(vs) 33 case []uint: 34 d = ArgMax(vs) 35 case []uint8: 36 d = ArgMax(vs) 37 case []uint16: 38 d = ArgMax(vs) 39 case []uint32: 40 d = ArgMax(vs) 41 case []uint64: 42 d = ArgMax(vs) 43 case []uintptr: 44 d = ArgMax(vs) 45 case []string: 46 d = __go_arg_max(vs) 47 case []bool: 48 d = __go_bool_arg_max(vs) 49 default: 50 // 其它类型原样返回 51 panic(TypeError(any(x))) 52 } 53 54 return d 55 } 56 57 func __go_arg_max[T Ordered](x []T) int { 58 maxValue := x[0] 59 idx := 0 60 for i, v := range x[1:] { 61 if v > maxValue { 62 maxValue = v 63 idx = 1 + i 64 } 65 } 66 return idx 67 } 68 69 func __go_bool_arg_max(x []bool) int { 70 maxValue := BoolToInt(x[0]) 71 idx := 0 72 for i, v := range x[1:] { 73 if BoolToInt(v) > maxValue { 74 maxValue = BoolToInt(v) 75 idx = 1 + i 76 } 77 } 78 return idx 79 }