github.com/qiaogw/arrgo@v0.0.8/stats_test.go (about)

     1  package arrgo
     2  
     3  import "testing"
     4  
     5  func TestSum(t *testing.T) {
     6  	var arr = Arange(100).ReShape(2, 5, 10)
     7  	if arr.Sum(0).NotEqual(Array(
     8  		[]float64{
     9  			50, 52, 54, 56, 58, 60, 62, 64, 66, 68,
    10  			70, 72, 74, 76, 78, 80, 82, 84, 86, 88,
    11  			90, 92, 94, 96, 98, 100, 102, 104, 106, 108,
    12  			110, 112, 114, 116, 118, 120, 122, 124, 126, 128,
    13  			130, 132, 134, 136, 138, 140, 142, 144, 146, 148},
    14  		5, 10)).AnyTrue() {
    15  		t.Error(`Expected [[ 50,  52,  54,  56,  58,  60,  62,  64,  66,  68],
    16  							 [ 70,  72,  74,  76,  78,  80,  82,  84,  86,  88],
    17  							 [ 90,  92,  94,  96,  98, 100, 102, 104, 106, 108],
    18  							 [110, 112, 114, 116, 118, 120, 122, 124, 126, 128],
    19  							 [130, 132, 134, 136, 138, 140, 142, 144, 146, 148]], got `,
    20  			arr.Sum(0))
    21  	}
    22  
    23  	if arr.Sum(1).NotEqual(Array(
    24  		[]float64{
    25  			100, 105, 110, 115, 120, 125, 130, 135, 140, 145,
    26  			350, 355, 360, 365, 370, 375, 380, 385, 390, 395},
    27  		2, 10)).AnyTrue() {
    28  		t.Error(`Expected 
    29  				[[100, 105, 110, 115, 120, 125, 130, 135, 140, 145],
    30   				[350, 355, 360, 365, 370, 375, 380, 385, 390, 395]], got `,
    31  			arr.Sum(1))
    32  	}
    33  
    34  	if arr.Sum(2).NotEqual(Array(
    35  		[]float64{
    36  			45, 145, 245, 345, 445,
    37  			545, 645, 745, 845, 945},
    38  		2, 5)).AnyTrue() {
    39  		t.Error(`Expected 
    40  			[[ 45, 145, 245, 345, 445],
    41  	 		[545, 645, 745, 845, 945]]
    42  		, got 
    43  		`, arr.Sum(2))
    44  	}
    45  
    46  	if arr.Sum(0, 1).NotEqual(Array(
    47  		[]float64{
    48  			450, 460, 470, 480, 490, 500, 510, 520, 530, 540},
    49  		10)).AnyTrue() {
    50  		t.Error(`Expected 
    51  			[450, 460, 470, 480, 490, 500, 510, 520, 530, 540]
    52  			, got`,
    53  			arr.Sum(0, 1))
    54  	}
    55  
    56  	if arr.Sum(0, 2).NotEqual(Array([]float64{590, 790, 990, 1190, 1390}, 5)).AnyTrue() {
    57  		t.Error(`Expected [ 590,  790,  990, 1190, 1390], got `, arr.Sum(0, 2))
    58  	}
    59  
    60  	if arr.Sum(1, 2).NotEqual(Array([]float64{1225, 3725})).AnyTrue() {
    61  		t.Error("Expected [1225, 3725], got ", arr.Sum(1, 2))
    62  	}
    63  
    64  	if arr.Sum(0, 1, 2).NotEqual(Array([]float64{4950})).AnyTrue() {
    65  		t.Error("expected [4950], got ", arr.Sum(0, 1, 2))
    66  	}
    67  
    68  	if arr.Sum().NotEqual(Array([]float64{4950})).AnyTrue() {
    69  		t.Error("expected [4950], got ", arr.Sum())
    70  	}
    71  }
    72  
    73  func TestArgMax(t *testing.T) {
    74  	arr := Array([]float64{17, 10, 22, 3, 2, 7, 15, 9, 23, 4, 14, 18, 5, 8, 0, 12, 1,
    75  		19, 20, 11, 6, 16, 21, 13}, 2, 3, 4)
    76  
    77  	if arr.ArgMax(0).NotEqual(Array([]float64{0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0}, 3, 4)).AnyTrue() {
    78  		t.Error(`Expected
    79  		[[0, 0, 0, 1],
    80  		[0, 1, 1, 1],
    81  		[0, 1, 1, 0]], got `, arr.ArgMax(0))
    82  	}
    83  
    84  	if arr.ArgMax(-3).NotEqual(Array([]float64{0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0}, 3, 4)).AnyTrue() {
    85  		t.Error(`Expected
    86  		[[0, 0, 0, 1],
    87  		[0, 1, 1, 1],
    88  		[0, 1, 1, 0]], got `, arr.ArgMax(0))
    89  	}
    90  
    91  	if arr.ArgMax(1).NotEqual(Array([]float64{2, 0, 0, 2, 2, 1, 2, 2}, 2, 4)).AnyTrue() {
    92  		t.Error(`Expected
    93  		[[2, 0, 0, 2],
    94         		[2, 1, 2, 2]], got `, arr.ArgMax(1))
    95  	}
    96  
    97  	if arr.ArgMax(-2).NotEqual(Array([]float64{2, 0, 0, 2, 2, 1, 2, 2}, 2, 4)).AnyTrue() {
    98  		t.Error(`Expected
    99  		[[2, 0, 0, 2],
   100         		[2, 1, 2, 2]], got `, arr.ArgMax(1))
   101  	}
   102  
   103  	if arr.ArgMax(2).NotEqual(Array([]float64{2, 2, 0, 3, 2, 2}, 2, 3)).AnyTrue() {
   104  		t.Error(`Expected
   105  		[[2, 2, 0],
   106         		[3, 2, 2]], got `, arr.ArgMax(2))
   107  	}
   108  
   109  	if arr.ArgMax(-1).NotEqual(Array([]float64{2, 2, 0, 3, 2, 2}, 2, 3)).AnyTrue() {
   110  		t.Error(`Expected
   111  		[[2, 2, 0],
   112         		[3, 2, 2]], got `, arr.ArgMax(2))
   113  	}
   114  
   115  }
   116  
   117  func TestArgMin(t *testing.T) {
   118  	arr := Array([]float64{17, 10, 22, 3, 2, 7, 15, 9, 23, 4, 14, 18, 5, 8, 0, 12, 1,
   119  		19, 20, 11, 6, 16, 21, 13}, 2, 3, 4)
   120  
   121  	if arr.ArgMin(0).NotEqual(Array([]float64{1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1}, 3, 4)).AnyTrue() {
   122  		t.Error(`Expected
   123  		[[1, 1, 1, 0],
   124  		[1, 0, 0, 0],
   125  		[1, 0, 0, 1]], got `, arr.ArgMin(0))
   126  	}
   127  
   128  	if arr.ArgMin(1).NotEqual(Array([]float64{1, 2, 2, 0, 1, 0, 0, 1}, 2, 4)).AnyTrue() {
   129  		t.Error(`Expected
   130  		[[1, 2, 2, 0],
   131         		[1, 0, 0, 1]], got `, arr.ArgMin(1))
   132  	}
   133  
   134  	if arr.ArgMin(2).NotEqual(Array([]float64{3, 0, 1, 2, 0, 0}, 2, 3)).AnyTrue() {
   135  		t.Error(`Expected
   136  		[[3, 0, 1],
   137         		[2, 0, 0]], got `, arr.ArgMin(2))
   138  	}
   139  
   140  	if arr.ArgMin(-3).NotEqual(Array([]float64{1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1}, 3, 4)).AnyTrue() {
   141  		t.Error(`Expected
   142  		[[1, 1, 1, 0],
   143  		[1, 0, 0, 0],
   144  		[1, 0, 0, 1]], got `, arr.ArgMin(0))
   145  	}
   146  
   147  	if arr.ArgMin(-2).NotEqual(Array([]float64{1, 2, 2, 0, 1, 0, 0, 1}, 2, 4)).AnyTrue() {
   148  		t.Error(`Expected
   149  		[[1, 2, 2, 0],
   150         		[1, 0, 0, 1]], got `, arr.ArgMin(1))
   151  	}
   152  
   153  	if arr.ArgMin(-1).NotEqual(Array([]float64{3, 0, 1, 2, 0, 0}, 2, 3)).AnyTrue() {
   154  		t.Error(`Expected
   155  		[[3, 0, 1],
   156         		[2, 0, 0]], got `, arr.ArgMin(2))
   157  	}
   158  }