github.com/jfcg/sorty@v1.2.0/sortyLsw.go (about) 1 /* Copyright (c) 2019, Serhat Şevki Dinçer. 2 This Source Code Form is subject to the terms of the Mozilla Public 3 License, v. 2.0. If a copy of the MPL was not distributed with this 4 file, You can obtain one at http://mozilla.org/MPL/2.0/. 5 */ 6 7 package sorty 8 9 import "sync/atomic" 10 11 // Lesswap function operates on an underlying collection to be sorted as: 12 // if less(i, k) { // strict ordering like < or > 13 // if r != s { 14 // swap(r, s) 15 // } 16 // return true 17 // } 18 // return false 19 type Lesswap func(i, k, r, s int) bool 20 21 // IsSorted returns 0 if underlying collection of length n is sorted, 22 // otherwise it returns i > 0 with less(i,i-1) = true. 23 func IsSorted(n int, lsw Lesswap) int { 24 for i := n - 1; i > 0; i-- { 25 if lsw(i, i-1, i, i) { // 3rd=4th disables swap 26 return i 27 } 28 } 29 return 0 30 } 31 32 // insertion sort ar[lo..hi], assumes lo < hi 33 func insertion(lsw Lesswap, lo, hi int) { 34 35 for l, h := mid(lo, hi-1)-1, hi; l >= lo; { 36 lsw(h, l, h, l) 37 l-- 38 h-- 39 } 40 for h := lo; ; { 41 l := h 42 h++ 43 k := h 44 for lsw(k, l, k, l) { 45 k = l 46 l-- 47 if l < lo { 48 break 49 } 50 } 51 if h >= hi { 52 break 53 } 54 } 55 } 56 57 // pivot divides ar[lo..hi] into 2n+1 equal intervals, sorts mid-points of them 58 // to find median-of-2n+1 pivot. ensures lo/hi ranges have at least n elements by 59 // moving 2n of mid-points to n positions at lo/hi ends. 60 // assumes n > 0, lo+4n+1 < hi. returns start,pivot,end for partitioning. 61 func pivot(lsw Lesswap, lo, hi, n int) (int, int, int) { 62 m := mid(lo, hi) 63 s := (hi - lo + 1) / (2*n + 1) // step > 1 64 l, h := m-n*s, m+n*s 65 66 for q, k := h, m-2*s; k >= l; { // insertion sort ar[m+i*s], i=-n..n 67 lsw(q, k, q, k) 68 q -= s 69 k -= s 70 } 71 for r := l; ; { 72 k := r 73 r += s 74 q := r 75 for lsw(q, k, q, k) { 76 q = k 77 k -= s 78 if k < l { 79 break 80 } 81 } 82 if r >= h { 83 break 84 } 85 } 86 87 // move hi mid-points to hi end 88 for { 89 if h == hi || lsw(hi, h, hi, h) { 90 h -= s 91 } 92 hi-- 93 if h <= m { 94 break 95 } 96 } 97 98 // move lo mid-points to lo end 99 for { 100 if l == lo || lsw(l, lo, l, lo) { 101 l += s 102 } 103 lo++ 104 if l >= m { 105 break 106 } 107 } 108 return lo, m, hi // lo <= m-s+1, m+s-1 <= hi 109 } 110 111 // partition ar[l..h] into <= and >= pivot, assumes l < h 112 // returns m with ar[:m] <= pivot, ar[m:] >= pivot 113 func partition1(lsw Lesswap, l, pv, h int) int { 114 // avoid unnecessary comparisons, extend ranges in balance 115 for { 116 if lsw(h, pv, h, h) { // 3rd=4th disables swap 117 for { 118 if lsw(pv, l, h, l) { 119 break 120 } 121 l++ 122 if l >= h { 123 return l + 1 124 } 125 } 126 } else if lsw(pv, l, l, l) { // 3rd=4th disables swap 127 for { 128 h-- 129 if l >= h { 130 return l 131 } 132 if lsw(h, pv, h, l) { 133 break 134 } 135 } 136 } 137 l++ 138 h-- 139 if l >= h { 140 break 141 } 142 } 143 // classify mid element 144 if l == h && (h == pv || lsw(h, pv, h, h)) { // 3rd=4th disables swap 145 l++ 146 } 147 return l 148 } 149 150 // rearrange ar[l..a] and ar[b..h] into <= and >= pivot, assumes l <= a < pv < b <= h 151 // gap (a..b) expands until one of the intervals is fully consumed 152 func partition2(lsw Lesswap, l, a, pv, b, h int) (int, int) { 153 // avoid unnecessary comparisons, extend ranges in balance 154 for { 155 if lsw(b, pv, b, b) { // 3rd=4th disables swap 156 for { 157 if lsw(pv, a, b, a) { 158 break 159 } 160 a-- 161 if a < l { 162 return a, b 163 } 164 } 165 } else if lsw(pv, a, a, a) { // 3rd=4th disables swap 166 for { 167 b++ 168 if b > h { 169 return a, b 170 } 171 if lsw(b, pv, b, a) { 172 break 173 } 174 } 175 } 176 a-- 177 b++ 178 if a < l || b > h { 179 return a, b 180 } 181 } 182 } 183 184 // new-goroutine partition 185 func gpart1(lsw Lesswap, l, pv, h int, ch chan int) { 186 ch <- partition1(lsw, l, pv, h) 187 } 188 189 // concurrent dual partitioning 190 // returns m with ar[:m] <= pivot, ar[m:] >= pivot 191 func cdualpar(lsw Lesswap, lo, hi int, ch chan int) int { 192 193 lo, pv, hi := pivot(lsw, lo, hi, 4) // median-of-9 194 195 if hi-lo <= 2*Mlr { // guard against short remaining range 196 return partition1(lsw, lo, pv, hi) 197 } 198 199 m := mid(lo, hi) // in pivot() lo/hi changed by possibly unequal amounts 200 a, b := mid(lo, m), mid(m, hi) 201 202 go gpart1(lsw, a+1, pv, b-1, ch) // mid half range 203 204 a, b = partition2(lsw, lo, a, pv, b, hi) // left/right quarter ranges 205 m = <-ch 206 207 // only one gap is possible 208 for ; lo <= a; a-- { // gap left in low range? 209 if lsw(pv, a, m-1, a) { 210 m-- 211 if m == pv { // swapped pivot when closing gap? 212 pv = a // Thanks to my wife Tansu who discovered this 213 } 214 } 215 } 216 for ; b <= hi; b++ { // gap left in high range? 217 if lsw(b, pv, b, m) { 218 if m == pv { // swapped pivot when closing gap? 219 pv = b // It took days of agony to discover these two if's :D 220 } 221 m++ 222 } 223 } 224 return m 225 } 226 227 // short range sort function, assumes Hmli <= hi-lo < Mlr 228 func short(lsw Lesswap, lo, hi int) { 229 start: 230 l, pv, h := pivot(lsw, lo, hi, 2) 231 l = partition1(lsw, l, pv, h) // median-of-5 partitioning 232 h = l - 1 233 no, n := h-lo, hi-l 234 235 if no < n { 236 n, no = no, n // [lo,hi] is the longer range 237 l, lo = lo, l 238 } else { 239 h, hi = hi, h 240 } 241 242 if n >= Hmli { 243 short(lsw, l, h) // recurse on the shorter range 244 goto start 245 } 246 insertion(lsw, l, h) // at least one insertion range 247 248 if no >= Hmli { 249 goto start 250 } 251 insertion(lsw, lo, hi) // two insertion ranges 252 } 253 254 // long range sort function (single goroutine), assumes hi-lo >= Mlr 255 func slong(lsw Lesswap, lo, hi int) { 256 start: 257 l, pv, h := pivot(lsw, lo, hi, 3) 258 l = partition1(lsw, l, pv, h) // median-of-7 partitioning 259 h = l - 1 260 no, n := h-lo, hi-l 261 262 if no < n { 263 n, no = no, n // [lo,hi] is the longer range 264 l, lo = lo, l 265 } else { 266 h, hi = hi, h 267 } 268 269 if n >= Mlr { // at least one not-long range? 270 slong(lsw, l, h) // recurse on the shorter range 271 goto start 272 } 273 274 if n >= Hmli { 275 short(lsw, l, h) 276 } else { 277 insertion(lsw, l, h) 278 } 279 280 if no >= Mlr { // two not-long ranges? 281 goto start 282 } 283 short(lsw, lo, hi) // we know no >= Hmli 284 } 285 286 // new-goroutine sort function 287 func glong(lsw Lesswap, lo, hi int, sv *syncVar) { 288 long(lsw, lo, hi, sv) 289 290 if atomic.AddUint32(&sv.ngr, ^uint32(0)) == 0 { // decrease goroutine counter 291 sv.done <- 0 // we are the last, all done 292 } 293 } 294 295 // long range sort function, assumes hi-lo >= Mlr 296 func long(lsw Lesswap, lo, hi int, sv *syncVar) { 297 start: 298 l, pv, h := pivot(lsw, lo, hi, 3) 299 l = partition1(lsw, l, pv, h) // median-of-7 partitioning 300 h = l - 1 301 no, n := h-lo, hi-l 302 303 if no < n { 304 n, no = no, n // [lo,hi] is the longer range 305 l, lo = lo, l 306 } else { 307 h, hi = hi, h 308 } 309 310 // branches below are optimal for fewer total jumps 311 if n < Mlr { // at least one not-long range? 312 313 if n >= Hmli { 314 short(lsw, l, h) 315 } else { 316 insertion(lsw, l, h) 317 } 318 319 if no >= Mlr { // two not-long ranges? 320 goto start 321 } 322 short(lsw, lo, hi) // we know no >= Hmli 323 return 324 } 325 326 // max goroutines? not atomic but good enough 327 if sv.ngr >= Mxg { 328 long(lsw, l, h, sv) // recurse on the shorter range 329 goto start 330 } 331 332 if atomic.AddUint32(&sv.ngr, 1) == 0 { // increase goroutine counter 333 panic("sorty: long: counter overflow") 334 } 335 // new-goroutine sort on the longer range only when 336 // both ranges are big and max goroutines is not exceeded 337 go glong(lsw, lo, hi, sv) 338 lo, hi = l, h 339 goto start 340 } 341 342 // Sort concurrently sorts underlying collection of length n via lsw(). 343 // Once for each non-trivial type you want to sort in a certain way, you 344 // can implement a custom sorting routine (for a slice for example) as: 345 // func SortObjAsc(c []Obj) { 346 // lsw := func(i, k, r, s int) bool { 347 // if c[i].Key < c[k].Key { // strict comparator like < or > 348 // if r != s { 349 // c[r], c[s] = c[s], c[r] 350 // } 351 // return true 352 // } 353 // return false 354 // } 355 // sorty.Sort(len(c), lsw) 356 // } 357 // Lesswap is a 'contract' between users and sorty library. Strict 358 // comparator, r!=s check, swap and returns are all strictly necessary. 359 func Sort(n int, lsw Lesswap) { 360 361 n-- // high indice 362 if n <= 2*Mlr || Mxg <= 1 { 363 364 // single-goroutine sorting 365 if n >= Mlr { 366 slong(lsw, 0, n) 367 } else if n >= Hmli { 368 short(lsw, 0, n) 369 } else if n > 0 { 370 insertion(lsw, 0, n) 371 } 372 return 373 } 374 375 // create channel only when concurrent partitioning & sorting 376 sv := syncVar{1, // number of goroutines including this 377 make(chan int)} // end signal 378 lo, hi := 0, n 379 for { 380 // median-of-9 concurrent dual partitioning with done 381 l := cdualpar(lsw, lo, hi, sv.done) 382 h := l - 1 383 no, n := h-lo, hi-l 384 385 if no < n { 386 n, no = no, n // [lo,hi] is the longer range 387 l, lo = lo, l 388 } else { 389 h, hi = hi, h 390 } 391 392 // handle shorter range 393 if n >= Mlr { 394 if atomic.AddUint32(&sv.ngr, 1) == 0 { // increase goroutine counter 395 panic("sorty: Sort: counter overflow") 396 } 397 go glong(lsw, l, h, &sv) 398 399 } else if n >= Hmli { 400 short(lsw, l, h) 401 } else { 402 insertion(lsw, l, h) 403 } 404 405 // longer range big enough? max goroutines? 406 if no <= 2*Mlr || sv.ngr >= Mxg { 407 break 408 } 409 // dual partition longer range 410 } 411 412 long(lsw, lo, hi, &sv) // we know hi-lo >= Mlr 413 414 if atomic.AddUint32(&sv.ngr, ^uint32(0)) != 0 { // decrease goroutine counter 415 <-sv.done // we are not the last, wait 416 } 417 }