github.com/jfcg/sorty@v1.2.0/sortyLsw.go (about)

     1  /*	Copyright (c) 2019, Serhat Şevki Dinçer.
     2  	This Source Code Form is subject to the terms of the Mozilla Public
     3  	License, v. 2.0. If a copy of the MPL was not distributed with this
     4  	file, You can obtain one at http://mozilla.org/MPL/2.0/.
     5  */
     6  
     7  package sorty
     8  
     9  import "sync/atomic"
    10  
    11  // Lesswap function operates on an underlying collection to be sorted as:
    12  //  if less(i, k) { // strict ordering like < or >
    13  //  	if r != s {
    14  //  		swap(r, s)
    15  //  	}
    16  //  	return true
    17  //  }
    18  //  return false
    19  type Lesswap func(i, k, r, s int) bool
    20  
    21  // IsSorted returns 0 if underlying collection of length n is sorted,
    22  // otherwise it returns i > 0 with less(i,i-1) = true.
    23  func IsSorted(n int, lsw Lesswap) int {
    24  	for i := n - 1; i > 0; i-- {
    25  		if lsw(i, i-1, i, i) { // 3rd=4th disables swap
    26  			return i
    27  		}
    28  	}
    29  	return 0
    30  }
    31  
    32  // insertion sort ar[lo..hi], assumes lo < hi
    33  func insertion(lsw Lesswap, lo, hi int) {
    34  
    35  	for l, h := mid(lo, hi-1)-1, hi; l >= lo; {
    36  		lsw(h, l, h, l)
    37  		l--
    38  		h--
    39  	}
    40  	for h := lo; ; {
    41  		l := h
    42  		h++
    43  		k := h
    44  		for lsw(k, l, k, l) {
    45  			k = l
    46  			l--
    47  			if l < lo {
    48  				break
    49  			}
    50  		}
    51  		if h >= hi {
    52  			break
    53  		}
    54  	}
    55  }
    56  
    57  // pivot divides ar[lo..hi] into 2n+1 equal intervals, sorts mid-points of them
    58  // to find median-of-2n+1 pivot. ensures lo/hi ranges have at least n elements by
    59  // moving 2n of mid-points to n positions at lo/hi ends.
    60  // assumes n > 0, lo+4n+1 < hi. returns start,pivot,end for partitioning.
    61  func pivot(lsw Lesswap, lo, hi, n int) (int, int, int) {
    62  	m := mid(lo, hi)
    63  	s := (hi - lo + 1) / (2*n + 1) // step > 1
    64  	l, h := m-n*s, m+n*s
    65  
    66  	for q, k := h, m-2*s; k >= l; { // insertion sort ar[m+i*s], i=-n..n
    67  		lsw(q, k, q, k)
    68  		q -= s
    69  		k -= s
    70  	}
    71  	for r := l; ; {
    72  		k := r
    73  		r += s
    74  		q := r
    75  		for lsw(q, k, q, k) {
    76  			q = k
    77  			k -= s
    78  			if k < l {
    79  				break
    80  			}
    81  		}
    82  		if r >= h {
    83  			break
    84  		}
    85  	}
    86  
    87  	// move hi mid-points to hi end
    88  	for {
    89  		if h == hi || lsw(hi, h, hi, h) {
    90  			h -= s
    91  		}
    92  		hi--
    93  		if h <= m {
    94  			break
    95  		}
    96  	}
    97  
    98  	// move lo mid-points to lo end
    99  	for {
   100  		if l == lo || lsw(l, lo, l, lo) {
   101  			l += s
   102  		}
   103  		lo++
   104  		if l >= m {
   105  			break
   106  		}
   107  	}
   108  	return lo, m, hi // lo <= m-s+1, m+s-1 <= hi
   109  }
   110  
   111  // partition ar[l..h] into <= and >= pivot, assumes l < h
   112  // returns m with ar[:m] <= pivot, ar[m:] >= pivot
   113  func partition1(lsw Lesswap, l, pv, h int) int {
   114  	// avoid unnecessary comparisons, extend ranges in balance
   115  	for {
   116  		if lsw(h, pv, h, h) { // 3rd=4th disables swap
   117  			for {
   118  				if lsw(pv, l, h, l) {
   119  					break
   120  				}
   121  				l++
   122  				if l >= h {
   123  					return l + 1
   124  				}
   125  			}
   126  		} else if lsw(pv, l, l, l) { // 3rd=4th disables swap
   127  			for {
   128  				h--
   129  				if l >= h {
   130  					return l
   131  				}
   132  				if lsw(h, pv, h, l) {
   133  					break
   134  				}
   135  			}
   136  		}
   137  		l++
   138  		h--
   139  		if l >= h {
   140  			break
   141  		}
   142  	}
   143  	// classify mid element
   144  	if l == h && (h == pv || lsw(h, pv, h, h)) { // 3rd=4th disables swap
   145  		l++
   146  	}
   147  	return l
   148  }
   149  
   150  // rearrange ar[l..a] and ar[b..h] into <= and >= pivot, assumes l <= a < pv < b <= h
   151  // gap (a..b) expands until one of the intervals is fully consumed
   152  func partition2(lsw Lesswap, l, a, pv, b, h int) (int, int) {
   153  	// avoid unnecessary comparisons, extend ranges in balance
   154  	for {
   155  		if lsw(b, pv, b, b) { // 3rd=4th disables swap
   156  			for {
   157  				if lsw(pv, a, b, a) {
   158  					break
   159  				}
   160  				a--
   161  				if a < l {
   162  					return a, b
   163  				}
   164  			}
   165  		} else if lsw(pv, a, a, a) { // 3rd=4th disables swap
   166  			for {
   167  				b++
   168  				if b > h {
   169  					return a, b
   170  				}
   171  				if lsw(b, pv, b, a) {
   172  					break
   173  				}
   174  			}
   175  		}
   176  		a--
   177  		b++
   178  		if a < l || b > h {
   179  			return a, b
   180  		}
   181  	}
   182  }
   183  
   184  // new-goroutine partition
   185  func gpart1(lsw Lesswap, l, pv, h int, ch chan int) {
   186  	ch <- partition1(lsw, l, pv, h)
   187  }
   188  
   189  // concurrent dual partitioning
   190  // returns m with ar[:m] <= pivot, ar[m:] >= pivot
   191  func cdualpar(lsw Lesswap, lo, hi int, ch chan int) int {
   192  
   193  	lo, pv, hi := pivot(lsw, lo, hi, 4) // median-of-9
   194  
   195  	if hi-lo <= 2*Mlr { // guard against short remaining range
   196  		return partition1(lsw, lo, pv, hi)
   197  	}
   198  
   199  	m := mid(lo, hi) // in pivot() lo/hi changed by possibly unequal amounts
   200  	a, b := mid(lo, m), mid(m, hi)
   201  
   202  	go gpart1(lsw, a+1, pv, b-1, ch) // mid half range
   203  
   204  	a, b = partition2(lsw, lo, a, pv, b, hi) // left/right quarter ranges
   205  	m = <-ch
   206  
   207  	// only one gap is possible
   208  	for ; lo <= a; a-- { // gap left in low range?
   209  		if lsw(pv, a, m-1, a) {
   210  			m--
   211  			if m == pv { // swapped pivot when closing gap?
   212  				pv = a // Thanks to my wife Tansu who discovered this
   213  			}
   214  		}
   215  	}
   216  	for ; b <= hi; b++ { // gap left in high range?
   217  		if lsw(b, pv, b, m) {
   218  			if m == pv { // swapped pivot when closing gap?
   219  				pv = b // It took days of agony to discover these two if's :D
   220  			}
   221  			m++
   222  		}
   223  	}
   224  	return m
   225  }
   226  
   227  // short range sort function, assumes Hmli <= hi-lo < Mlr
   228  func short(lsw Lesswap, lo, hi int) {
   229  start:
   230  	l, pv, h := pivot(lsw, lo, hi, 2)
   231  	l = partition1(lsw, l, pv, h) // median-of-5 partitioning
   232  	h = l - 1
   233  	no, n := h-lo, hi-l
   234  
   235  	if no < n {
   236  		n, no = no, n // [lo,hi] is the longer range
   237  		l, lo = lo, l
   238  	} else {
   239  		h, hi = hi, h
   240  	}
   241  
   242  	if n >= Hmli {
   243  		short(lsw, l, h) // recurse on the shorter range
   244  		goto start
   245  	}
   246  	insertion(lsw, l, h) // at least one insertion range
   247  
   248  	if no >= Hmli {
   249  		goto start
   250  	}
   251  	insertion(lsw, lo, hi) // two insertion ranges
   252  }
   253  
   254  // long range sort function (single goroutine), assumes hi-lo >= Mlr
   255  func slong(lsw Lesswap, lo, hi int) {
   256  start:
   257  	l, pv, h := pivot(lsw, lo, hi, 3)
   258  	l = partition1(lsw, l, pv, h) // median-of-7 partitioning
   259  	h = l - 1
   260  	no, n := h-lo, hi-l
   261  
   262  	if no < n {
   263  		n, no = no, n // [lo,hi] is the longer range
   264  		l, lo = lo, l
   265  	} else {
   266  		h, hi = hi, h
   267  	}
   268  
   269  	if n >= Mlr { // at least one not-long range?
   270  		slong(lsw, l, h) // recurse on the shorter range
   271  		goto start
   272  	}
   273  
   274  	if n >= Hmli {
   275  		short(lsw, l, h)
   276  	} else {
   277  		insertion(lsw, l, h)
   278  	}
   279  
   280  	if no >= Mlr { // two not-long ranges?
   281  		goto start
   282  	}
   283  	short(lsw, lo, hi) // we know no >= Hmli
   284  }
   285  
   286  // new-goroutine sort function
   287  func glong(lsw Lesswap, lo, hi int, sv *syncVar) {
   288  	long(lsw, lo, hi, sv)
   289  
   290  	if atomic.AddUint32(&sv.ngr, ^uint32(0)) == 0 { // decrease goroutine counter
   291  		sv.done <- 0 // we are the last, all done
   292  	}
   293  }
   294  
   295  // long range sort function, assumes hi-lo >= Mlr
   296  func long(lsw Lesswap, lo, hi int, sv *syncVar) {
   297  start:
   298  	l, pv, h := pivot(lsw, lo, hi, 3)
   299  	l = partition1(lsw, l, pv, h) // median-of-7 partitioning
   300  	h = l - 1
   301  	no, n := h-lo, hi-l
   302  
   303  	if no < n {
   304  		n, no = no, n // [lo,hi] is the longer range
   305  		l, lo = lo, l
   306  	} else {
   307  		h, hi = hi, h
   308  	}
   309  
   310  	// branches below are optimal for fewer total jumps
   311  	if n < Mlr { // at least one not-long range?
   312  
   313  		if n >= Hmli {
   314  			short(lsw, l, h)
   315  		} else {
   316  			insertion(lsw, l, h)
   317  		}
   318  
   319  		if no >= Mlr { // two not-long ranges?
   320  			goto start
   321  		}
   322  		short(lsw, lo, hi) // we know no >= Hmli
   323  		return
   324  	}
   325  
   326  	// max goroutines? not atomic but good enough
   327  	if sv.ngr >= Mxg {
   328  		long(lsw, l, h, sv) // recurse on the shorter range
   329  		goto start
   330  	}
   331  
   332  	if atomic.AddUint32(&sv.ngr, 1) == 0 { // increase goroutine counter
   333  		panic("sorty: long: counter overflow")
   334  	}
   335  	// new-goroutine sort on the longer range only when
   336  	// both ranges are big and max goroutines is not exceeded
   337  	go glong(lsw, lo, hi, sv)
   338  	lo, hi = l, h
   339  	goto start
   340  }
   341  
   342  // Sort concurrently sorts underlying collection of length n via lsw().
   343  // Once for each non-trivial type you want to sort in a certain way, you
   344  // can implement a custom sorting routine (for a slice for example) as:
   345  //  func SortObjAsc(c []Obj) {
   346  //  	lsw := func(i, k, r, s int) bool {
   347  //  		if c[i].Key < c[k].Key { // strict comparator like < or >
   348  //  			if r != s {
   349  //  				c[r], c[s] = c[s], c[r]
   350  //  			}
   351  //  			return true
   352  //  		}
   353  //  		return false
   354  //  	}
   355  //  	sorty.Sort(len(c), lsw)
   356  //  }
   357  // Lesswap is a 'contract' between users and sorty library. Strict
   358  // comparator, r!=s check, swap and returns are all strictly necessary.
   359  func Sort(n int, lsw Lesswap) {
   360  
   361  	n-- // high indice
   362  	if n <= 2*Mlr || Mxg <= 1 {
   363  
   364  		// single-goroutine sorting
   365  		if n >= Mlr {
   366  			slong(lsw, 0, n)
   367  		} else if n >= Hmli {
   368  			short(lsw, 0, n)
   369  		} else if n > 0 {
   370  			insertion(lsw, 0, n)
   371  		}
   372  		return
   373  	}
   374  
   375  	// create channel only when concurrent partitioning & sorting
   376  	sv := syncVar{1, // number of goroutines including this
   377  		make(chan int)} // end signal
   378  	lo, hi := 0, n
   379  	for {
   380  		// median-of-9 concurrent dual partitioning with done
   381  		l := cdualpar(lsw, lo, hi, sv.done)
   382  		h := l - 1
   383  		no, n := h-lo, hi-l
   384  
   385  		if no < n {
   386  			n, no = no, n // [lo,hi] is the longer range
   387  			l, lo = lo, l
   388  		} else {
   389  			h, hi = hi, h
   390  		}
   391  
   392  		// handle shorter range
   393  		if n >= Mlr {
   394  			if atomic.AddUint32(&sv.ngr, 1) == 0 { // increase goroutine counter
   395  				panic("sorty: Sort: counter overflow")
   396  			}
   397  			go glong(lsw, l, h, &sv)
   398  
   399  		} else if n >= Hmli {
   400  			short(lsw, l, h)
   401  		} else {
   402  			insertion(lsw, l, h)
   403  		}
   404  
   405  		// longer range big enough? max goroutines?
   406  		if no <= 2*Mlr || sv.ngr >= Mxg {
   407  			break
   408  		}
   409  		// dual partition longer range
   410  	}
   411  
   412  	long(lsw, lo, hi, &sv) // we know hi-lo >= Mlr
   413  
   414  	if atomic.AddUint32(&sv.ngr, ^uint32(0)) != 0 { // decrease goroutine counter
   415  		<-sv.done // we are not the last, wait
   416  	}
   417  }