github.com/songzhibin97/go-baseutils@v0.0.2-0.20240302024150-487d8ce9c082/structure/trees/avltree/avltree.go (about) 1 // Package avltree implements an AVL balanced binary tree. 2 // 3 // Structure is not thread safe. 4 // 5 // References: https://en.wikipedia.org/wiki/AVL_tree 6 package avltree 7 8 import ( 9 "encoding/json" 10 "fmt" 11 "strings" 12 13 "github.com/songzhibin97/go-baseutils/base/bcomparator" 14 "github.com/songzhibin97/go-baseutils/structure/trees" 15 ) 16 17 // Assert Tree implementation 18 var _ trees.Tree[any] = new(Tree[any, any]) 19 20 // Tree holds elements of the AVL tree. 21 type Tree[K any, V any] struct { 22 Root *Node[K, V] // Root node 23 Comparator bcomparator.Comparator[K] // Key comparator 24 size int // Total number of keys in the tree 25 zeroV V 26 } 27 28 // Node is a single element within the tree 29 type Node[K, V any] struct { 30 Key K 31 Value V 32 Parent *Node[K, V] // Parent node 33 Children [2]*Node[K, V] // Children nodes 34 b int8 35 } 36 37 // NewWith instantiates an AVL tree with the custom comparator. 38 func NewWith[K, V any](comparator bcomparator.Comparator[K]) *Tree[K, V] { 39 return &Tree[K, V]{Comparator: comparator} 40 } 41 42 // NewWithIntComparator instantiates an AVL tree with the IntComparator, i.e. keys are of type int. 43 func NewWithIntComparator[V any]() *Tree[int, V] { 44 return &Tree[int, V]{Comparator: bcomparator.IntComparator()} 45 } 46 47 // NewWithStringComparator instantiates an AVL tree with the StringComparator, i.e. keys are of type string. 48 func NewWithStringComparator[V any]() *Tree[string, V] { 49 return &Tree[string, V]{Comparator: bcomparator.StringComparator()} 50 } 51 52 // Put inserts node into the tree. 53 // Key should adhere to the comparator's type assertion, otherwise method panics. 54 func (tree *Tree[K, V]) Put(key K, value V) { 55 tree.put(key, value, nil, &tree.Root) 56 } 57 58 // Get searches the node in the tree by key and returns its value or nil if key is not found in tree. 59 // Second return parameter is true if key was found, otherwise false. 60 // Key should adhere to the comparator's type assertion, otherwise method panics. 61 func (tree *Tree[K, V]) Get(key K) (value V, found bool) { 62 n := tree.GetNode(key) 63 if n != nil { 64 return n.Value, true 65 } 66 return tree.zeroV, false 67 } 68 69 // GetNode searches the node in the tree by key and returns its node or nil if key is not found in tree. 70 // Key should adhere to the comparator's type assertion, otherwise method panics. 71 func (tree *Tree[K, V]) GetNode(key K) *Node[K, V] { 72 n := tree.Root 73 for n != nil { 74 cmp := tree.Comparator(key, n.Key) 75 switch { 76 case cmp == 0: 77 return n 78 case cmp < 0: 79 n = n.Children[0] 80 case cmp > 0: 81 n = n.Children[1] 82 } 83 } 84 return n 85 } 86 87 // Remove remove the node from the tree by key. 88 // Key should adhere to the comparator's type assertion, otherwise method panics. 89 func (tree *Tree[K, V]) Remove(key K) { 90 tree.remove(key, &tree.Root) 91 } 92 93 // Empty returns true if tree does not contain any nodes. 94 func (tree *Tree[K, V]) Empty() bool { 95 return tree.size == 0 96 } 97 98 // Size returns the number of elements stored in the tree. 99 func (tree *Tree[K, V]) Size() int { 100 return tree.size 101 } 102 103 // Size returns the number of elements stored in the subtree. 104 // Computed dynamically on each call, i.e. the subtree is traversed to count the number of the nodes. 105 func (n *Node[K, V]) Size() int { 106 if n == nil { 107 return 0 108 } 109 size := 1 110 if n.Children[0] != nil { 111 size += n.Children[0].Size() 112 } 113 if n.Children[1] != nil { 114 size += n.Children[1].Size() 115 } 116 return size 117 } 118 119 // Keys returns all keys in-order 120 func (tree *Tree[K, V]) Keys() []K { 121 keys := make([]K, tree.size) 122 it := tree.Iterator() 123 for i := 0; it.Next(); i++ { 124 keys[i] = it.Key() 125 } 126 return keys 127 } 128 129 // Values returns all values in-order based on the key. 130 func (tree *Tree[K, V]) Values() []V { 131 values := make([]V, tree.size) 132 it := tree.Iterator() 133 for i := 0; it.Next(); i++ { 134 values[i] = it.Value() 135 } 136 return values 137 } 138 139 // Left returns the minimum element of the AVL tree 140 // or nil if the tree is empty. 141 func (tree *Tree[K, V]) Left() *Node[K, V] { 142 return tree.bottom(0) 143 } 144 145 // Right returns the maximum element of the AVL tree 146 // or nil if the tree is empty. 147 func (tree *Tree[K, V]) Right() *Node[K, V] { 148 return tree.bottom(1) 149 } 150 151 // Floor Finds floor node of the input key, return the floor node or nil if no floor is found. 152 // Second return parameter is true if floor was found, otherwise false. 153 // 154 // Floor node is defined as the largest node that is smaller than or equal to the given node. 155 // A floor node may not be found, either because the tree is empty, or because 156 // all nodes in the tree is larger than the given node. 157 // 158 // Key should adhere to the comparator's type assertion, otherwise method panics. 159 func (tree *Tree[K, V]) Floor(key K) (floor *Node[K, V], found bool) { 160 found = false 161 n := tree.Root 162 for n != nil { 163 c := tree.Comparator(key, n.Key) 164 switch { 165 case c == 0: 166 return n, true 167 case c < 0: 168 n = n.Children[0] 169 case c > 0: 170 floor, found = n, true 171 n = n.Children[1] 172 } 173 } 174 if found { 175 return 176 } 177 return nil, false 178 } 179 180 // Ceiling finds ceiling node of the input key, return the ceiling node or nil if no ceiling is found. 181 // Second return parameter is true if ceiling was found, otherwise false. 182 // 183 // Ceiling node is defined as the smallest node that is larger than or equal to the given node. 184 // A ceiling node may not be found, either because the tree is empty, or because 185 // all nodes in the tree is smaller than the given node. 186 // 187 // Key should adhere to the comparator's type assertion, otherwise method panics. 188 func (tree *Tree[K, V]) Ceiling(key K) (floor *Node[K, V], found bool) { 189 found = false 190 n := tree.Root 191 for n != nil { 192 c := tree.Comparator(key, n.Key) 193 switch { 194 case c == 0: 195 return n, true 196 case c < 0: 197 floor, found = n, true 198 n = n.Children[0] 199 case c > 0: 200 n = n.Children[1] 201 } 202 } 203 if found { 204 return 205 } 206 return nil, false 207 } 208 209 // Clear removes all nodes from the tree. 210 func (tree *Tree[K, V]) Clear() { 211 tree.Root = nil 212 tree.size = 0 213 } 214 215 // String returns a string representation of container 216 func (tree *Tree[K, V]) String() string { 217 b := strings.Builder{} 218 b.WriteString("AVLTree\n") 219 if !tree.Empty() { 220 output(tree.Root, "", true, &b) 221 } 222 return b.String() 223 } 224 225 func (n *Node[K, V]) String() string { 226 return fmt.Sprintf("%v", n.Key) 227 } 228 229 func (tree *Tree[K, V]) put(key K, value V, p *Node[K, V], qp **Node[K, V]) bool { 230 q := *qp 231 if q == nil { 232 tree.size++ 233 *qp = &Node[K, V]{Key: key, Value: value, Parent: p} 234 return true 235 } 236 237 c := tree.Comparator(key, q.Key) 238 if c == 0 { 239 q.Key = key 240 q.Value = value 241 return false 242 } 243 244 if c < 0 { 245 c = -1 246 } else { 247 c = 1 248 } 249 a := (c + 1) / 2 250 var fix bool 251 fix = tree.put(key, value, q, &q.Children[a]) 252 if fix { 253 return putFix(int8(c), qp) 254 } 255 return false 256 } 257 258 func (tree *Tree[K, V]) remove(key K, qp **Node[K, V]) bool { 259 q := *qp 260 if q == nil { 261 return false 262 } 263 264 c := tree.Comparator(key, q.Key) 265 if c == 0 { 266 tree.size-- 267 if q.Children[1] == nil { 268 if q.Children[0] != nil { 269 q.Children[0].Parent = q.Parent 270 } 271 *qp = q.Children[0] 272 return true 273 } 274 fix := removeMin(&q.Children[1], &q.Key, &q.Value) 275 if fix { 276 return removeFix(-1, qp) 277 } 278 return false 279 } 280 281 if c < 0 { 282 c = -1 283 } else { 284 c = 1 285 } 286 a := (c + 1) / 2 287 fix := tree.remove(key, &q.Children[a]) 288 if fix { 289 return removeFix(int8(-c), qp) 290 } 291 return false 292 } 293 294 func removeMin[K, V any](qp **Node[K, V], minKey *K, minVal *V) bool { 295 q := *qp 296 if q.Children[0] == nil { 297 *minKey = q.Key 298 *minVal = q.Value 299 if q.Children[1] != nil { 300 q.Children[1].Parent = q.Parent 301 } 302 *qp = q.Children[1] 303 return true 304 } 305 fix := removeMin(&q.Children[0], minKey, minVal) 306 if fix { 307 return removeFix(1, qp) 308 } 309 return false 310 } 311 312 func putFix[K, V any](c int8, t **Node[K, V]) bool { 313 s := *t 314 if s.b == 0 { 315 s.b = c 316 return true 317 } 318 319 if s.b == -c { 320 s.b = 0 321 return false 322 } 323 324 if s.Children[(c+1)/2].b == c { 325 s = singlerot(c, s) 326 } else { 327 s = doublerot(c, s) 328 } 329 *t = s 330 return false 331 } 332 333 func removeFix[K, V any](c int8, t **Node[K, V]) bool { 334 s := *t 335 if s.b == 0 { 336 s.b = c 337 return false 338 } 339 340 if s.b == -c { 341 s.b = 0 342 return true 343 } 344 345 a := (c + 1) / 2 346 if s.Children[a].b == 0 { 347 s = rotate(c, s) 348 s.b = -c 349 *t = s 350 return false 351 } 352 353 if s.Children[a].b == c { 354 s = singlerot(c, s) 355 } else { 356 s = doublerot(c, s) 357 } 358 *t = s 359 return true 360 } 361 362 func singlerot[K, V any](c int8, s *Node[K, V]) *Node[K, V] { 363 s.b = 0 364 s = rotate(c, s) 365 s.b = 0 366 return s 367 } 368 369 func doublerot[K, V any](c int8, s *Node[K, V]) *Node[K, V] { 370 a := (c + 1) / 2 371 r := s.Children[a] 372 s.Children[a] = rotate(-c, s.Children[a]) 373 p := rotate(c, s) 374 375 switch { 376 default: 377 s.b = 0 378 r.b = 0 379 case p.b == c: 380 s.b = -c 381 r.b = 0 382 case p.b == -c: 383 s.b = 0 384 r.b = c 385 } 386 387 p.b = 0 388 return p 389 } 390 391 func rotate[K, V any](c int8, s *Node[K, V]) *Node[K, V] { 392 a := (c + 1) / 2 393 r := s.Children[a] 394 s.Children[a] = r.Children[a^1] 395 if s.Children[a] != nil { 396 s.Children[a].Parent = s 397 } 398 r.Children[a^1] = s 399 r.Parent = s.Parent 400 s.Parent = r 401 return r 402 } 403 404 func (tree *Tree[K, V]) bottom(d int) *Node[K, V] { 405 n := tree.Root 406 if n == nil { 407 return nil 408 } 409 410 for c := n.Children[d]; c != nil; c = n.Children[d] { 411 n = c 412 } 413 return n 414 } 415 416 // Prev returns the previous element in an inorder 417 // walk of the AVL tree. 418 func (n *Node[K, V]) Prev() *Node[K, V] { 419 return n.walk1(0) 420 } 421 422 // Next returns the next element in an inorder 423 // walk of the AVL tree. 424 func (n *Node[K, V]) Next() *Node[K, V] { 425 return n.walk1(1) 426 } 427 428 func (n *Node[K, V]) walk1(a int) *Node[K, V] { 429 if n == nil { 430 return nil 431 } 432 433 if n.Children[a] != nil { 434 n = n.Children[a] 435 for n.Children[a^1] != nil { 436 n = n.Children[a^1] 437 } 438 return n 439 } 440 441 p := n.Parent 442 for p != nil && p.Children[a] == n { 443 n = p 444 p = p.Parent 445 } 446 return p 447 } 448 449 func output[K, V any](node *Node[K, V], prefix string, isTail bool, builder *strings.Builder) { 450 if node.Children[1] != nil { 451 newPrefix := prefix 452 if isTail { 453 newPrefix += "│ " 454 } else { 455 newPrefix += " " 456 } 457 output(node.Children[1], newPrefix, false, builder) 458 } 459 builder.WriteString(prefix) 460 if isTail { 461 builder.WriteString("└── ") 462 } else { 463 builder.WriteString("┌── ") 464 } 465 builder.WriteString(node.String() + "\n") 466 if node.Children[0] != nil { 467 newPrefix := prefix 468 if isTail { 469 newPrefix += " " 470 } else { 471 newPrefix += "│ " 472 } 473 output(node.Children[0], newPrefix, true, builder) 474 } 475 } 476 477 // UnmarshalJSON @implements json.Unmarshaler 478 func (tree *Tree[K, V]) UnmarshalJSON(data []byte) error { 479 elements := make(map[string]V) 480 err := json.Unmarshal(data, &elements) 481 if err == nil { 482 tree.Clear() 483 for key, value := range elements { 484 var nk K 485 err = tree.Comparator.Unmarshal([]byte(key), &nk) 486 if err != nil { 487 return err 488 } 489 tree.Put(nk, value) 490 } 491 } 492 return err 493 } 494 495 // MarshalJSON @implements json.Marshaler 496 func (tree *Tree[K, V]) MarshalJSON() ([]byte, error) { 497 elements := make(map[string]V) 498 it := tree.Iterator() 499 for it.Next() { 500 k, err := tree.Comparator.Marshal(it.Key()) 501 if err != nil { 502 return nil, err 503 } 504 elements[string(k)] = it.Value() 505 } 506 return json.Marshal(&elements) 507 }