github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/pkg/queue/queue_ms.go (about)

     1  // Copyright (c) 2023 Paweł Gaczyński
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package queue
    16  
    17  import (
    18  	"sync/atomic"
    19  	"unsafe"
    20  )
    21  
    22  type msQueue[T any] struct {
    23  	head      unsafe.Pointer
    24  	tail      unsafe.Pointer
    25  	queueSize int32
    26  }
    27  
    28  type node[T any] struct {
    29  	value *T
    30  	next  unsafe.Pointer
    31  }
    32  
    33  func NewQueue[T any]() LockFreeQueue[T] {
    34  	n := unsafe.Pointer(&node[T]{})
    35  
    36  	return &msQueue[T]{head: n, tail: n}
    37  }
    38  
    39  func NewIntQueue() LockFreeQueue[int] {
    40  	n := unsafe.Pointer(&node[int]{})
    41  
    42  	return &msQueue[int]{head: n, tail: n}
    43  }
    44  
    45  func (q *msQueue[T]) Enqueue(value T) {
    46  	node := &node[T]{value: &value}
    47  
    48  	for {
    49  		var (
    50  			tail = load[T](&q.tail)
    51  			next = load[T](&tail.next)
    52  		)
    53  
    54  		if tail == load[T](&q.tail) {
    55  			if next == nil {
    56  				if cas(&tail.next, next, node) {
    57  					cas(&q.tail, tail, node)
    58  					atomic.AddInt32(&q.queueSize, 1)
    59  
    60  					return
    61  				}
    62  			} else {
    63  				cas(&q.tail, tail, next)
    64  			}
    65  		}
    66  	}
    67  }
    68  
    69  func (q *msQueue[T]) Dequeue() T {
    70  	for {
    71  		var (
    72  			head = load[T](&q.head)
    73  			tail = load[T](&q.tail)
    74  			next = load[T](&head.next)
    75  		)
    76  
    77  		if head == load[T](&q.head) {
    78  			if head == tail {
    79  				if next == nil {
    80  					return getZero[T]()
    81  				}
    82  
    83  				cas(&q.tail, tail, next)
    84  			} else {
    85  				value := *next.value
    86  				if cas(&q.head, head, next) {
    87  					atomic.AddInt32(&q.queueSize, -1)
    88  
    89  					return value
    90  				}
    91  			}
    92  		}
    93  	}
    94  }
    95  
    96  func (q *msQueue[T]) IsEmpty() bool {
    97  	return atomic.LoadInt32(&q.queueSize) == 0
    98  }
    99  
   100  func (q *msQueue[T]) Size() int32 {
   101  	return atomic.LoadInt32(&q.queueSize)
   102  }
   103  
   104  func load[T any](p *unsafe.Pointer) *node[T] {
   105  	return (*node[T])(atomic.LoadPointer(p))
   106  }
   107  
   108  func cas[T any](p *unsafe.Pointer, oldNode, newNode *node[T]) bool {
   109  	return atomic.CompareAndSwapPointer(p, unsafe.Pointer(oldNode), unsafe.Pointer(newNode))
   110  }
   111  
   112  func getZero[T any]() T {
   113  	var result T
   114  
   115  	return result
   116  }