github.com/blong14/gache@v0.0.0-20240124023949-89416fd8bbfa/internal/db/memtable/map.go (about) 1 package memtable 2 3 import ( 4 "errors" 5 "fmt" 6 "strings" 7 "sync/atomic" 8 "unsafe" 9 ) 10 11 // RandUint32 returns a lock free uint32 value. 12 // 13 //go:linkname RandUint32 runtime.fastrand 14 func RandUint32() uint32 15 16 func hash(key []byte) uint64 { 17 var h uint64 18 for _, b := range key { 19 h = uint64(b) + (h << 6) + (h << 16) - h 20 } 21 return h 22 } 23 24 type SkipList struct { 25 head *index 26 count uint64 27 } 28 29 func NewSkipList() *SkipList { 30 return &SkipList{} 31 } 32 33 func (sk *SkipList) top() *index { 34 if sk == nil { 35 return nil 36 } 37 return (*index)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&sk.head)))) 38 } 39 40 func (sk *SkipList) findPredecessor(key uint64) *node { 41 q := sk.top() 42 for q != nil { 43 r := q.Right() 44 loop: 45 for r != nil { 46 p := r.Node() 47 switch { 48 case p == nil || p.hash == 0 || p.val == nil: 49 atomic.CompareAndSwapPointer( 50 (*unsafe.Pointer)(unsafe.Pointer(&q.right)), 51 unsafe.Pointer(r), 52 unsafe.Pointer(r.Right()), 53 ) 54 case key > p.hash: 55 q = r 56 r = q.Right() 57 default: 58 break loop 59 } 60 } 61 d := q.Down() 62 if d == nil { 63 return q.Node() 64 } 65 q = d 66 } 67 return nil 68 } 69 70 func (sk *SkipList) findNode(key uint64) *node { 71 r := sk.findPredecessor(key) 72 for r != nil { 73 n := r.Next() 74 for n != nil { 75 switch { 76 case key > n.hash: 77 r = n 78 n = r.Next() 79 case key == n.hash: 80 return n 81 default: 82 return nil 83 } 84 } 85 } 86 return nil 87 } 88 89 func (sk *SkipList) addIndices(q *index, skips int, x *index) bool { 90 if x != nil && q != nil { 91 z := x.Node() 92 key := z.hash 93 if key == 0 { 94 return false 95 } 96 var retrying bool 97 loop: 98 for { 99 c := -1 100 r := q.Right() 101 if r != nil { 102 p := r.Node() 103 switch { 104 case p == nil || p.hash == 0 || p.val == nil: 105 atomic.CompareAndSwapPointer( 106 (*unsafe.Pointer)(unsafe.Pointer(&q.right)), 107 unsafe.Pointer(r), 108 unsafe.Pointer(r.Right()), 109 ) 110 c = 0 111 case key > p.hash: 112 q = r 113 r = q.Right() 114 c = 1 115 case key == p.hash: 116 c = 0 117 default: 118 } 119 if c == 0 { 120 break 121 } 122 } else { 123 c = -1 124 } 125 if c < 0 { 126 d := q.Down() 127 switch { 128 case d != nil && skips > 0: 129 skips -= 1 130 q = d 131 case d != nil && !retrying && !sk.addIndices(d, 0, x.Down()): 132 break loop 133 default: 134 x.right = r 135 if atomic.CompareAndSwapPointer( 136 (*unsafe.Pointer)(unsafe.Pointer(&q.right)), 137 unsafe.Pointer(r), 138 unsafe.Pointer(x), 139 ) { 140 return true 141 } else { 142 retrying = true 143 } 144 } 145 } 146 } 147 } 148 return false 149 } 150 151 func (sk *SkipList) Get(key []byte) ([]byte, bool) { 152 hashedValue := hash(key) 153 if hashedValue == 0 { 154 return nil, false 155 } 156 q := sk.top() 157 for q != nil { 158 r := q.Right() 159 loop: 160 for r != nil { 161 p := r.Node() 162 switch { 163 case p == nil || p.hash == 0 || p.val == nil: 164 atomic.CompareAndSwapPointer( 165 (*unsafe.Pointer)(unsafe.Pointer(&q.right)), 166 unsafe.Pointer(r), 167 unsafe.Pointer(r.Right()), 168 ) 169 case hashedValue > p.hash: 170 q = r 171 r = q.Right() 172 case hashedValue == p.hash: 173 return p.val, true 174 default: 175 break loop 176 } 177 } 178 d := q.Down() 179 if d != nil { 180 q = d 181 } else { 182 b := q.Node() 183 if b != nil { 184 n := b.Next() 185 for n != nil { 186 if n.val == nil || n.hash == 0 || hashedValue > n.hash { 187 b = n 188 n = b.Next() 189 } else { 190 if hashedValue == n.hash { 191 return n.val, true 192 } 193 break 194 } 195 } 196 } 197 break 198 } 199 } 200 return nil, false 201 } 202 203 func (sk *SkipList) Set(key, value []byte) error { 204 if key == nil { 205 return errors.New("missing key") 206 } 207 var b *node 208 hashedKey := hash(key) 209 for { 210 levels := 0 211 h := sk.top() 212 if h == nil { 213 base := newNode(0, nil, nil, nil) 214 nh := newIndex(base, nil, nil) 215 if atomic.CompareAndSwapPointer( 216 (*unsafe.Pointer)(unsafe.Pointer(&sk.head)), 217 unsafe.Pointer(h), 218 unsafe.Pointer(nh), 219 ) { 220 b = base 221 h = nh 222 } else { 223 b = nil 224 } 225 } else { 226 q := h 227 for q != nil { 228 r := q.Right() 229 loop: 230 for r != nil { 231 p := r.Node() 232 switch { 233 case p == nil || p.hash == 0 || p.val == nil: 234 atomic.CompareAndSwapPointer( 235 (*unsafe.Pointer)(unsafe.Pointer(&q.right)), 236 unsafe.Pointer(r), 237 unsafe.Pointer(r.Right()), 238 ) 239 case hashedKey > p.hash: 240 q = r 241 r = q.Right() 242 default: 243 break loop 244 } 245 } 246 if q != nil { 247 d := q.Down() 248 if d != nil { 249 levels += 1 250 q = d 251 } else { 252 b = q.Node() 253 break 254 } 255 } 256 } 257 } 258 if b != nil { 259 var z *node 260 var p *node 261 for { 262 c := -1 263 n := b.Next() 264 switch { 265 case n == nil: 266 c = -1 267 case n.hash == 0: 268 break 269 case n.val == nil: 270 // unlinkNode(b, n) 271 c = 1 272 case hashedKey > n.hash: 273 b = n 274 c = 1 275 case hashedKey == n.hash: 276 c = 0 277 default: 278 } 279 if c == 0 { 280 // already in list 281 return nil 282 } 283 if c < 0 { 284 if p == nil { 285 p = newNode(hashedKey, key, value, nil) 286 } 287 p.next = n 288 if atomic.CompareAndSwapPointer( 289 (*unsafe.Pointer)(unsafe.Pointer(&b.next)), 290 unsafe.Pointer(n), 291 unsafe.Pointer(p), 292 ) { 293 z = p 294 break 295 } 296 } 297 } 298 if z != nil { 299 lr := uint64(RandUint32()) 300 if (lr & 0x3) == 0 { 301 hr := uint64(RandUint32()) 302 rnd := hr<<32 | lr&0xffffffff 303 skips := levels 304 var x *index 305 for { 306 skips -= 1 307 x = newIndex(z, x, nil) 308 if rnd <= 0 || skips < 0 { 309 break 310 } else { 311 rnd >>= 1 312 } 313 } 314 if sk.addIndices(h, skips, x) && skips < 0 && sk.top() == h { 315 hx := newIndex(z, x, nil) 316 nh := newIndex(h.Node(), h, hx) 317 atomic.CompareAndSwapPointer( 318 (*unsafe.Pointer)(unsafe.Pointer(&sk.head)), 319 unsafe.Pointer(h), 320 unsafe.Pointer(nh), 321 ) 322 } 323 if z.val == nil { 324 sk.findPredecessor(hashedKey) 325 } 326 } 327 atomic.AddUint64(&sk.count, 1) 328 return nil 329 } 330 } 331 } 332 } 333 334 func (sk *SkipList) Remove(_ uint64) ([]byte, bool) { 335 return nil, true 336 } 337 338 func (sk *SkipList) Range(f func(k, v []byte) bool) { 339 h := sk.top() 340 if h == nil || h.Node() == nil { 341 return 342 } 343 b := h.Node() 344 if b != nil { 345 n := b.Next() 346 for n != nil { 347 if n.val != nil { 348 ok := f(n.key, n.val) 349 if !ok { 350 break 351 } 352 } 353 b = n 354 n = b.Next() 355 } 356 } 357 } 358 359 type iter struct { 360 sk *SkipList 361 lastReturned *node 362 nxt *node 363 start *uint64 364 end *uint64 365 } 366 367 func newIter(sk *SkipList, start, end []byte) *iter { 368 var s *uint64 369 if start != nil { 370 h := hash(start) 371 s = &h 372 } 373 var e *uint64 374 if end != nil { 375 h := hash(end) 376 e = &h 377 } 378 i := &iter{sk: sk, start: s, end: e} 379 h := i.sk.top() 380 if h != nil { 381 n := h.Node() 382 i.advance(n) 383 } 384 return i 385 } 386 387 func (i *iter) advance(b *node) { 388 var n *node 389 i.lastReturned = b 390 if i.lastReturned != nil { 391 for n = b.Next(); n != nil && n.val == nil; { 392 b = n 393 n = b.Next() 394 } 395 } 396 if i.start != nil && n != nil && *i.start > n.hash { 397 n = i.sk.findNode(*i.start) 398 } 399 i.nxt = n 400 } 401 402 func (i *iter) hasNext() bool { 403 if i.end == nil { 404 return i.nxt != nil 405 } 406 return i.nxt != nil && i.nxt.hash <= *i.end 407 } 408 409 func (i *iter) next() *node { 410 n := i.nxt 411 i.advance(n) 412 return n 413 } 414 415 func (sk *SkipList) Scan(start, end []byte, f func(k, v []byte) bool) { 416 itr := newIter(sk, start, end) 417 for itr.hasNext() { 418 n := itr.next() 419 if ok := f(n.key, n.val); !ok { 420 return 421 } 422 } 423 } 424 425 func (sk *SkipList) Print() { 426 out := strings.Builder{} 427 out.Reset() 428 curr := sk.top() 429 d := curr.Down() 430 for curr != nil { 431 r := curr.Right() 432 for r != nil { 433 n := r.Node() 434 out.WriteString(fmt.Sprintf("[%d - %s->]\t", n.hash, n.key)) 435 curr = r 436 r = curr.Right() 437 } 438 if d.Down() != nil { 439 curr = d 440 d = d.Down() 441 out.WriteString("\n") 442 } else { 443 out.WriteString("\n") 444 curr = d 445 for curr != nil { 446 n := curr.Node() 447 for n != nil { 448 if n.hash == curr.Node().hash { 449 out.WriteString(fmt.Sprintf("[%d-%s->] ", n.hash, n.key)) 450 } else { 451 out.WriteString(fmt.Sprintf("%s-> ", n.key)) 452 } 453 n = n.Next() 454 } 455 curr = r 456 if curr != nil { 457 r = curr.Right() 458 } 459 } 460 break 461 } 462 } 463 fmt.Println(out.String()) 464 } 465 466 func (sk *SkipList) Count() uint64 { 467 return atomic.LoadUint64(&sk.count) 468 }