github.com/blong14/gache@v0.0.0-20240124023949-89416fd8bbfa/internal/map/skiplist/map.go (about)

     1  package skiplist
     2  
     3  import (
     4  	"fmt"
     5  	"hash/maphash"
     6  	"strings"
     7  	"sync/atomic"
     8  	"unsafe"
     9  	_ "unsafe"
    10  )
    11  
    12  //go:linkname RandUint32 runtime.fastrand
    13  func RandUint32() uint32
    14  
    15  //go:linkname RandUint32n runtime.fastrandn
    16  func RandUint32n(n uint32) uint32
    17  
    18  var seed = maphash.MakeSeed()
    19  
    20  func hash(key []byte) uint64 {
    21  	var hasher maphash.Hash
    22  	hasher.SetSeed(seed)
    23  	_, _ = hasher.Write(key)
    24  	return hasher.Sum64()
    25  }
    26  
    27  const maxHeight uint8 = 20
    28  
    29  func unlock(highestLocked int, preds []*node) {
    30  	if highestLocked < 0 || highestLocked >= len(preds) {
    31  		return
    32  	}
    33  	for i := 0; i <= highestLocked; i++ {
    34  		m := preds[i]
    35  		if m == nil {
    36  			continue
    37  		}
    38  		if m.lock != nil {
    39  			select {
    40  			case <-m.lock:
    41  			default:
    42  			}
    43  		}
    44  	}
    45  }
    46  
    47  type node struct {
    48  	rawKey      []byte
    49  	value       []byte
    50  	hash        uint64
    51  	topLayer    uint8
    52  	marked      bool
    53  	fullyLinked bool
    54  	nexts       [maxHeight]*node
    55  	lock        chan struct{}
    56  }
    57  
    58  func newNode(k, v []byte) *node {
    59  	return &node{
    60  		lock:   make(chan struct{}, 1),
    61  		rawKey: k,
    62  		value:  v,
    63  		nexts:  [maxHeight]*node{},
    64  	}
    65  }
    66  
    67  func (n *node) Next(layer uint64) *node {
    68  	return (*node)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&n.nexts[layer]))))
    69  }
    70  
    71  type SkipList struct {
    72  	Sentinal  *node
    73  	MaxHeight uint8
    74  	height    uint64
    75  	count     uint64
    76  }
    77  
    78  func New() *SkipList {
    79  	return &SkipList{
    80  		Sentinal:  newNode(nil, nil),
    81  		MaxHeight: maxHeight,
    82  		height:    uint64(0),
    83  		count:     uint64(0),
    84  	}
    85  }
    86  
    87  func (sl *SkipList) search(key uint64, preds, succs []*node) int {
    88  	var curr *node
    89  	pred := sl.Sentinal
    90  	layer := int(sl.MaxHeight - 1)
    91  oloop:
    92  	curr = pred.Next(uint64(layer))
    93  iloop:
    94  	// d := -1
    95  	if curr != nil {
    96  		// d = bytes.Compare(key, curr.rawKey)
    97  		if key > curr.hash {
    98  			pred = curr
    99  			curr = pred.Next(uint64(layer))
   100  			goto iloop
   101  		}
   102  	}
   103  	preds[layer] = pred
   104  	succs[layer] = curr
   105  	if curr != nil && key == curr.hash {
   106  		return layer
   107  	}
   108  	layer--
   109  	if layer >= 0 {
   110  		goto oloop
   111  	}
   112  	return -1
   113  }
   114  
   115  func (sl *SkipList) Get(key []byte) ([]byte, bool) {
   116  	preds := make([]*node, maxHeight)
   117  	succs := make([]*node, maxHeight)
   118  	lFound := sl.search(hash(key), preds, succs)
   119  	if lFound != -1 && succs[lFound].fullyLinked && !succs[lFound].marked {
   120  		return succs[lFound].value, true
   121  	}
   122  	return nil, false
   123  }
   124  
   125  func (sl *SkipList) Set(key, value []byte) error {
   126  	topLayer := RandUint32n(uint32(maxHeight))
   127  	if topLayer == 0 {
   128  		topLayer = 1
   129  	}
   130  	k := hash(key)
   131  loop:
   132  	for {
   133  		valid := true
   134  		highestLocked := -1
   135  		preds := make([]*node, maxHeight)
   136  		succs := make([]*node, maxHeight)
   137  		locks := make([]*node, maxHeight)
   138  		lFound := sl.search(k, preds, succs)
   139  		if lFound != -1 {
   140  			nodeFound := succs[lFound]
   141  			if nodeFound != nil && !nodeFound.marked {
   142  				// item already in the list return early
   143  				return nil
   144  			}
   145  			continue
   146  		}
   147  		var prevPred *node
   148  		height := sl.Height()
   149  		for layer := uint64(0); valid && (layer <= height); layer++ {
   150  			pred := preds[layer]
   151  			if pred != nil && pred != prevPred {
   152  				select {
   153  				case pred.lock <- struct{}{}:
   154  					locks[layer] = pred
   155  					highestLocked = int(layer)
   156  					prevPred = pred
   157  				default:
   158  					unlock(highestLocked, locks)
   159  					continue loop
   160  				}
   161  			}
   162  			succ := succs[layer]
   163  			if succ != nil {
   164  				valid = !pred.marked && !succ.marked && pred.Next(layer) == succ
   165  			}
   166  		}
   167  		if !valid {
   168  			// validation failed; try again
   169  			// validation = for each layer, i <= topNodeLayer, preds[i], succs[i]
   170  			// are still adjacent at layer i and that neither is marked
   171  			unlock(highestLocked, locks)
   172  			continue
   173  		}
   174  		// at this point; this thread holds all locks
   175  		// safe to create a new node
   176  		node := newNode(key, value)
   177  		node.hash = k
   178  		node.topLayer = uint8(topLayer)
   179  		for layer := uint64(0); layer <= uint64(topLayer); layer++ {
   180  			node.nexts[layer] = succs[layer]
   181  			oldNext := preds[layer].Next(layer)
   182  			atomic.CompareAndSwapPointer(
   183  				(*unsafe.Pointer)(unsafe.Pointer(&preds[layer].nexts[layer])),
   184  				unsafe.Pointer(oldNext),
   185  				unsafe.Pointer(node),
   186  			)
   187  		}
   188  		node.fullyLinked = true
   189  		atomic.AddUint64(&sl.count, 1)
   190  		height = sl.Height()
   191  		for uint64(topLayer) > height {
   192  			if atomic.CompareAndSwapUint64(&sl.height, height, uint64(topLayer)) {
   193  				break
   194  			}
   195  			height = sl.Height()
   196  		}
   197  		unlock(highestLocked, locks)
   198  		return nil
   199  	}
   200  }
   201  
   202  func (sl *SkipList) Remove(k uint64) ([]byte, bool) {
   203  	return nil, true
   204  }
   205  
   206  func (sl *SkipList) Print() {
   207  	out := strings.Builder{}
   208  	out.Reset()
   209  	curr := sl.Sentinal
   210  	for curr != nil {
   211  		for i := uint8(0); i < sl.MaxHeight; i++ {
   212  			n := curr.Next(uint64(i))
   213  			if n != nil {
   214  				out.WriteString(fmt.Sprintf("\t(%d, %s)", n.hash, n.rawKey))
   215  			}
   216  		}
   217  		curr = curr.Next(0)
   218  		out.WriteString("\n")
   219  	}
   220  	fmt.Println(out.String())
   221  }
   222  
   223  func (sl *SkipList) Range(f func(k, v []byte) bool) {
   224  	curr := sl.Sentinal.nexts[0]
   225  	for curr != nil {
   226  		ok := f(curr.rawKey, curr.value)
   227  		curr = curr.Next(0)
   228  		if !ok || curr == nil {
   229  			break
   230  		}
   231  	}
   232  }
   233  
   234  func (sl *SkipList) Count() uint64 {
   235  	return atomic.LoadUint64(&sl.count)
   236  }
   237  
   238  func (sl *SkipList) Height() uint64 {
   239  	return atomic.LoadUint64(&sl.height)
   240  }