github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/psort/mergesort.go (about) 1 package psort 2 3 import ( 4 "reflect" 5 "sort" 6 "sync" 7 8 "github.com/Schaudge/grailbase/traverse" 9 ) 10 11 const ( 12 serialThreshold = 128 13 ) 14 15 // Slice sorts the given slice according to the ordering induced by the provided 16 // less function. Parallel computation will be attempted, up to the limit imposed by 17 // parallelism. This function can be much faster than the standard library's sort.Slice() 18 // when sorting large slices on multicore machines. 19 func Slice(slice interface{}, less func(i, j int) bool, parallelism int) { 20 if parallelism < 1 { 21 panic("parallelism must be at least 1") 22 } 23 if reflect.TypeOf(slice).Kind() != reflect.Slice { 24 panic("input interface was not of slice type") 25 } 26 rv := reflect.ValueOf(slice) 27 if rv.Len() < 2 { 28 return 29 } 30 // For clarity, we will sort a slice containing indices from the input slice. Then, 31 // we will set the elements of the input slice according to this permutation. This 32 // avoids difficult-to-understand reflection types and calls in most of the code. 33 perm := make([]int, rv.Len()) 34 for i := range perm { 35 perm[i] = i 36 } 37 scratch := make([]int, len(perm)) 38 mergeSort(perm, less, parallelism, scratch) 39 result := reflect.MakeSlice(rv.Type(), rv.Len(), rv.Len()) 40 _ = traverse.Limit(parallelism).Range(rv.Len(), func(start, end int) error { 41 for i := start; i < end; i++ { 42 result.Index(i).Set(rv.Index(perm[i])) 43 } 44 return nil 45 }) 46 _ = traverse.Limit(parallelism).Range(rv.Len(), func(start, end int) error { 47 reflect.Copy(rv.Slice(start, end), result.Slice(start, end)) 48 return nil 49 }) 50 } 51 52 func mergeSort(perm []int, less func(i, j int) bool, parallelism int, scratch []int) { 53 if parallelism == 1 || len(perm) < serialThreshold { 54 sortSerial(perm, less) 55 return 56 } 57 58 // Sort two halves of the slice in parallel, allocating half of our parallelism to 59 // each subroutine. 60 left := perm[:len(perm)/2] 61 right := perm[len(perm)/2:] 62 var waitGroup sync.WaitGroup 63 waitGroup.Add(1) 64 go func() { 65 mergeSort(left, less, (parallelism+1)/2, scratch[:len(perm)/2]) 66 waitGroup.Done() 67 }() 68 mergeSort(right, less, parallelism/2, scratch[len(perm)/2:]) 69 waitGroup.Wait() 70 71 merge(left, right, less, parallelism, scratch) 72 parallelCopy(perm, scratch, parallelism) 73 } 74 75 func parallelCopy(dst, src []int, parallelism int) { 76 _ = traverse.Limit(parallelism).Range(len(dst), func(start, end int) error { 77 copy(dst[start:end], src[start:end]) 78 return nil 79 }) 80 } 81 82 func sortSerial(perm []int, less func(i, j int) bool) { 83 sort.Slice(perm, func(i, j int) bool { 84 return less(perm[i], perm[j]) 85 }) 86 } 87 88 func merge(perm1, perm2 []int, less func(i, j int) bool, parallelism int, out []int) { 89 if parallelism == 1 || len(perm1)+len(perm2) < serialThreshold { 90 mergeSerial(perm1, perm2, less, out) 91 return 92 } 93 94 if len(perm1) < len(perm2) { 95 perm1, perm2 = perm2, perm1 96 } 97 // Find the index in perm2 such that all elements to the left are smaller than 98 // the midpoint element of perm1. 99 r := len(perm1) / 2 100 s := sort.Search(len(perm2), func(i int) bool { 101 return !less(perm2[i], perm1[r]) 102 }) 103 // Merge in parallel, allocating half of our parallelism to each subroutine. 104 var waitGroup sync.WaitGroup 105 waitGroup.Add(1) 106 go func() { 107 merge(perm1[:r], perm2[:s], less, (parallelism+1)/2, out[:r+s]) 108 waitGroup.Done() 109 }() 110 merge(perm1[r:], perm2[s:], less, parallelism/2, out[r+s:]) 111 waitGroup.Wait() 112 } 113 114 func mergeSerial(perm1, perm2 []int, less func(i, j int) bool, out []int) { 115 var idx1, idx2, idxOut int 116 for idx1 < len(perm1) && idx2 < len(perm2) { 117 if less(perm1[idx1], perm2[idx2]) { 118 out[idxOut] = perm1[idx1] 119 idx1++ 120 } else { 121 out[idxOut] = perm2[idx2] 122 idx2++ 123 } 124 idxOut++ 125 } 126 for idx1 < len(perm1) { 127 out[idxOut] = perm1[idx1] 128 idx1++ 129 idxOut++ 130 } 131 for idx2 < len(perm2) { 132 out[idxOut] = perm2[idx2] 133 idx2++ 134 idxOut++ 135 } 136 }