github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/psort/mergesort_test.go (about)

     1  package psort
     2  
     3  import (
     4  	"fmt"
     5  	"math/rand"
     6  	"reflect"
     7  	"sort"
     8  	"testing"
     9  )
    10  
    11  type TestInput int
    12  
    13  const (
    14  	Random TestInput = iota
    15  	Ascending
    16  	Descending
    17  )
    18  
    19  func TestSlice(t *testing.T) {
    20  	tests := []struct {
    21  		input       TestInput
    22  		size        int
    23  		parallelism int
    24  		reps        int
    25  	}{
    26  		{
    27  			input:       Random,
    28  			size:        10000,
    29  			parallelism: 7,
    30  			reps:        100,
    31  		},
    32  		{
    33  			input:       Random,
    34  			size:        1000000,
    35  			parallelism: 6,
    36  			reps:        4,
    37  		},
    38  		{
    39  			input:       Ascending,
    40  			size:        10000,
    41  			parallelism: 9,
    42  			reps:        1,
    43  		},
    44  		{
    45  			input:       Descending,
    46  			size:        10000,
    47  			parallelism: 8,
    48  			reps:        1,
    49  		},
    50  	}
    51  
    52  	for _, test := range tests {
    53  		random := rand.New(rand.NewSource(0))
    54  		for rep := 0; rep < test.reps; rep++ {
    55  			in := make([]int, test.size)
    56  			switch test.input {
    57  			case Random:
    58  				for i := range in {
    59  					in[i] = random.Intn(test.size)
    60  				}
    61  			case Ascending:
    62  				for i := range in {
    63  					in[i] = i
    64  				}
    65  			case Descending:
    66  				for i := range in {
    67  					in[i] = len(in) - i
    68  				}
    69  			}
    70  			expected := make([]int, len(in))
    71  			copy(expected, in)
    72  			sort.Slice(expected, func(i, j int) bool {
    73  				return expected[i] < expected[j]
    74  			})
    75  			Slice(in, func(i, j int) bool {
    76  				return in[i] < in[j]
    77  			}, test.parallelism)
    78  			if !reflect.DeepEqual(expected, in) {
    79  				t.Errorf("Wrong sort result: want %v\n, got %v\n", expected, in)
    80  			}
    81  		}
    82  	}
    83  }
    84  
    85  func BenchmarkSlice(b *testing.B) {
    86  	tests := []struct {
    87  		size        int
    88  		parallelism int //parallelism = 0 means use sort.Slice() sort
    89  	}{
    90  		{
    91  			size:        100000000,
    92  			parallelism: 4096,
    93  		},
    94  		{
    95  			size:        100000000,
    96  			parallelism: 2048,
    97  		},
    98  		{
    99  			size:        100000000,
   100  			parallelism: 1024,
   101  		},
   102  		{
   103  			size:        100000000,
   104  			parallelism: 512,
   105  		},
   106  		{
   107  			size:        100000000,
   108  			parallelism: 256,
   109  		},
   110  		{
   111  			size:        100000000,
   112  			parallelism: 128,
   113  		},
   114  		{
   115  			size:        100000000,
   116  			parallelism: 64,
   117  		},
   118  		{
   119  			size:        100000000,
   120  			parallelism: 32,
   121  		},
   122  		{
   123  			size:        100000000,
   124  			parallelism: 16,
   125  		},
   126  		{
   127  			size:        100000000,
   128  			parallelism: 8,
   129  		},
   130  		{
   131  			size:        100000000,
   132  			parallelism: 4,
   133  		},
   134  		{
   135  			size:        100000000,
   136  			parallelism: 2,
   137  		},
   138  		{
   139  			size:        100000000,
   140  			parallelism: 1,
   141  		},
   142  		{
   143  			size:        100000000,
   144  			parallelism: 0,
   145  		},
   146  	}
   147  
   148  	for _, test := range tests {
   149  		b.Run(fmt.Sprintf("size:%d-%d", test.size, test.parallelism), func(b *testing.B) {
   150  			data := make([]float64, test.size)
   151  			r := rand.New(rand.NewSource(0))
   152  			dataCopy := make([]float64, len(data))
   153  			for i := range data {
   154  				data[i] = r.Float64()
   155  			}
   156  			b.ResetTimer()
   157  			for i := 0; i < b.N; i++ {
   158  				b.StopTimer()
   159  				copy(dataCopy, data)
   160  				b.StartTimer()
   161  				if test.parallelism == 0 {
   162  					sort.Slice(dataCopy, func(i, j int) bool {
   163  						return dataCopy[i] < dataCopy[j]
   164  					})
   165  				} else {
   166  					Slice(dataCopy, func(i, j int) bool {
   167  						return dataCopy[i] < dataCopy[j]
   168  					}, test.parallelism)
   169  				}
   170  			}
   171  		})
   172  	}
   173  }