github.com/jxskiss/gopkg/v2@v2.14.9-0.20240514120614-899f3e7952b4/collection/heapx/heap.go (about)

     1  package heapx
     2  
     3  import "unsafe"
     4  
     5  // LessFunc is a comparator function to build a Heap.
     6  type LessFunc[T any] func(lhs, rhs T) bool
     7  
     8  // Heap implements the classic heap data-structure.
     9  // A Heap is not safe for concurrent operations.
    10  type Heap[T any] struct {
    11  	items heapItems[T]
    12  }
    13  
    14  // NewHeap creates a new Heap.
    15  func NewHeap[T any](cmp LessFunc[T]) *Heap[T] {
    16  	h := &Heap[T]{}
    17  	h.init(cmp)
    18  	return h
    19  }
    20  
    21  func (h *Heap[T]) init(lessFunc LessFunc[T]) {
    22  	h.items.elemSz = unsafe.Sizeof(*new(T))
    23  	h.items.lessFunc = lessFunc
    24  }
    25  
    26  // Len returns the size of the heap.
    27  func (h *Heap[T]) Len() int {
    28  	return h.items.Len()
    29  }
    30  
    31  // Push pushes the element x onto the heap.
    32  // The complexity is O(log n) where n = h.Len().
    33  func (h *Heap[T]) Push(x T) {
    34  	/*
    35  		heap.Push(&h.items, x)
    36  	*/
    37  	h.items.Push(x)
    38  	h.items.up(h.Len() - 1)
    39  }
    40  
    41  // Peek returns the minium element (according to the LessFunc) in the heap,
    42  // it does not remove the item from the heap.
    43  // The complexity is O(1).
    44  func (h *Heap[T]) Peek() (x T, ok bool) {
    45  	if h.items.Len() == 0 {
    46  		return
    47  	}
    48  	return h.items.s0[0], true
    49  }
    50  
    51  // Pop removes and returns the minimum element (according to the LessFunc) from the heap.
    52  // The complexity is O(log n) where n = h.Len().
    53  // Pop is equivalent to Remove(h, 0).
    54  func (h *Heap[T]) Pop() (x T, ok bool) {
    55  	if h.items.Len() == 0 {
    56  		return
    57  	}
    58  
    59  	/*
    60  		return heap.Pop(&h.items).(T), true
    61  	*/
    62  	n := h.Len() - 1
    63  	h.items.Swap(0, n)
    64  	h.items.down(0, n)
    65  	return h.items.Pop().(T), true
    66  }
    67  
    68  const (
    69  	bktShift = 11
    70  	bktSize  = 1 << bktShift
    71  	bktMask  = bktSize - 1
    72  	initSize = 64
    73  	ptrSize  = unsafe.Sizeof(unsafe.Pointer(nil))
    74  
    75  	shrinkThreshold = bktSize / 2
    76  )
    77  
    78  type heapItems[T any] struct {
    79  	elemSz   uintptr
    80  	lessFunc LessFunc[T]
    81  
    82  	cap   int
    83  	len   int
    84  	s0    []T
    85  	ss    []unsafe.Pointer
    86  	ssPtr unsafe.Pointer
    87  }
    88  
    89  // index uses unsafe trick to eliminate slice bounds checking.
    90  func (p *heapItems[T]) index(i int) *T {
    91  	i, j := i>>bktShift, i&bktMask
    92  	sPtr := unsafe.Pointer(uintptr(p.ssPtr) + uintptr(i)*ptrSize)
    93  	return (*T)(unsafe.Pointer(uintptr(*(*unsafe.Pointer)(sPtr)) + uintptr(j)*p.elemSz))
    94  }
    95  
    96  func (p *heapItems[T]) addBucket(bkt []T) {
    97  	p.ss = append(p.ss, unsafe.Pointer(&bkt[0]))
    98  	p.ssPtr = unsafe.Pointer(&p.ss[0])
    99  }
   100  
   101  func (p *heapItems[T]) Len() int {
   102  	return p.len
   103  }
   104  
   105  func (p *heapItems[T]) Less(i, j int) bool {
   106  	return p.lessFunc(*p.index(i), *p.index(j))
   107  }
   108  
   109  func (p *heapItems[T]) Swap(i, j int) {
   110  	p1, p2 := p.index(i), p.index(j)
   111  	*p1, *p2 = *p2, *p1
   112  	return
   113  }
   114  
   115  func (p *heapItems[T]) Push(x any) {
   116  	if p.cap < p.len+1 {
   117  		if p.len == 0 {
   118  			p.s0 = make([]T, initSize)
   119  			p.addBucket(p.s0)
   120  			p.cap = initSize
   121  		} else if p.cap < bktSize {
   122  			newBkt := make([]T, p.cap*2)
   123  			copy(newBkt, p.s0[:p.len])
   124  			p.s0 = newBkt
   125  			p.ss[0] = unsafe.Pointer(&newBkt[0])
   126  			p.cap *= 2
   127  		} else {
   128  			newBkt := make([]T, bktSize)
   129  			p.addBucket(newBkt)
   130  			p.cap += bktSize
   131  		}
   132  	}
   133  	*p.index(p.len) = x.(T)
   134  	p.len++
   135  }
   136  
   137  func (p *heapItems[T]) Pop() any {
   138  	var ret, zero T
   139  	if p.len > 0 {
   140  		ptr := p.index(p.len - 1)
   141  		ret, *ptr = *ptr, zero
   142  		p.len--
   143  	}
   144  	// Shrink buckets and free the underlying memory.
   145  	if (p.len+shrinkThreshold)&bktMask == 0 && (p.cap-p.len) > bktSize {
   146  		p.ss[len(p.ss)-1] = nil
   147  		p.ss = p.ss[:len(p.ss)-1]
   148  		p.cap -= bktSize
   149  	}
   150  	return ret
   151  }
   152  
   153  func (p *heapItems[T]) up(j int) {
   154  	for {
   155  		i := (j - 1) / 2 // parent
   156  		if i == j || !p.lessFunc(*p.index(j), *p.index(i)) {
   157  			break
   158  		}
   159  		p1, p2 := p.index(i), p.index(j)
   160  		*p1, *p2 = *p2, *p1 // swap(i, j)
   161  		j = i
   162  	}
   163  }
   164  
   165  func (p *heapItems[T]) down(i0, n int) bool {
   166  	i := i0
   167  	for {
   168  		j1 := 2*i + 1
   169  		if j1 >= n || j1 < 0 { // j1 < 0 after int overflow
   170  			break
   171  		}
   172  		j := j1 // left child
   173  		if j2 := j1 + 1; j2 < n && p.lessFunc(*p.index(j2), *p.index(j1)) {
   174  			j = j2 // = 2*i + 2 // right child
   175  		}
   176  		if !p.lessFunc(*p.index(j), *p.index(i)) {
   177  			break
   178  		}
   179  		p1, p2 := p.index(i), p.index(j)
   180  		*p1, *p2 = *p2, *p1 // swap(i, j)
   181  		i = j
   182  	}
   183  	return i > i0
   184  }