golang.org/x/exp@v0.0.0-20240506185415-9bf2ced13842/slices/sort_test.go (about)

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package slices
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"math/rand"
    11  	"sort"
    12  	"strconv"
    13  	"strings"
    14  	"testing"
    15  )
    16  
    17  var ints = [...]int{74, 59, 238, -784, 9845, 959, 905, 0, 0, 42, 7586, -5467984, 7586}
    18  var float64s = [...]float64{74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3, math.Inf(-1), 9845.768, -959.7485, 905, 7.8, 7.8, 74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3}
    19  var float64sWithNaNs = [...]float64{74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3, math.NaN(), math.NaN(), math.Inf(-1), 9845.768, -959.7485, 905, 7.8, 7.8}
    20  var strs = [...]string{"", "Hello", "foo", "bar", "foo", "f00", "%*&^*&^&", "***"}
    21  
    22  func TestSortIntSlice(t *testing.T) {
    23  	data := Clone(ints[:])
    24  	Sort(data)
    25  	if !IsSorted(data) {
    26  		t.Errorf("sorted %v", ints)
    27  		t.Errorf("   got %v", data)
    28  	}
    29  }
    30  
    31  func TestSortFuncIntSlice(t *testing.T) {
    32  	data := Clone(ints[:])
    33  	SortFunc(data, func(a, b int) int { return a - b })
    34  	if !IsSorted(data) {
    35  		t.Errorf("sorted %v", ints)
    36  		t.Errorf("   got %v", data)
    37  	}
    38  }
    39  
    40  func TestSortFloat64Slice(t *testing.T) {
    41  	data := Clone(float64s[:])
    42  	Sort(data)
    43  	if !IsSorted(data) {
    44  		t.Errorf("sorted %v", float64s)
    45  		t.Errorf("   got %v", data)
    46  	}
    47  }
    48  
    49  func TestSortFloat64SliceWithNaNs(t *testing.T) {
    50  	data := float64sWithNaNs[:]
    51  	data2 := Clone(data)
    52  
    53  	Sort(data)
    54  	sort.Float64s(data2)
    55  
    56  	if !IsSorted(data) {
    57  		t.Error("IsSorted indicates data isn't sorted")
    58  	}
    59  
    60  	// Compare for equality using cmp.Compare, which considers NaNs equal.
    61  	if !EqualFunc(data, data2, func(a, b float64) bool { return cmpCompare(a, b) == 0 }) {
    62  		t.Errorf("mismatch between Sort and sort.Float64: got %v, want %v", data, data2)
    63  	}
    64  }
    65  
    66  func TestSortStringSlice(t *testing.T) {
    67  	data := Clone(strs[:])
    68  	Sort(data)
    69  	if !IsSorted(data) {
    70  		t.Errorf("sorted %v", strs)
    71  		t.Errorf("   got %v", data)
    72  	}
    73  }
    74  
    75  func TestSortLarge_Random(t *testing.T) {
    76  	n := 1000000
    77  	if testing.Short() {
    78  		n /= 100
    79  	}
    80  	data := make([]int, n)
    81  	for i := 0; i < len(data); i++ {
    82  		data[i] = rand.Intn(100)
    83  	}
    84  	if IsSorted(data) {
    85  		t.Fatalf("terrible rand.rand")
    86  	}
    87  	Sort(data)
    88  	if !IsSorted(data) {
    89  		t.Errorf("sort didn't sort - 1M ints")
    90  	}
    91  }
    92  
    93  type intPair struct {
    94  	a, b int
    95  }
    96  
    97  type intPairs []intPair
    98  
    99  // Pairs compare on a only.
   100  func intPairCmp(x, y intPair) int {
   101  	return x.a - y.a
   102  }
   103  
   104  // Record initial order in B.
   105  func (d intPairs) initB() {
   106  	for i := range d {
   107  		d[i].b = i
   108  	}
   109  }
   110  
   111  // InOrder checks if a-equal elements were not reordered.
   112  func (d intPairs) inOrder() bool {
   113  	lastA, lastB := -1, 0
   114  	for i := 0; i < len(d); i++ {
   115  		if lastA != d[i].a {
   116  			lastA = d[i].a
   117  			lastB = d[i].b
   118  			continue
   119  		}
   120  		if d[i].b <= lastB {
   121  			return false
   122  		}
   123  		lastB = d[i].b
   124  	}
   125  	return true
   126  }
   127  
   128  func TestStability(t *testing.T) {
   129  	n, m := 100000, 1000
   130  	if testing.Short() {
   131  		n, m = 1000, 100
   132  	}
   133  	data := make(intPairs, n)
   134  
   135  	// random distribution
   136  	for i := 0; i < len(data); i++ {
   137  		data[i].a = rand.Intn(m)
   138  	}
   139  	if IsSortedFunc(data, intPairCmp) {
   140  		t.Fatalf("terrible rand.rand")
   141  	}
   142  	data.initB()
   143  	SortStableFunc(data, intPairCmp)
   144  	if !IsSortedFunc(data, intPairCmp) {
   145  		t.Errorf("Stable didn't sort %d ints", n)
   146  	}
   147  	if !data.inOrder() {
   148  		t.Errorf("Stable wasn't stable on %d ints", n)
   149  	}
   150  
   151  	// already sorted
   152  	data.initB()
   153  	SortStableFunc(data, intPairCmp)
   154  	if !IsSortedFunc(data, intPairCmp) {
   155  		t.Errorf("Stable shuffled sorted %d ints (order)", n)
   156  	}
   157  	if !data.inOrder() {
   158  		t.Errorf("Stable shuffled sorted %d ints (stability)", n)
   159  	}
   160  
   161  	// sorted reversed
   162  	for i := 0; i < len(data); i++ {
   163  		data[i].a = len(data) - i
   164  	}
   165  	data.initB()
   166  	SortStableFunc(data, intPairCmp)
   167  	if !IsSortedFunc(data, intPairCmp) {
   168  		t.Errorf("Stable didn't sort %d ints", n)
   169  	}
   170  	if !data.inOrder() {
   171  		t.Errorf("Stable wasn't stable on %d ints", n)
   172  	}
   173  }
   174  
   175  type S struct {
   176  	a int
   177  	b string
   178  }
   179  
   180  func cmpS(s1, s2 S) int {
   181  	return cmpCompare(s1.a, s2.a)
   182  }
   183  
   184  func TestMinMax(t *testing.T) {
   185  	intCmp := func(a, b int) int { return a - b }
   186  
   187  	tests := []struct {
   188  		data    []int
   189  		wantMin int
   190  		wantMax int
   191  	}{
   192  		{[]int{7}, 7, 7},
   193  		{[]int{1, 2}, 1, 2},
   194  		{[]int{2, 1}, 1, 2},
   195  		{[]int{1, 2, 3}, 1, 3},
   196  		{[]int{3, 2, 1}, 1, 3},
   197  		{[]int{2, 1, 3}, 1, 3},
   198  		{[]int{2, 2, 3}, 2, 3},
   199  		{[]int{3, 2, 3}, 2, 3},
   200  		{[]int{0, 2, -9}, -9, 2},
   201  	}
   202  	for _, tt := range tests {
   203  		t.Run(fmt.Sprintf("%v", tt.data), func(t *testing.T) {
   204  			gotMin := Min(tt.data)
   205  			if gotMin != tt.wantMin {
   206  				t.Errorf("Min got %v, want %v", gotMin, tt.wantMin)
   207  			}
   208  
   209  			gotMinFunc := MinFunc(tt.data, intCmp)
   210  			if gotMinFunc != tt.wantMin {
   211  				t.Errorf("MinFunc got %v, want %v", gotMinFunc, tt.wantMin)
   212  			}
   213  
   214  			gotMax := Max(tt.data)
   215  			if gotMax != tt.wantMax {
   216  				t.Errorf("Max got %v, want %v", gotMax, tt.wantMax)
   217  			}
   218  
   219  			gotMaxFunc := MaxFunc(tt.data, intCmp)
   220  			if gotMaxFunc != tt.wantMax {
   221  				t.Errorf("MaxFunc got %v, want %v", gotMaxFunc, tt.wantMax)
   222  			}
   223  		})
   224  	}
   225  
   226  	svals := []S{
   227  		{1, "a"},
   228  		{2, "a"},
   229  		{1, "b"},
   230  		{2, "b"},
   231  	}
   232  
   233  	gotMin := MinFunc(svals, cmpS)
   234  	wantMin := S{1, "a"}
   235  	if gotMin != wantMin {
   236  		t.Errorf("MinFunc(%v) = %v, want %v", svals, gotMin, wantMin)
   237  	}
   238  
   239  	gotMax := MaxFunc(svals, cmpS)
   240  	wantMax := S{2, "a"}
   241  	if gotMax != wantMax {
   242  		t.Errorf("MaxFunc(%v) = %v, want %v", svals, gotMax, wantMax)
   243  	}
   244  }
   245  
   246  func TestMinMaxNaNs(t *testing.T) {
   247  	fs := []float64{1.0, 999.9, 3.14, -400.4, -5.14}
   248  	if Min(fs) != -400.4 {
   249  		t.Errorf("got min %v, want -400.4", Min(fs))
   250  	}
   251  	if Max(fs) != 999.9 {
   252  		t.Errorf("got max %v, want 999.9", Max(fs))
   253  	}
   254  
   255  	// No matter which element of fs is replaced with a NaN, both Min and Max
   256  	// should propagate the NaN to their output.
   257  	for i := 0; i < len(fs); i++ {
   258  		testfs := Clone(fs)
   259  		testfs[i] = math.NaN()
   260  
   261  		fmin := Min(testfs)
   262  		if !math.IsNaN(fmin) {
   263  			t.Errorf("got min %v, want NaN", fmin)
   264  		}
   265  
   266  		fmax := Max(testfs)
   267  		if !math.IsNaN(fmax) {
   268  			t.Errorf("got max %v, want NaN", fmax)
   269  		}
   270  	}
   271  }
   272  
   273  func TestMinMaxPanics(t *testing.T) {
   274  	intCmp := func(a, b int) int { return a - b }
   275  	emptySlice := []int{}
   276  
   277  	if !panics(func() { Min(emptySlice) }) {
   278  		t.Errorf("Min([]): got no panic, want panic")
   279  	}
   280  
   281  	if !panics(func() { Max(emptySlice) }) {
   282  		t.Errorf("Max([]): got no panic, want panic")
   283  	}
   284  
   285  	if !panics(func() { MinFunc(emptySlice, intCmp) }) {
   286  		t.Errorf("MinFunc([]): got no panic, want panic")
   287  	}
   288  
   289  	if !panics(func() { MaxFunc(emptySlice, intCmp) }) {
   290  		t.Errorf("MaxFunc([]): got no panic, want panic")
   291  	}
   292  }
   293  
   294  func TestBinarySearch(t *testing.T) {
   295  	str1 := []string{"foo"}
   296  	str2 := []string{"ab", "ca"}
   297  	str3 := []string{"mo", "qo", "vo"}
   298  	str4 := []string{"ab", "ad", "ca", "xy"}
   299  
   300  	// slice with repeating elements
   301  	strRepeats := []string{"ba", "ca", "da", "da", "da", "ka", "ma", "ma", "ta"}
   302  
   303  	// slice with all element equal
   304  	strSame := []string{"xx", "xx", "xx"}
   305  
   306  	tests := []struct {
   307  		data      []string
   308  		target    string
   309  		wantPos   int
   310  		wantFound bool
   311  	}{
   312  		{[]string{}, "foo", 0, false},
   313  		{[]string{}, "", 0, false},
   314  
   315  		{str1, "foo", 0, true},
   316  		{str1, "bar", 0, false},
   317  		{str1, "zx", 1, false},
   318  
   319  		{str2, "aa", 0, false},
   320  		{str2, "ab", 0, true},
   321  		{str2, "ad", 1, false},
   322  		{str2, "ca", 1, true},
   323  		{str2, "ra", 2, false},
   324  
   325  		{str3, "bb", 0, false},
   326  		{str3, "mo", 0, true},
   327  		{str3, "nb", 1, false},
   328  		{str3, "qo", 1, true},
   329  		{str3, "tr", 2, false},
   330  		{str3, "vo", 2, true},
   331  		{str3, "xr", 3, false},
   332  
   333  		{str4, "aa", 0, false},
   334  		{str4, "ab", 0, true},
   335  		{str4, "ac", 1, false},
   336  		{str4, "ad", 1, true},
   337  		{str4, "ax", 2, false},
   338  		{str4, "ca", 2, true},
   339  		{str4, "cc", 3, false},
   340  		{str4, "dd", 3, false},
   341  		{str4, "xy", 3, true},
   342  		{str4, "zz", 4, false},
   343  
   344  		{strRepeats, "da", 2, true},
   345  		{strRepeats, "db", 5, false},
   346  		{strRepeats, "ma", 6, true},
   347  		{strRepeats, "mb", 8, false},
   348  
   349  		{strSame, "xx", 0, true},
   350  		{strSame, "ab", 0, false},
   351  		{strSame, "zz", 3, false},
   352  	}
   353  	for _, tt := range tests {
   354  		t.Run(tt.target, func(t *testing.T) {
   355  			{
   356  				pos, found := BinarySearch(tt.data, tt.target)
   357  				if pos != tt.wantPos || found != tt.wantFound {
   358  					t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
   359  				}
   360  			}
   361  
   362  			{
   363  				pos, found := BinarySearchFunc(tt.data, tt.target, strings.Compare)
   364  				if pos != tt.wantPos || found != tt.wantFound {
   365  					t.Errorf("BinarySearchFunc got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
   366  				}
   367  			}
   368  		})
   369  	}
   370  }
   371  
   372  func TestBinarySearchInts(t *testing.T) {
   373  	data := []int{20, 30, 40, 50, 60, 70, 80, 90}
   374  	tests := []struct {
   375  		target    int
   376  		wantPos   int
   377  		wantFound bool
   378  	}{
   379  		{20, 0, true},
   380  		{23, 1, false},
   381  		{43, 3, false},
   382  		{80, 6, true},
   383  	}
   384  	for _, tt := range tests {
   385  		t.Run(strconv.Itoa(tt.target), func(t *testing.T) {
   386  			{
   387  				pos, found := BinarySearch(data, tt.target)
   388  				if pos != tt.wantPos || found != tt.wantFound {
   389  					t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
   390  				}
   391  			}
   392  
   393  			{
   394  				cmp := func(a, b int) int {
   395  					return a - b
   396  				}
   397  				pos, found := BinarySearchFunc(data, tt.target, cmp)
   398  				if pos != tt.wantPos || found != tt.wantFound {
   399  					t.Errorf("BinarySearchFunc got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
   400  				}
   401  			}
   402  		})
   403  	}
   404  }
   405  
   406  func TestBinarySearchFloats(t *testing.T) {
   407  	data := []float64{math.NaN(), -0.25, 0.0, 1.4}
   408  	tests := []struct {
   409  		target    float64
   410  		wantPos   int
   411  		wantFound bool
   412  	}{
   413  		{math.NaN(), 0, true},
   414  		{math.Inf(-1), 1, false},
   415  		{-0.25, 1, true},
   416  		{0.0, 2, true},
   417  		{1.4, 3, true},
   418  		{1.5, 4, false},
   419  	}
   420  	for _, tt := range tests {
   421  		t.Run(fmt.Sprintf("%v", tt.target), func(t *testing.T) {
   422  			{
   423  				pos, found := BinarySearch(data, tt.target)
   424  				if pos != tt.wantPos || found != tt.wantFound {
   425  					t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
   426  				}
   427  			}
   428  		})
   429  	}
   430  }
   431  
   432  func TestBinarySearchFunc(t *testing.T) {
   433  	data := []int{1, 10, 11, 2} // sorted lexicographically
   434  	cmp := func(a int, b string) int {
   435  		return strings.Compare(strconv.Itoa(a), b)
   436  	}
   437  	pos, found := BinarySearchFunc(data, "2", cmp)
   438  	if pos != 3 || !found {
   439  		t.Errorf("BinarySearchFunc(%v, %q, cmp) = %v, %v, want %v, %v", data, "2", pos, found, 3, true)
   440  	}
   441  }