github.com/blong14/gache@v0.0.0-20240124023949-89416fd8bbfa/internal/db/memtable/map.go (about)

     1  package memtable
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"strings"
     7  	"sync/atomic"
     8  	"unsafe"
     9  )
    10  
    11  // RandUint32 returns a lock free uint32 value.
    12  //
    13  //go:linkname RandUint32 runtime.fastrand
    14  func RandUint32() uint32
    15  
    16  func hash(key []byte) uint64 {
    17  	var h uint64
    18  	for _, b := range key {
    19  		h = uint64(b) + (h << 6) + (h << 16) - h
    20  	}
    21  	return h
    22  }
    23  
    24  type SkipList struct {
    25  	head  *index
    26  	count uint64
    27  }
    28  
    29  func NewSkipList() *SkipList {
    30  	return &SkipList{}
    31  }
    32  
    33  func (sk *SkipList) top() *index {
    34  	if sk == nil {
    35  		return nil
    36  	}
    37  	return (*index)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&sk.head))))
    38  }
    39  
    40  func (sk *SkipList) findPredecessor(key uint64) *node {
    41  	q := sk.top()
    42  	for q != nil {
    43  		r := q.Right()
    44  	loop:
    45  		for r != nil {
    46  			p := r.Node()
    47  			switch {
    48  			case p == nil || p.hash == 0 || p.val == nil:
    49  				atomic.CompareAndSwapPointer(
    50  					(*unsafe.Pointer)(unsafe.Pointer(&q.right)),
    51  					unsafe.Pointer(r),
    52  					unsafe.Pointer(r.Right()),
    53  				)
    54  			case key > p.hash:
    55  				q = r
    56  				r = q.Right()
    57  			default:
    58  				break loop
    59  			}
    60  		}
    61  		d := q.Down()
    62  		if d == nil {
    63  			return q.Node()
    64  		}
    65  		q = d
    66  	}
    67  	return nil
    68  }
    69  
    70  func (sk *SkipList) findNode(key uint64) *node {
    71  	r := sk.findPredecessor(key)
    72  	for r != nil {
    73  		n := r.Next()
    74  		for n != nil {
    75  			switch {
    76  			case key > n.hash:
    77  				r = n
    78  				n = r.Next()
    79  			case key == n.hash:
    80  				return n
    81  			default:
    82  				return nil
    83  			}
    84  		}
    85  	}
    86  	return nil
    87  }
    88  
    89  func (sk *SkipList) addIndices(q *index, skips int, x *index) bool {
    90  	if x != nil && q != nil {
    91  		z := x.Node()
    92  		key := z.hash
    93  		if key == 0 {
    94  			return false
    95  		}
    96  		var retrying bool
    97  	loop:
    98  		for {
    99  			c := -1
   100  			r := q.Right()
   101  			if r != nil {
   102  				p := r.Node()
   103  				switch {
   104  				case p == nil || p.hash == 0 || p.val == nil:
   105  					atomic.CompareAndSwapPointer(
   106  						(*unsafe.Pointer)(unsafe.Pointer(&q.right)),
   107  						unsafe.Pointer(r),
   108  						unsafe.Pointer(r.Right()),
   109  					)
   110  					c = 0
   111  				case key > p.hash:
   112  					q = r
   113  					r = q.Right()
   114  					c = 1
   115  				case key == p.hash:
   116  					c = 0
   117  				default:
   118  				}
   119  				if c == 0 {
   120  					break
   121  				}
   122  			} else {
   123  				c = -1
   124  			}
   125  			if c < 0 {
   126  				d := q.Down()
   127  				switch {
   128  				case d != nil && skips > 0:
   129  					skips -= 1
   130  					q = d
   131  				case d != nil && !retrying && !sk.addIndices(d, 0, x.Down()):
   132  					break loop
   133  				default:
   134  					x.right = r
   135  					if atomic.CompareAndSwapPointer(
   136  						(*unsafe.Pointer)(unsafe.Pointer(&q.right)),
   137  						unsafe.Pointer(r),
   138  						unsafe.Pointer(x),
   139  					) {
   140  						return true
   141  					} else {
   142  						retrying = true
   143  					}
   144  				}
   145  			}
   146  		}
   147  	}
   148  	return false
   149  }
   150  
   151  func (sk *SkipList) Get(key []byte) ([]byte, bool) {
   152  	hashedValue := hash(key)
   153  	if hashedValue == 0 {
   154  		return nil, false
   155  	}
   156  	q := sk.top()
   157  	for q != nil {
   158  		r := q.Right()
   159  	loop:
   160  		for r != nil {
   161  			p := r.Node()
   162  			switch {
   163  			case p == nil || p.hash == 0 || p.val == nil:
   164  				atomic.CompareAndSwapPointer(
   165  					(*unsafe.Pointer)(unsafe.Pointer(&q.right)),
   166  					unsafe.Pointer(r),
   167  					unsafe.Pointer(r.Right()),
   168  				)
   169  			case hashedValue > p.hash:
   170  				q = r
   171  				r = q.Right()
   172  			case hashedValue == p.hash:
   173  				return p.val, true
   174  			default:
   175  				break loop
   176  			}
   177  		}
   178  		d := q.Down()
   179  		if d != nil {
   180  			q = d
   181  		} else {
   182  			b := q.Node()
   183  			if b != nil {
   184  				n := b.Next()
   185  				for n != nil {
   186  					if n.val == nil || n.hash == 0 || hashedValue > n.hash {
   187  						b = n
   188  						n = b.Next()
   189  					} else {
   190  						if hashedValue == n.hash {
   191  							return n.val, true
   192  						}
   193  						break
   194  					}
   195  				}
   196  			}
   197  			break
   198  		}
   199  	}
   200  	return nil, false
   201  }
   202  
   203  func (sk *SkipList) Set(key, value []byte) error {
   204  	if key == nil {
   205  		return errors.New("missing key")
   206  	}
   207  	var b *node
   208  	hashedKey := hash(key)
   209  	for {
   210  		levels := 0
   211  		h := sk.top()
   212  		if h == nil {
   213  			base := newNode(0, nil, nil, nil)
   214  			nh := newIndex(base, nil, nil)
   215  			if atomic.CompareAndSwapPointer(
   216  				(*unsafe.Pointer)(unsafe.Pointer(&sk.head)),
   217  				unsafe.Pointer(h),
   218  				unsafe.Pointer(nh),
   219  			) {
   220  				b = base
   221  				h = nh
   222  			} else {
   223  				b = nil
   224  			}
   225  		} else {
   226  			q := h
   227  			for q != nil {
   228  				r := q.Right()
   229  			loop:
   230  				for r != nil {
   231  					p := r.Node()
   232  					switch {
   233  					case p == nil || p.hash == 0 || p.val == nil:
   234  						atomic.CompareAndSwapPointer(
   235  							(*unsafe.Pointer)(unsafe.Pointer(&q.right)),
   236  							unsafe.Pointer(r),
   237  							unsafe.Pointer(r.Right()),
   238  						)
   239  					case hashedKey > p.hash:
   240  						q = r
   241  						r = q.Right()
   242  					default:
   243  						break loop
   244  					}
   245  				}
   246  				if q != nil {
   247  					d := q.Down()
   248  					if d != nil {
   249  						levels += 1
   250  						q = d
   251  					} else {
   252  						b = q.Node()
   253  						break
   254  					}
   255  				}
   256  			}
   257  		}
   258  		if b != nil {
   259  			var z *node
   260  			var p *node
   261  			for {
   262  				c := -1
   263  				n := b.Next()
   264  				switch {
   265  				case n == nil:
   266  					c = -1
   267  				case n.hash == 0:
   268  					break
   269  				case n.val == nil:
   270  					// unlinkNode(b, n)
   271  					c = 1
   272  				case hashedKey > n.hash:
   273  					b = n
   274  					c = 1
   275  				case hashedKey == n.hash:
   276  					c = 0
   277  				default:
   278  				}
   279  				if c == 0 {
   280  					// already in list
   281  					return nil
   282  				}
   283  				if c < 0 {
   284  					if p == nil {
   285  						p = newNode(hashedKey, key, value, nil)
   286  					}
   287  					p.next = n
   288  					if atomic.CompareAndSwapPointer(
   289  						(*unsafe.Pointer)(unsafe.Pointer(&b.next)),
   290  						unsafe.Pointer(n),
   291  						unsafe.Pointer(p),
   292  					) {
   293  						z = p
   294  						break
   295  					}
   296  				}
   297  			}
   298  			if z != nil {
   299  				lr := uint64(RandUint32())
   300  				if (lr & 0x3) == 0 {
   301  					hr := uint64(RandUint32())
   302  					rnd := hr<<32 | lr&0xffffffff
   303  					skips := levels
   304  					var x *index
   305  					for {
   306  						skips -= 1
   307  						x = newIndex(z, x, nil)
   308  						if rnd <= 0 || skips < 0 {
   309  							break
   310  						} else {
   311  							rnd >>= 1
   312  						}
   313  					}
   314  					if sk.addIndices(h, skips, x) && skips < 0 && sk.top() == h {
   315  						hx := newIndex(z, x, nil)
   316  						nh := newIndex(h.Node(), h, hx)
   317  						atomic.CompareAndSwapPointer(
   318  							(*unsafe.Pointer)(unsafe.Pointer(&sk.head)),
   319  							unsafe.Pointer(h),
   320  							unsafe.Pointer(nh),
   321  						)
   322  					}
   323  					if z.val == nil {
   324  						sk.findPredecessor(hashedKey)
   325  					}
   326  				}
   327  				atomic.AddUint64(&sk.count, 1)
   328  				return nil
   329  			}
   330  		}
   331  	}
   332  }
   333  
   334  func (sk *SkipList) Remove(_ uint64) ([]byte, bool) {
   335  	return nil, true
   336  }
   337  
   338  func (sk *SkipList) Range(f func(k, v []byte) bool) {
   339  	h := sk.top()
   340  	if h == nil || h.Node() == nil {
   341  		return
   342  	}
   343  	b := h.Node()
   344  	if b != nil {
   345  		n := b.Next()
   346  		for n != nil {
   347  			if n.val != nil {
   348  				ok := f(n.key, n.val)
   349  				if !ok {
   350  					break
   351  				}
   352  			}
   353  			b = n
   354  			n = b.Next()
   355  		}
   356  	}
   357  }
   358  
   359  type iter struct {
   360  	sk           *SkipList
   361  	lastReturned *node
   362  	nxt          *node
   363  	start        *uint64
   364  	end          *uint64
   365  }
   366  
   367  func newIter(sk *SkipList, start, end []byte) *iter {
   368  	var s *uint64
   369  	if start != nil {
   370  		h := hash(start)
   371  		s = &h
   372  	}
   373  	var e *uint64
   374  	if end != nil {
   375  		h := hash(end)
   376  		e = &h
   377  	}
   378  	i := &iter{sk: sk, start: s, end: e}
   379  	h := i.sk.top()
   380  	if h != nil {
   381  		n := h.Node()
   382  		i.advance(n)
   383  	}
   384  	return i
   385  }
   386  
   387  func (i *iter) advance(b *node) {
   388  	var n *node
   389  	i.lastReturned = b
   390  	if i.lastReturned != nil {
   391  		for n = b.Next(); n != nil && n.val == nil; {
   392  			b = n
   393  			n = b.Next()
   394  		}
   395  	}
   396  	if i.start != nil && n != nil && *i.start > n.hash {
   397  		n = i.sk.findNode(*i.start)
   398  	}
   399  	i.nxt = n
   400  }
   401  
   402  func (i *iter) hasNext() bool {
   403  	if i.end == nil {
   404  		return i.nxt != nil
   405  	}
   406  	return i.nxt != nil && i.nxt.hash <= *i.end
   407  }
   408  
   409  func (i *iter) next() *node {
   410  	n := i.nxt
   411  	i.advance(n)
   412  	return n
   413  }
   414  
   415  func (sk *SkipList) Scan(start, end []byte, f func(k, v []byte) bool) {
   416  	itr := newIter(sk, start, end)
   417  	for itr.hasNext() {
   418  		n := itr.next()
   419  		if ok := f(n.key, n.val); !ok {
   420  			return
   421  		}
   422  	}
   423  }
   424  
   425  func (sk *SkipList) Print() {
   426  	out := strings.Builder{}
   427  	out.Reset()
   428  	curr := sk.top()
   429  	d := curr.Down()
   430  	for curr != nil {
   431  		r := curr.Right()
   432  		for r != nil {
   433  			n := r.Node()
   434  			out.WriteString(fmt.Sprintf("[%d - %s->]\t", n.hash, n.key))
   435  			curr = r
   436  			r = curr.Right()
   437  		}
   438  		if d.Down() != nil {
   439  			curr = d
   440  			d = d.Down()
   441  			out.WriteString("\n")
   442  		} else {
   443  			out.WriteString("\n")
   444  			curr = d
   445  			for curr != nil {
   446  				n := curr.Node()
   447  				for n != nil {
   448  					if n.hash == curr.Node().hash {
   449  						out.WriteString(fmt.Sprintf("[%d-%s->] ", n.hash, n.key))
   450  					} else {
   451  						out.WriteString(fmt.Sprintf("%s-> ", n.key))
   452  					}
   453  					n = n.Next()
   454  				}
   455  				curr = r
   456  				if curr != nil {
   457  					r = curr.Right()
   458  				}
   459  			}
   460  			break
   461  		}
   462  	}
   463  	fmt.Println(out.String())
   464  }
   465  
   466  func (sk *SkipList) Count() uint64 {
   467  	return atomic.LoadUint64(&sk.count)
   468  }