github.com/songzhibin97/gkit@v1.2.13/internal/sys/queue/queue.go (about)

     1  package queue
     2  
     3  import (
     4  	"sync/atomic"
     5  	"unsafe"
     6  )
     7  
     8  // LKQueue lock-free的queue
     9  type LKQueue struct {
    10  	head unsafe.Pointer
    11  	tail unsafe.Pointer
    12  }
    13  
    14  // 通过链表实现,这个数据结构代表链表中的节点
    15  type node struct {
    16  	value interface{}
    17  	next  unsafe.Pointer
    18  }
    19  
    20  func NewLKQueue() *LKQueue {
    21  	n := unsafe.Pointer(&node{})
    22  	return &LKQueue{head: n, tail: n}
    23  }
    24  
    25  // Enqueue 入队
    26  func (q *LKQueue) Enqueue(v interface{}) {
    27  	n := &node{value: v}
    28  	for {
    29  		tail := load(&q.tail)
    30  		next := load(&tail.next)
    31  		if tail == load(&q.tail) { // 尾还是尾
    32  			if next == nil { // 还没有新数据入队
    33  				if cas(&tail.next, next, n) { // 增加到队尾
    34  					cas(&q.tail, tail, n) // 入队成功,移动尾巴指针
    35  					return
    36  				}
    37  			} else { // 已有新数据加到队列后面,需要移动尾指针
    38  				cas(&q.tail, tail, next)
    39  			}
    40  		}
    41  	}
    42  }
    43  
    44  // Dequeue 出队,没有元素则返回nil
    45  func (q *LKQueue) Dequeue() interface{} {
    46  	for {
    47  		head := load(&q.head)
    48  		tail := load(&q.tail)
    49  		next := load(&head.next)
    50  		if head == load(&q.head) { // head还是那个head
    51  			if head == tail { // head和tail一样
    52  				if next == nil { // 说明是空队列
    53  					return nil
    54  				}
    55  				// 只是尾指针还没有调整,尝试调整它指向下一个
    56  				cas(&q.tail, tail, next)
    57  			} else {
    58  				// 读取出队的数据
    59  				v := next.value
    60  				// 既然要出队了,头指针移动到下一个
    61  				if cas(&q.head, head, next) {
    62  					return v // Dequeue is done.  return
    63  				}
    64  			}
    65  		}
    66  	}
    67  }
    68  
    69  // 将unsafe.Pointer原子加载转换成node
    70  func load(p *unsafe.Pointer) (n *node) {
    71  	return (*node)(atomic.LoadPointer(p))
    72  }
    73  
    74  // 封装CAS,避免直接将*node转换成unsafe.Pointer
    75  func cas(p *unsafe.Pointer, old, new *node) (ok bool) {
    76  	return atomic.CompareAndSwapPointer(
    77  		p, unsafe.Pointer(old), unsafe.Pointer(new))
    78  }