github.com/jxskiss/gopkg/v2@v2.14.9-0.20240514120614-899f3e7952b4/collection/heapx/heap.go (about) 1 package heapx 2 3 import "unsafe" 4 5 // LessFunc is a comparator function to build a Heap. 6 type LessFunc[T any] func(lhs, rhs T) bool 7 8 // Heap implements the classic heap data-structure. 9 // A Heap is not safe for concurrent operations. 10 type Heap[T any] struct { 11 items heapItems[T] 12 } 13 14 // NewHeap creates a new Heap. 15 func NewHeap[T any](cmp LessFunc[T]) *Heap[T] { 16 h := &Heap[T]{} 17 h.init(cmp) 18 return h 19 } 20 21 func (h *Heap[T]) init(lessFunc LessFunc[T]) { 22 h.items.elemSz = unsafe.Sizeof(*new(T)) 23 h.items.lessFunc = lessFunc 24 } 25 26 // Len returns the size of the heap. 27 func (h *Heap[T]) Len() int { 28 return h.items.Len() 29 } 30 31 // Push pushes the element x onto the heap. 32 // The complexity is O(log n) where n = h.Len(). 33 func (h *Heap[T]) Push(x T) { 34 /* 35 heap.Push(&h.items, x) 36 */ 37 h.items.Push(x) 38 h.items.up(h.Len() - 1) 39 } 40 41 // Peek returns the minium element (according to the LessFunc) in the heap, 42 // it does not remove the item from the heap. 43 // The complexity is O(1). 44 func (h *Heap[T]) Peek() (x T, ok bool) { 45 if h.items.Len() == 0 { 46 return 47 } 48 return h.items.s0[0], true 49 } 50 51 // Pop removes and returns the minimum element (according to the LessFunc) from the heap. 52 // The complexity is O(log n) where n = h.Len(). 53 // Pop is equivalent to Remove(h, 0). 54 func (h *Heap[T]) Pop() (x T, ok bool) { 55 if h.items.Len() == 0 { 56 return 57 } 58 59 /* 60 return heap.Pop(&h.items).(T), true 61 */ 62 n := h.Len() - 1 63 h.items.Swap(0, n) 64 h.items.down(0, n) 65 return h.items.Pop().(T), true 66 } 67 68 const ( 69 bktShift = 11 70 bktSize = 1 << bktShift 71 bktMask = bktSize - 1 72 initSize = 64 73 ptrSize = unsafe.Sizeof(unsafe.Pointer(nil)) 74 75 shrinkThreshold = bktSize / 2 76 ) 77 78 type heapItems[T any] struct { 79 elemSz uintptr 80 lessFunc LessFunc[T] 81 82 cap int 83 len int 84 s0 []T 85 ss []unsafe.Pointer 86 ssPtr unsafe.Pointer 87 } 88 89 // index uses unsafe trick to eliminate slice bounds checking. 90 func (p *heapItems[T]) index(i int) *T { 91 i, j := i>>bktShift, i&bktMask 92 sPtr := unsafe.Pointer(uintptr(p.ssPtr) + uintptr(i)*ptrSize) 93 return (*T)(unsafe.Pointer(uintptr(*(*unsafe.Pointer)(sPtr)) + uintptr(j)*p.elemSz)) 94 } 95 96 func (p *heapItems[T]) addBucket(bkt []T) { 97 p.ss = append(p.ss, unsafe.Pointer(&bkt[0])) 98 p.ssPtr = unsafe.Pointer(&p.ss[0]) 99 } 100 101 func (p *heapItems[T]) Len() int { 102 return p.len 103 } 104 105 func (p *heapItems[T]) Less(i, j int) bool { 106 return p.lessFunc(*p.index(i), *p.index(j)) 107 } 108 109 func (p *heapItems[T]) Swap(i, j int) { 110 p1, p2 := p.index(i), p.index(j) 111 *p1, *p2 = *p2, *p1 112 return 113 } 114 115 func (p *heapItems[T]) Push(x any) { 116 if p.cap < p.len+1 { 117 if p.len == 0 { 118 p.s0 = make([]T, initSize) 119 p.addBucket(p.s0) 120 p.cap = initSize 121 } else if p.cap < bktSize { 122 newBkt := make([]T, p.cap*2) 123 copy(newBkt, p.s0[:p.len]) 124 p.s0 = newBkt 125 p.ss[0] = unsafe.Pointer(&newBkt[0]) 126 p.cap *= 2 127 } else { 128 newBkt := make([]T, bktSize) 129 p.addBucket(newBkt) 130 p.cap += bktSize 131 } 132 } 133 *p.index(p.len) = x.(T) 134 p.len++ 135 } 136 137 func (p *heapItems[T]) Pop() any { 138 var ret, zero T 139 if p.len > 0 { 140 ptr := p.index(p.len - 1) 141 ret, *ptr = *ptr, zero 142 p.len-- 143 } 144 // Shrink buckets and free the underlying memory. 145 if (p.len+shrinkThreshold)&bktMask == 0 && (p.cap-p.len) > bktSize { 146 p.ss[len(p.ss)-1] = nil 147 p.ss = p.ss[:len(p.ss)-1] 148 p.cap -= bktSize 149 } 150 return ret 151 } 152 153 func (p *heapItems[T]) up(j int) { 154 for { 155 i := (j - 1) / 2 // parent 156 if i == j || !p.lessFunc(*p.index(j), *p.index(i)) { 157 break 158 } 159 p1, p2 := p.index(i), p.index(j) 160 *p1, *p2 = *p2, *p1 // swap(i, j) 161 j = i 162 } 163 } 164 165 func (p *heapItems[T]) down(i0, n int) bool { 166 i := i0 167 for { 168 j1 := 2*i + 1 169 if j1 >= n || j1 < 0 { // j1 < 0 after int overflow 170 break 171 } 172 j := j1 // left child 173 if j2 := j1 + 1; j2 < n && p.lessFunc(*p.index(j2), *p.index(j1)) { 174 j = j2 // = 2*i + 2 // right child 175 } 176 if !p.lessFunc(*p.index(j), *p.index(i)) { 177 break 178 } 179 p1, p2 := p.index(i), p.index(j) 180 *p1, *p2 = *p2, *p1 // swap(i, j) 181 i = j 182 } 183 return i > i0 184 }