gitee.com/quant1x/num@v0.3.2/argmax.go (about)

     1  package num
     2  
     3  import (
     4  	"gitee.com/quant1x/num/x32"
     5  	"gitee.com/quant1x/num/x64"
     6  )
     7  
     8  // ArgMax Returns the indices of the maximum values along an axis.
     9  //
    10  //	返回轴上最大值的索引
    11  func ArgMax[T Number](x []T) int {
    12  	ret := UnaryOperations2[T, int](x, x32.ArgMax, x64.ArgMax, __go_arg_max[T])
    13  	return ret
    14  }
    15  
    16  func ArgMax2[T BaseType](x []T) int {
    17  	var d int
    18  	switch vs := any(x).(type) {
    19  	case []float32:
    20  		d = ArgMax(vs)
    21  	case []float64:
    22  		d = ArgMax(vs)
    23  	case []int:
    24  		d = ArgMax(vs)
    25  	case []int8:
    26  		d = ArgMax(vs)
    27  	case []int16:
    28  		d = ArgMax(vs)
    29  	case []int32:
    30  		d = ArgMax(vs)
    31  	case []int64:
    32  		d = ArgMax(vs)
    33  	case []uint:
    34  		d = ArgMax(vs)
    35  	case []uint8:
    36  		d = ArgMax(vs)
    37  	case []uint16:
    38  		d = ArgMax(vs)
    39  	case []uint32:
    40  		d = ArgMax(vs)
    41  	case []uint64:
    42  		d = ArgMax(vs)
    43  	case []uintptr:
    44  		d = ArgMax(vs)
    45  	case []string:
    46  		d = __go_arg_max(vs)
    47  	case []bool:
    48  		d = __go_bool_arg_max(vs)
    49  	default:
    50  		// 其它类型原样返回
    51  		panic(TypeError(any(x)))
    52  	}
    53  
    54  	return d
    55  }
    56  
    57  func __go_arg_max[T Ordered](x []T) int {
    58  	maxValue := x[0]
    59  	idx := 0
    60  	for i, v := range x[1:] {
    61  		if v > maxValue {
    62  			maxValue = v
    63  			idx = 1 + i
    64  		}
    65  	}
    66  	return idx
    67  }
    68  
    69  func __go_bool_arg_max(x []bool) int {
    70  	maxValue := BoolToInt(x[0])
    71  	idx := 0
    72  	for i, v := range x[1:] {
    73  		if BoolToInt(v) > maxValue {
    74  			maxValue = BoolToInt(v)
    75  			idx = 1 + i
    76  		}
    77  	}
    78  	return idx
    79  }