github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/psort/mergesort.go (about)

     1  package psort
     2  
     3  import (
     4  	"reflect"
     5  	"sort"
     6  	"sync"
     7  
     8  	"github.com/Schaudge/grailbase/traverse"
     9  )
    10  
    11  const (
    12  	serialThreshold = 128
    13  )
    14  
    15  // Slice sorts the given slice according to the ordering induced by the provided
    16  // less function. Parallel computation will be attempted, up to the limit imposed by
    17  // parallelism. This function can be much faster than the standard library's sort.Slice()
    18  // when sorting large slices on multicore machines.
    19  func Slice(slice interface{}, less func(i, j int) bool, parallelism int) {
    20  	if parallelism < 1 {
    21  		panic("parallelism must be at least 1")
    22  	}
    23  	if reflect.TypeOf(slice).Kind() != reflect.Slice {
    24  		panic("input interface was not of slice type")
    25  	}
    26  	rv := reflect.ValueOf(slice)
    27  	if rv.Len() < 2 {
    28  		return
    29  	}
    30  	// For clarity, we will sort a slice containing indices from the input slice. Then,
    31  	// we will set the elements of the input slice according to this permutation. This
    32  	// avoids difficult-to-understand reflection types and calls in most of the code.
    33  	perm := make([]int, rv.Len())
    34  	for i := range perm {
    35  		perm[i] = i
    36  	}
    37  	scratch := make([]int, len(perm))
    38  	mergeSort(perm, less, parallelism, scratch)
    39  	result := reflect.MakeSlice(rv.Type(), rv.Len(), rv.Len())
    40  	_ = traverse.Limit(parallelism).Range(rv.Len(), func(start, end int) error {
    41  		for i := start; i < end; i++ {
    42  			result.Index(i).Set(rv.Index(perm[i]))
    43  		}
    44  		return nil
    45  	})
    46  	_ = traverse.Limit(parallelism).Range(rv.Len(), func(start, end int) error {
    47  		reflect.Copy(rv.Slice(start, end), result.Slice(start, end))
    48  		return nil
    49  	})
    50  }
    51  
    52  func mergeSort(perm []int, less func(i, j int) bool, parallelism int, scratch []int) {
    53  	if parallelism == 1 || len(perm) < serialThreshold {
    54  		sortSerial(perm, less)
    55  		return
    56  	}
    57  
    58  	// Sort two halves of the slice in parallel, allocating half of our parallelism to
    59  	// each subroutine.
    60  	left := perm[:len(perm)/2]
    61  	right := perm[len(perm)/2:]
    62  	var waitGroup sync.WaitGroup
    63  	waitGroup.Add(1)
    64  	go func() {
    65  		mergeSort(left, less, (parallelism+1)/2, scratch[:len(perm)/2])
    66  		waitGroup.Done()
    67  	}()
    68  	mergeSort(right, less, parallelism/2, scratch[len(perm)/2:])
    69  	waitGroup.Wait()
    70  
    71  	merge(left, right, less, parallelism, scratch)
    72  	parallelCopy(perm, scratch, parallelism)
    73  }
    74  
    75  func parallelCopy(dst, src []int, parallelism int) {
    76  	_ = traverse.Limit(parallelism).Range(len(dst), func(start, end int) error {
    77  		copy(dst[start:end], src[start:end])
    78  		return nil
    79  	})
    80  }
    81  
    82  func sortSerial(perm []int, less func(i, j int) bool) {
    83  	sort.Slice(perm, func(i, j int) bool {
    84  		return less(perm[i], perm[j])
    85  	})
    86  }
    87  
    88  func merge(perm1, perm2 []int, less func(i, j int) bool, parallelism int, out []int) {
    89  	if parallelism == 1 || len(perm1)+len(perm2) < serialThreshold {
    90  		mergeSerial(perm1, perm2, less, out)
    91  		return
    92  	}
    93  
    94  	if len(perm1) < len(perm2) {
    95  		perm1, perm2 = perm2, perm1
    96  	}
    97  	// Find the index in perm2 such that all elements to the left are smaller than
    98  	// the midpoint element of perm1.
    99  	r := len(perm1) / 2
   100  	s := sort.Search(len(perm2), func(i int) bool {
   101  		return !less(perm2[i], perm1[r])
   102  	})
   103  	// Merge in parallel, allocating half of our parallelism to each subroutine.
   104  	var waitGroup sync.WaitGroup
   105  	waitGroup.Add(1)
   106  	go func() {
   107  		merge(perm1[:r], perm2[:s], less, (parallelism+1)/2, out[:r+s])
   108  		waitGroup.Done()
   109  	}()
   110  	merge(perm1[r:], perm2[s:], less, parallelism/2, out[r+s:])
   111  	waitGroup.Wait()
   112  }
   113  
   114  func mergeSerial(perm1, perm2 []int, less func(i, j int) bool, out []int) {
   115  	var idx1, idx2, idxOut int
   116  	for idx1 < len(perm1) && idx2 < len(perm2) {
   117  		if less(perm1[idx1], perm2[idx2]) {
   118  			out[idxOut] = perm1[idx1]
   119  			idx1++
   120  		} else {
   121  			out[idxOut] = perm2[idx2]
   122  			idx2++
   123  		}
   124  		idxOut++
   125  	}
   126  	for idx1 < len(perm1) {
   127  		out[idxOut] = perm1[idx1]
   128  		idx1++
   129  		idxOut++
   130  	}
   131  	for idx2 < len(perm2) {
   132  		out[idxOut] = perm2[idx2]
   133  		idx2++
   134  		idxOut++
   135  	}
   136  }