github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/internal/utils/tree/tree.go (about) 1 // Originated from https://github.com/ross-oreto/go-tree/blob/master/btree.go with the following changes: 2 // 1. Genericized the code 3 // 2. Added Match function for our frame sorter use case 4 // 3. Fixed a bug in deleteNode where in some cases the deleted flag was not set to true 5 6 /* 7 Copyright (c) 2017 Ross Oreto 8 9 Permission is hereby granted, free of charge, to any person obtaining a copy 10 of this software and associated documentation files (the "Software"), to deal 11 in the Software without restriction, including without limitation the rights 12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 copies of the Software, and to permit persons to whom the Software is 14 furnished to do so, subject to the following conditions: 15 16 The above copyright notice and this permission notice shall be included in all 17 copies or substantial portions of the Software. 18 19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 SOFTWARE. 26 */ 27 28 package tree 29 30 import ( 31 "fmt" 32 ) 33 34 type Val[T any] interface { 35 Comp(val T) int8 // returns 1 if > val, -1 if < val, 0 if equals to val 36 Match(cond T) int8 // returns 1 if > cond, -1 if < cond, 0 if matches cond 37 } 38 39 // Btree represents an AVL tree 40 type Btree[T Val[T]] struct { 41 root *Node[T] 42 values []T 43 len int 44 } 45 46 // Node represents a node in the tree with a value, left and right children, and a height/balance of the node. 47 type Node[T Val[T]] struct { 48 Value T 49 left, right *Node[T] 50 height int8 51 } 52 53 // New returns a new btree 54 func New[T Val[T]]() *Btree[T] { return new(Btree[T]).Init() } 55 56 // Init initializes all values/clears the tree and returns the tree pointer 57 func (t *Btree[T]) Init() *Btree[T] { 58 t.root = nil 59 t.values = nil 60 t.len = 0 61 return t 62 } 63 64 // String returns a string representation of the tree values 65 func (t *Btree[T]) String() string { 66 return fmt.Sprint(t.Values()) 67 } 68 69 // Empty returns true if the tree is empty 70 func (t *Btree[T]) Empty() bool { 71 return t.root == nil 72 } 73 74 // NotEmpty returns true if the tree is not empty 75 func (t *Btree[T]) NotEmpty() bool { 76 return t.root != nil 77 } 78 79 // Insert inserts a new value into the tree and returns the tree pointer 80 func (t *Btree[T]) Insert(value T) *Btree[T] { 81 added := false 82 t.root = insert(t.root, value, &added) 83 if added { 84 t.len++ 85 } 86 t.values = nil 87 return t 88 } 89 90 func insert[T Val[T]](n *Node[T], value T, added *bool) *Node[T] { 91 if n == nil { 92 *added = true 93 return (&Node[T]{Value: value}).Init() 94 } 95 c := value.Comp(n.Value) 96 if c > 0 { 97 n.right = insert(n.right, value, added) 98 } else if c < 0 { 99 n.left = insert(n.left, value, added) 100 } else { 101 n.Value = value 102 *added = false 103 return n 104 } 105 106 n.height = n.maxHeight() + 1 107 c = balance(n) 108 109 if c > 1 { 110 c = value.Comp(n.left.Value) 111 if c < 0 { 112 return n.rotateRight() 113 } else if c > 0 { 114 n.left = n.left.rotateLeft() 115 return n.rotateRight() 116 } 117 } else if c < -1 { 118 c = value.Comp(n.right.Value) 119 if c > 0 { 120 return n.rotateLeft() 121 } else if c < 0 { 122 n.right = n.right.rotateRight() 123 return n.rotateLeft() 124 } 125 } 126 return n 127 } 128 129 // InsertAll inserts all the values into the tree and returns the tree pointer 130 func (t *Btree[T]) InsertAll(values []T) *Btree[T] { 131 for _, v := range values { 132 t.Insert(v) 133 } 134 return t 135 } 136 137 // Contains returns true if the tree contains the specified value 138 func (t *Btree[T]) Contains(value T) bool { 139 return t.Get(value) != nil 140 } 141 142 // ContainsAny returns true if the tree contains any of the values 143 func (t *Btree[T]) ContainsAny(values []T) bool { 144 for _, v := range values { 145 if t.Contains(v) { 146 return true 147 } 148 } 149 return false 150 } 151 152 // ContainsAll returns true if the tree contains all of the values 153 func (t *Btree[T]) ContainsAll(values []T) bool { 154 for _, v := range values { 155 if !t.Contains(v) { 156 return false 157 } 158 } 159 return true 160 } 161 162 // Get returns the node value associated with the search value 163 func (t *Btree[T]) Get(value T) *T { 164 var node *Node[T] 165 if t.root != nil { 166 node = t.root.get(value) 167 } 168 if node != nil { 169 return &node.Value 170 } 171 return nil 172 } 173 174 func (t *Btree[T]) Match(cond T) []T { 175 var matches []T 176 if t.root != nil { 177 t.root.match(cond, &matches) 178 } 179 return matches 180 } 181 182 // Len return the number of nodes in the tree 183 func (t *Btree[T]) Len() int { 184 return t.len 185 } 186 187 // Head returns the first value in the tree 188 func (t *Btree[T]) Head() *T { 189 if t.root == nil { 190 return nil 191 } 192 beginning := t.root 193 for beginning.left != nil { 194 beginning = beginning.left 195 } 196 if beginning == nil { 197 for beginning.right != nil { 198 beginning = beginning.right 199 } 200 } 201 if beginning != nil { 202 return &beginning.Value 203 } 204 return nil 205 } 206 207 // Tail returns the last value in the tree 208 func (t *Btree[T]) Tail() *T { 209 if t.root == nil { 210 return nil 211 } 212 beginning := t.root 213 for beginning.right != nil { 214 beginning = beginning.right 215 } 216 if beginning == nil { 217 for beginning.left != nil { 218 beginning = beginning.left 219 } 220 } 221 if beginning != nil { 222 return &beginning.Value 223 } 224 return nil 225 } 226 227 // Values returns a slice of all the values in tree in order 228 func (t *Btree[T]) Values() []T { 229 if t.values == nil { 230 t.values = make([]T, t.len) 231 t.Ascend(func(n *Node[T], i int) bool { 232 t.values[i] = n.Value 233 return true 234 }) 235 } 236 return t.values 237 } 238 239 // Delete deletes the node from the tree associated with the search value 240 func (t *Btree[T]) Delete(value T) *Btree[T] { 241 deleted := false 242 t.root = deleteNode(t.root, value, &deleted) 243 if deleted { 244 t.len-- 245 } 246 t.values = nil 247 return t 248 } 249 250 // DeleteAll deletes the nodes from the tree associated with the search values 251 func (t *Btree[T]) DeleteAll(values []T) *Btree[T] { 252 for _, v := range values { 253 t.Delete(v) 254 } 255 return t 256 } 257 258 func deleteNode[T Val[T]](n *Node[T], value T, deleted *bool) *Node[T] { 259 if n == nil { 260 return n 261 } 262 263 c := value.Comp(n.Value) 264 265 if c < 0 { 266 n.left = deleteNode(n.left, value, deleted) 267 } else if c > 0 { 268 n.right = deleteNode(n.right, value, deleted) 269 } else { 270 if n.left == nil { 271 t := n.right 272 n.Init() 273 *deleted = true 274 return t 275 } else if n.right == nil { 276 t := n.left 277 n.Init() 278 *deleted = true 279 return t 280 } 281 t := n.right.min() 282 n.Value = t.Value 283 n.right = deleteNode(n.right, t.Value, deleted) 284 *deleted = true 285 } 286 287 // re-balance 288 if n == nil { 289 return n 290 } 291 n.height = n.maxHeight() + 1 292 bal := balance(n) 293 if bal > 1 { 294 if balance(n.left) >= 0 { 295 return n.rotateRight() 296 } 297 n.left = n.left.rotateLeft() 298 return n.rotateRight() 299 } else if bal < -1 { 300 if balance(n.right) <= 0 { 301 return n.rotateLeft() 302 } 303 n.right = n.right.rotateRight() 304 return n.rotateLeft() 305 } 306 307 return n 308 } 309 310 // Pop deletes the last node from the tree and returns its value 311 func (t *Btree[T]) Pop() *T { 312 value := t.Tail() 313 if value != nil { 314 t.Delete(*value) 315 } 316 return value 317 } 318 319 // Pull deletes the first node from the tree and returns its value 320 func (t *Btree[T]) Pull() *T { 321 value := t.Head() 322 if value != nil { 323 t.Delete(*value) 324 } 325 return value 326 } 327 328 // NodeIterator expresses the iterator function used for traversals 329 type NodeIterator[T Val[T]] func(n *Node[T], i int) bool 330 331 // Ascend performs an ascending order traversal of the tree calling the iterator function on each node 332 // the iterator will continue as long as the NodeIterator returns true 333 func (t *Btree[T]) Ascend(iterator NodeIterator[T]) { 334 var i int 335 if t.root != nil { 336 t.root.iterate(iterator, &i, true) 337 } 338 } 339 340 // Descend performs a descending order traversal of the tree using the iterator 341 // the iterator will continue as long as the NodeIterator returns true 342 func (t *Btree[T]) Descend(iterator NodeIterator[T]) { 343 var i int 344 if t.root != nil { 345 t.root.rIterate(iterator, &i, true) 346 } 347 } 348 349 // Debug prints out useful debug information about the tree for debugging purposes 350 func (t *Btree[T]) Debug() { 351 fmt.Println("----------------------------------------------------------------------------------------------") 352 if t.Empty() { 353 fmt.Println("tree is empty") 354 } else { 355 fmt.Println(t.Len(), "elements") 356 } 357 358 t.Ascend(func(n *Node[T], i int) bool { 359 if t.root.Value.Comp(n.Value) == 0 { 360 fmt.Print("ROOT ** ") 361 } 362 n.Debug() 363 return true 364 }) 365 fmt.Println("----------------------------------------------------------------------------------------------") 366 } 367 368 // Init initializes the values of the node or clears the node and returns the node pointer 369 func (n *Node[T]) Init() *Node[T] { 370 n.height = 1 371 n.left = nil 372 n.right = nil 373 return n 374 } 375 376 // String returns a string representing the node 377 func (n *Node[T]) String() string { 378 return fmt.Sprint(n.Value) 379 } 380 381 // Debug prints out useful debug information about the tree node for debugging purposes 382 func (n *Node[T]) Debug() { 383 var children string 384 if n.left == nil && n.right == nil { 385 children = "no children |" 386 } else if n.left != nil && n.right != nil { 387 children = fmt.Sprint("left child:", n.left.String(), " right child:", n.right.String()) 388 } else if n.right != nil { 389 children = fmt.Sprint("right child:", n.right.String()) 390 } else { 391 children = fmt.Sprint("left child:", n.left.String()) 392 } 393 394 fmt.Println(n.String(), "|", "height", n.height, "|", "balance", balance(n), "|", children) 395 } 396 397 func height[T Val[T]](n *Node[T]) int8 { 398 if n != nil { 399 return n.height 400 } 401 return 0 402 } 403 404 func balance[T Val[T]](n *Node[T]) int8 { 405 if n == nil { 406 return 0 407 } 408 return height(n.left) - height(n.right) 409 } 410 411 func (n *Node[T]) get(val T) *Node[T] { 412 var node *Node[T] 413 c := val.Comp(n.Value) 414 if c < 0 { 415 if n.left != nil { 416 node = n.left.get(val) 417 } 418 } else if c > 0 { 419 if n.right != nil { 420 node = n.right.get(val) 421 } 422 } else { 423 node = n 424 } 425 return node 426 } 427 428 func (n *Node[T]) match(cond T, results *[]T) { 429 c := n.Value.Match(cond) 430 if c > 0 { 431 if n.left != nil { 432 n.left.match(cond, results) 433 } 434 } else if c < 0 { 435 if n.right != nil { 436 n.right.match(cond, results) 437 } 438 } else { 439 // other matching nodes could be on both sides 440 if n.left != nil { 441 n.left.match(cond, results) 442 } 443 *results = append(*results, n.Value) 444 if n.right != nil { 445 n.right.match(cond, results) 446 } 447 } 448 } 449 450 func (n *Node[T]) rotateRight() *Node[T] { 451 l := n.left 452 // Rotation 453 l.right, n.left = n, l.right 454 455 // update heights 456 n.height = n.maxHeight() + 1 457 l.height = l.maxHeight() + 1 458 459 return l 460 } 461 462 func (n *Node[T]) rotateLeft() *Node[T] { 463 r := n.right 464 // Rotation 465 r.left, n.right = n, r.left 466 467 // update heights 468 n.height = n.maxHeight() + 1 469 r.height = r.maxHeight() + 1 470 471 return r 472 } 473 474 func (n *Node[T]) iterate(iterator NodeIterator[T], i *int, cont bool) { 475 if n != nil && cont { 476 n.left.iterate(iterator, i, cont) 477 cont = iterator(n, *i) 478 *i++ 479 n.right.iterate(iterator, i, cont) 480 } 481 } 482 483 func (n *Node[T]) rIterate(iterator NodeIterator[T], i *int, cont bool) { 484 if n != nil && cont { 485 n.right.iterate(iterator, i, cont) 486 cont = iterator(n, *i) 487 *i++ 488 n.left.iterate(iterator, i, cont) 489 } 490 } 491 492 func (n *Node[T]) min() *Node[T] { 493 current := n 494 for current.left != nil { 495 current = current.left 496 } 497 return current 498 } 499 500 func (n *Node[T]) maxHeight() int8 { 501 rh := height(n.right) 502 lh := height(n.left) 503 if rh > lh { 504 return rh 505 } 506 return lh 507 }