gitee.com/quant1x/num@v0.3.2/argmax_test.go (about)

     1  package num
     2  
     3  import (
     4  	"gitee.com/quant1x/num/labs"
     5  	"testing"
     6  )
     7  
     8  func TestArgMax(t *testing.T) {
     9  	type testCase struct {
    10  		Name     string
    11  		Args     any
    12  		Want     any
    13  		TestFunc func(v any) any
    14  	}
    15  	tests := []testCase{
    16  		//{
    17  		//	Name: "bool",
    18  		//	Args: []bool{false, true},
    19  		//	Want: []bool{false, true},
    20  		//	TestFunc: func(v any) any {
    21  		//		return ArgMax(v.([]bool))
    22  		//	},
    23  		//},
    24  		//{
    25  		//	Name: "string",
    26  		//	Args: []string{"1"},
    27  		//	Want: []string{"1"},
    28  		//	TestFunc: func(v any) any {
    29  		//		return Abs(v.([]string))
    30  		//	},
    31  		//},
    32  		{
    33  			Name: "float32",
    34  			Args: []float32{-0.1, 1.0, -2.00, -3},
    35  			Want: 1,
    36  			TestFunc: func(v any) any {
    37  				return ArgMax(v.([]float32))
    38  			},
    39  		},
    40  		{
    41  			Name: "float64",
    42  			Args: []float64{1.2, 1.2, 3.3},
    43  			Want: 2,
    44  			TestFunc: func(v any) any {
    45  				return ArgMax(v.([]float64))
    46  			},
    47  		},
    48  		{
    49  			Name: "int32",
    50  			Args: []int32{11, 12, 33},
    51  			Want: 2,
    52  			TestFunc: func(v any) any {
    53  				return ArgMax(v.([]int32))
    54  			},
    55  		},
    56  		{
    57  			Name: "int64",
    58  			Args: []int64{11, 12, 33},
    59  			Want: 2,
    60  			TestFunc: func(v any) any {
    61  				return ArgMax(v.([]int64))
    62  			},
    63  		},
    64  	}
    65  
    66  	for _, tt := range tests {
    67  		t.Run(tt.Name, func(t *testing.T) {
    68  			if got := tt.TestFunc(tt.Args); !labs.DeepEqual(got, tt.Want) {
    69  				t.Errorf("ArgMax() = %v, want %v", got, tt.Want)
    70  			}
    71  		})
    72  	}
    73  }