github.com/songzhibin97/go-baseutils@v0.0.2-0.20240302024150-487d8ce9c082/structure/sets/skipset/skipset.go (about)

     1  // Package skipset is a high-performance, scalable, concurrent-safe set based on skip-list.
     2  // In the typical pattern(100000 operations, 90%CONTAINS 9%Add 1%Remove, 8C16T), the skipset
     3  // up to 15x faster than the built-in sync.Map.
     4  package skipset
     5  
     6  import (
     7  	"fmt"
     8  	"strings"
     9  	"sync"
    10  	"sync/atomic"
    11  	"unsafe"
    12  
    13  	"github.com/songzhibin97/go-baseutils/base/bcomparator"
    14  	"github.com/songzhibin97/go-baseutils/structure/sets"
    15  )
    16  
    17  // Assert Set implementation
    18  var _ sets.Set[any] = (*Set[any])(nil)
    19  
    20  type Set[E any] struct {
    21  	header       *node[E]
    22  	length       int64
    23  	highestLevel int64 // highest level for now
    24  	comparator   bcomparator.Comparator[E]
    25  }
    26  
    27  type node[E any] struct {
    28  	value E
    29  	next  optionalArray // [level]*node[E]
    30  	mu    sync.Mutex
    31  	flags bitflag
    32  	level uint32
    33  }
    34  
    35  func newNode[E any](e E, level int) *node[E] {
    36  	node := &node[E]{
    37  		value: e,
    38  		level: uint32(level),
    39  	}
    40  	if level > op1 {
    41  		node.next.extra = new([op2]unsafe.Pointer)
    42  	}
    43  	return node
    44  }
    45  
    46  func (n *node[E]) loadNext(i int) *node[E] {
    47  	return (*node[E])(n.next.load(i))
    48  }
    49  
    50  func (n *node[E]) storeNext(i int, node *node[E]) {
    51  	n.next.store(i, unsafe.Pointer(node))
    52  }
    53  
    54  func (n *node[E]) atomicLoadNext(i int) *node[E] {
    55  	return (*node[E])(n.next.atomicLoad(i))
    56  }
    57  
    58  func (n *node[E]) atomicStoreNext(i int, node *node[E]) {
    59  	n.next.atomicStore(i, unsafe.Pointer(node))
    60  }
    61  
    62  func (n *node[E]) lessthan(value E, comparable bcomparator.Comparator[E]) bool {
    63  	return comparable(n.value, value) < 0
    64  }
    65  
    66  func (n *node[E]) equal(value E, comparable bcomparator.Comparator[E]) bool {
    67  	return comparable(n.value, value) == 0
    68  }
    69  
    70  func New[E any](comparator bcomparator.Comparator[E]) *Set[E] {
    71  	var zero E
    72  	h := newNode[E](zero, maxLevel)
    73  	h.flags.SetTrue(fullyLinked)
    74  	return &Set[E]{
    75  		header:       h,
    76  		highestLevel: defaultHighestLevel,
    77  		comparator:   comparator,
    78  	}
    79  }
    80  
    81  // findNodeRemove takes a value and two maximal-height arrays then searches exactly as in a sequential skip-list.
    82  // The returned preds and succs always satisfy preds[i] > value >= succs[i].
    83  func (s *Set[E]) findNodeRemove(value E, preds *[maxLevel]*node[E], succs *[maxLevel]*node[E]) int {
    84  	// lFound represents the index of the first layer at which it found a node.
    85  	lFound, x := -1, s.header
    86  	for i := int(atomic.LoadInt64(&s.highestLevel)) - 1; i >= 0; i-- {
    87  		succ := x.atomicLoadNext(i)
    88  		for succ != nil && succ.lessthan(value, s.comparator) {
    89  			x = succ
    90  			succ = x.atomicLoadNext(i)
    91  		}
    92  		preds[i] = x
    93  		succs[i] = succ
    94  
    95  		// Check if the value already in the skip list.
    96  		if lFound == -1 && succ != nil && succ.equal(value, s.comparator) {
    97  			lFound = i
    98  		}
    99  	}
   100  	return lFound
   101  }
   102  
   103  // findNodeAdd takes a value and two maximal-height arrays then searches exactly as in a sequential skip-set.
   104  // The returned preds and succs always satisfy preds[i] > value >= succs[i].
   105  func (s *Set[E]) findNodeAdd(value E, preds *[maxLevel]*node[E], succs *[maxLevel]*node[E]) int {
   106  	x := s.header
   107  	for i := int(atomic.LoadInt64(&s.highestLevel)) - 1; i >= 0; i-- {
   108  		succ := x.atomicLoadNext(i)
   109  		for succ != nil && succ.lessthan(value, s.comparator) {
   110  			x = succ
   111  			succ = x.atomicLoadNext(i)
   112  		}
   113  		preds[i] = x
   114  		succs[i] = succ
   115  
   116  		// Check if the value already in the skip list.
   117  		if succ != nil && succ.equal(value, s.comparator) {
   118  			return i
   119  		}
   120  	}
   121  	return -1
   122  }
   123  
   124  func unlockInt64[E any](preds [maxLevel]*node[E], highestLevel int) {
   125  	var prevPred *node[E]
   126  	for i := highestLevel; i >= 0; i-- {
   127  		if preds[i] != prevPred { // the node could be unlocked by previous loop
   128  			preds[i].mu.Unlock()
   129  			prevPred = preds[i]
   130  		}
   131  	}
   132  }
   133  
   134  // AddB add the value into skip set, return true if this process insert the value into skip set,
   135  // return false if this process can't insert this value, because another process has insert the same value.
   136  //
   137  // If the value is in the skip set but not fully linked, this process will wait until it is.
   138  func (s *Set[E]) AddB(value E) bool {
   139  	level := s.randomLevel()
   140  	var preds, succs [maxLevel]*node[E]
   141  	for {
   142  		lFound := s.findNodeAdd(value, &preds, &succs)
   143  		if lFound != -1 { // indicating the value is already in the skip-list
   144  			nodeFound := succs[lFound]
   145  			if !nodeFound.flags.Get(marked) {
   146  				for !nodeFound.flags.Get(fullyLinked) {
   147  					// The node is not yet fully linked, just waits until it is.
   148  				}
   149  				return false
   150  			}
   151  			// If the node is marked, represents some other thread is in the process of deleting this node,
   152  			// we need to add this node in next loop.
   153  			continue
   154  		}
   155  		// Add this node into skip list.
   156  		var (
   157  			highestLocked        = -1 // the highest level being locked by this process
   158  			valid                = true
   159  			pred, succ, prevPred *node[E]
   160  		)
   161  		for layer := 0; valid && layer < level; layer++ {
   162  			pred = preds[layer]   // target node's previous node
   163  			succ = succs[layer]   // target node's next node
   164  			if pred != prevPred { // the node in this layer could be locked by previous loop
   165  				pred.mu.Lock()
   166  				highestLocked = layer
   167  				prevPred = pred
   168  			}
   169  			// valid check if there is another node has inserted into the skip list in this layer during this process.
   170  			// It is valid if:
   171  			// 1. The previous node and next node both are not marked.
   172  			// 2. The previous node's next node is succ in this layer.
   173  			valid = !pred.flags.Get(marked) && (succ == nil || !succ.flags.Get(marked)) && pred.loadNext(layer) == succ
   174  		}
   175  		if !valid {
   176  			unlockInt64(preds, highestLocked)
   177  			continue
   178  		}
   179  
   180  		nn := newNode[E](value, level)
   181  		for layer := 0; layer < level; layer++ {
   182  			nn.storeNext(layer, succs[layer])
   183  			preds[layer].atomicStoreNext(layer, nn)
   184  		}
   185  		nn.flags.SetTrue(fullyLinked)
   186  		unlockInt64(preds, highestLocked)
   187  		atomic.AddInt64(&s.length, 1)
   188  		return true
   189  	}
   190  }
   191  
   192  func (s *Set[E]) Add(values ...E) {
   193  	for _, value := range values {
   194  		s.AddB(value)
   195  	}
   196  }
   197  
   198  func (s *Set[E]) randomLevel() int {
   199  	// Generate random level.
   200  	level := randomLevel()
   201  	// Update highest level if possible.
   202  	for {
   203  		hl := atomic.LoadInt64(&s.highestLevel)
   204  		if int64(level) <= hl {
   205  			break
   206  		}
   207  		if atomic.CompareAndSwapInt64(&s.highestLevel, hl, int64(level)) {
   208  			break
   209  		}
   210  	}
   211  	return level
   212  }
   213  
   214  // ContainsB check if the value is in the skip set.
   215  func (s *Set[E]) ContainsB(value E) bool {
   216  	x := s.header
   217  	for i := int(atomic.LoadInt64(&s.highestLevel)) - 1; i >= 0; i-- {
   218  		nex := x.atomicLoadNext(i)
   219  		for nex != nil && nex.lessthan(value, s.comparator) {
   220  			x = nex
   221  			nex = x.atomicLoadNext(i)
   222  		}
   223  
   224  		// Check if the value already in the skip list.
   225  		if nex != nil && nex.equal(value, s.comparator) {
   226  			return nex.flags.MGet(fullyLinked|marked, fullyLinked)
   227  		}
   228  	}
   229  	return false
   230  }
   231  
   232  func (s *Set[E]) Contains(values ...E) bool {
   233  	for _, value := range values {
   234  		if !s.ContainsB(value) {
   235  			return false
   236  		}
   237  	}
   238  	return true
   239  }
   240  
   241  // RemoveB a node from the skip set.
   242  func (s *Set[E]) RemoveB(value E) bool {
   243  	var (
   244  		nodeToRemove *node[E]
   245  		isMarked     bool // represents if this operation mark the node
   246  		topLayer     = -1
   247  		preds, succs [maxLevel]*node[E]
   248  	)
   249  	for {
   250  		lFound := s.findNodeRemove(value, &preds, &succs)
   251  		if isMarked || // this process mark this node or we can find this node in the skip list
   252  			lFound != -1 && succs[lFound].flags.MGet(fullyLinked|marked, fullyLinked) && (int(succs[lFound].level)-1) == lFound {
   253  			if !isMarked { // we don't mark this node for now
   254  				nodeToRemove = succs[lFound]
   255  				topLayer = lFound
   256  				nodeToRemove.mu.Lock()
   257  				if nodeToRemove.flags.Get(marked) {
   258  					// The node is marked by another process,
   259  					// the physical deletion will be accomplished by another process.
   260  					nodeToRemove.mu.Unlock()
   261  					return false
   262  				}
   263  				nodeToRemove.flags.SetTrue(marked)
   264  				isMarked = true
   265  			}
   266  			// Accomplish the physical deletion.
   267  			var (
   268  				highestLocked        = -1 // the highest level being locked by this process
   269  				valid                = true
   270  				pred, succ, prevPred *node[E]
   271  			)
   272  			for layer := 0; valid && (layer <= topLayer); layer++ {
   273  				pred, succ = preds[layer], succs[layer]
   274  				if pred != prevPred { // the node in this layer could be locked by previous loop
   275  					pred.mu.Lock()
   276  					highestLocked = layer
   277  					prevPred = pred
   278  				}
   279  				// valid check if there is another node has inserted into the skip list in this layer
   280  				// during this process, or the previous is removed by another process.
   281  				// It is valid if:
   282  				// 1. the previous node exists.
   283  				// 2. no another node has inserted into the skip list in this layer.
   284  				valid = !pred.flags.Get(marked) && pred.loadNext(layer) == succ
   285  			}
   286  			if !valid {
   287  				unlockInt64(preds, highestLocked)
   288  				continue
   289  			}
   290  			for i := topLayer; i >= 0; i-- {
   291  				// Now we own the `nodeToRemove`, no other goroutine will modify it.
   292  				// So we don't need `nodeToRemove.loadNext`
   293  				preds[i].atomicStoreNext(i, nodeToRemove.loadNext(i))
   294  			}
   295  			nodeToRemove.mu.Unlock()
   296  			unlockInt64(preds, highestLocked)
   297  			atomic.AddInt64(&s.length, -1)
   298  			return true
   299  		}
   300  		return false
   301  	}
   302  }
   303  
   304  func (s *Set[E]) Remove(values ...E) {
   305  	for _, value := range values {
   306  		s.RemoveB(value)
   307  	}
   308  }
   309  
   310  // Range calls f sequentially for each value present in the skip set.
   311  // If f returns false, range stops the iteration.
   312  func (s *Set[E]) Range(f func(value E) bool) {
   313  	x := s.header.atomicLoadNext(0)
   314  	for x != nil {
   315  		if !x.flags.MGet(fullyLinked|marked, fullyLinked) {
   316  			x = x.atomicLoadNext(0)
   317  			continue
   318  		}
   319  		if !f(x.value) {
   320  			break
   321  		}
   322  		x = x.atomicLoadNext(0)
   323  	}
   324  }
   325  
   326  // Len return the length of this skip set.
   327  func (s *Set[E]) Len() int {
   328  	return int(atomic.LoadInt64(&s.length))
   329  }
   330  
   331  func (s *Set[E]) Empty() bool {
   332  	return s.Len() == 0
   333  }
   334  
   335  func (s *Set[E]) Size() int {
   336  	return s.Len()
   337  }
   338  
   339  func (s *Set[E]) Clear() {
   340  	var zero E
   341  	h := newNode[E](zero, maxLevel)
   342  	h.flags.SetTrue(fullyLinked)
   343  	s.header = h
   344  	s.highestLevel = defaultHighestLevel
   345  }
   346  
   347  func (s *Set[E]) Values() []E {
   348  	ln := s.Len()
   349  	vals := make([]E, 0, ln)
   350  	s.Range(func(e E) bool {
   351  		vals = append(vals, e)
   352  		return true
   353  	})
   354  	return vals
   355  }
   356  
   357  func (s *Set[E]) String() string {
   358  	b := strings.Builder{}
   359  	b.WriteString("SkipSet\n")
   360  	for val := range s.Values() {
   361  		b.WriteString(fmt.Sprintf("(key:%v) ", val))
   362  	}
   363  	return b.String()
   364  }