go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/collections/binary_search_tree.go (about) 1 /* 2 3 Copyright (c) 2023 - Present. Will Charczuk. All rights reserved. 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository. 5 6 */ 7 8 package collections 9 10 import "cmp" 11 12 // BinarySearchTree is a AVL balanced tree which holds the properties 13 // that nodes are ordered left to right. 14 // 15 // The choice to use AVL to balance the tree means the use cases skew 16 // towards fast lookups at the expense of more costly mutations. 17 type BinarySearchTree[K cmp.Ordered, V any] struct { 18 root *BinarySearchTreeNode[K, V] 19 } 20 21 // Insert adds a new value to the binary search tree. 22 func (bst *BinarySearchTree[K, V]) Insert(k K, v V) { 23 bst.root = bst._insert(bst.root, k, v) 24 } 25 26 // Delete deletes a value from the tree, and returns if it existed. 27 func (bst *BinarySearchTree[K, V]) Delete(k K) { 28 bst.root = bst._delete(bst.root, k) 29 } 30 31 // Search searches for a node with a given key, returning the value 32 // and a boolean indicating the key was found. 33 func (bst *BinarySearchTree[K, V]) Search(k K) (v V, ok bool) { 34 v, ok = bst._search(bst.root, k) 35 return 36 } 37 38 // Min returns the minimum key and value. 39 func (bst *BinarySearchTree[K, V]) Min() (k K, v V, ok bool) { 40 if bst.root == nil { 41 return 42 } 43 k, v, ok = bst.root.Key, bst.root.Value, true 44 current := bst.root 45 for current.Left != nil { 46 current = current.Left 47 k, v = current.Key, current.Value 48 } 49 return 50 } 51 52 // Max returns the maximum key and value. 53 func (bst *BinarySearchTree[K, V]) Max() (k K, v V, ok bool) { 54 if bst.root == nil { 55 return 56 } 57 k, v, ok = bst.root.Key, bst.root.Value, true 58 current := bst.root 59 for current.Right != nil { 60 current = current.Right 61 k, v = current.Key, current.Value 62 } 63 return 64 } 65 66 // InOrder traversal returns the sorted values in the tree. 67 func (bst *BinarySearchTree[K, V]) InOrder(fn func(K, V)) { 68 bst._inOrder(bst.root, fn) 69 } 70 71 // PreOrder traversal returns the values in the tree in pre-order. 72 func (bst *BinarySearchTree[K, V]) PreOrder(fn func(K, V)) { 73 bst._preOrder(bst.root, fn) 74 } 75 76 // PostOrder traversal returns the values in the tree in post-order. 77 func (bst *BinarySearchTree[K, V]) PostOrder(fn func(K, V)) { 78 bst._postOrder(bst.root, fn) 79 } 80 81 // KeysEqual is a function that can be used to deeply compare two trees based on their keys. 82 // 83 // Values are _not_ considered because values are not comparable by design. 84 func (bst *BinarySearchTree[K, V]) KeysEqual(other *BinarySearchTree[K, V]) bool { 85 return bst._keysEqual(bst.root, other.root) 86 } 87 88 // 89 // internal methods 90 // 91 92 func (bst *BinarySearchTree[K, V]) _height(n *BinarySearchTreeNode[K, V]) int { 93 if n == nil { 94 return 0 95 } 96 return n.Height 97 } 98 99 func (bst *BinarySearchTree[K, V]) _inOrder(n *BinarySearchTreeNode[K, V], fn func(K, V)) { 100 if n == nil { 101 return 102 } 103 bst._inOrder(n.Left, fn) 104 fn(n.Key, n.Value) 105 bst._inOrder(n.Right, fn) 106 } 107 108 func (bst *BinarySearchTree[K, V]) _preOrder(n *BinarySearchTreeNode[K, V], fn func(K, V)) { 109 if n == nil { 110 return 111 } 112 fn(n.Key, n.Value) 113 bst._preOrder(n.Left, fn) 114 bst._preOrder(n.Right, fn) 115 } 116 117 func (bst *BinarySearchTree[K, V]) _postOrder(n *BinarySearchTreeNode[K, V], fn func(K, V)) { 118 if n == nil { 119 return 120 } 121 bst._postOrder(n.Left, fn) 122 bst._postOrder(n.Right, fn) 123 fn(n.Key, n.Value) 124 } 125 126 func (bst *BinarySearchTree[K, V]) _insert(n *BinarySearchTreeNode[K, V], k K, v V) *BinarySearchTreeNode[K, V] { 127 if n == nil { 128 return &BinarySearchTreeNode[K, V]{ 129 Key: k, 130 Value: v, 131 Height: 1, 132 } 133 } 134 135 if k < n.Key { 136 n.Left = bst._insert(n.Left, k, v) 137 } else if k > n.Key { 138 n.Right = bst._insert(n.Right, k, v) 139 } else { 140 n.Value = v 141 return n 142 } 143 144 n.Height = max(bst._height(n.Left), bst._height(n.Right)) + 1 145 146 balanceFactor := bst._getBalanceFactor(n) 147 if balanceFactor > 1 && k < n.Left.Key { 148 return bst._rotateRight(n) 149 } 150 if balanceFactor < -1 && k > n.Right.Key { 151 return bst._rotateLeft(n) 152 } 153 if balanceFactor > 1 && k > n.Left.Key { 154 n.Left = bst._rotateLeft(n.Left) 155 return bst._rotateRight(n) 156 } 157 if balanceFactor < -1 && k < n.Right.Key { 158 n.Right = bst._rotateRight(n.Right) 159 return bst._rotateLeft(n) 160 } 161 return n 162 } 163 164 func (bst *BinarySearchTree[K, V]) _delete(n *BinarySearchTreeNode[K, V], k K) *BinarySearchTreeNode[K, V] { 165 if n == nil { 166 return nil 167 } 168 169 if k < n.Key { 170 n.Left = bst._delete(n.Left, k) 171 } else if k > n.Key { 172 n.Right = bst._delete(n.Right, k) 173 } else { 174 if n.Left == nil || n.Right == nil { 175 var temp *BinarySearchTreeNode[K, V] 176 if n.Left == nil { 177 temp = n.Right 178 } else { 179 temp = n.Left 180 } 181 if temp == nil { 182 n = nil 183 } else { 184 n = temp 185 } 186 } else { 187 temp := bst._searchMin(n.Right) 188 n.Key, n.Value = temp.Key, temp.Value 189 n.Right = bst._delete(n.Right, temp.Key) 190 } 191 } 192 193 if n == nil { 194 return nil 195 } 196 197 n.Height = max(bst._height(n.Left), bst._height(n.Right)) + 1 198 199 balanceFactor := bst._getBalanceFactor(n) 200 if balanceFactor > 1 && bst._getBalanceFactor(n.Left) >= 0 { 201 return bst._rotateRight(n) 202 } 203 if balanceFactor > 1 && bst._getBalanceFactor(n.Left) < 0 { 204 n.Left = bst._rotateLeft(n.Left) 205 return bst._rotateRight(n) 206 } 207 208 if balanceFactor < -1 && bst._getBalanceFactor(n.Right) <= 0 { 209 return bst._rotateLeft(n) 210 } 211 212 if balanceFactor < -1 && bst._getBalanceFactor(n.Right) > 0 { 213 n.Right = bst._rotateRight(n.Right) 214 return bst._rotateLeft(n) 215 } 216 return n 217 } 218 219 func (bst *BinarySearchTree[K, V]) _searchMin(n *BinarySearchTreeNode[K, V]) (min *BinarySearchTreeNode[K, V]) { 220 min = n 221 for min.Left != nil { 222 min = min.Left 223 } 224 return 225 } 226 227 func (bst *BinarySearchTree[K, V]) _search(n *BinarySearchTreeNode[K, V], k K) (v V, ok bool) { 228 if n == nil { 229 return 230 } 231 if n.Key == k { 232 v = n.Value 233 ok = true 234 return 235 } 236 if k < n.Key { 237 v, ok = bst._search(n.Left, k) 238 return 239 } 240 v, ok = bst._search(n.Right, k) 241 return 242 } 243 244 func (bst *BinarySearchTree[K, V]) _rotateRight(y *BinarySearchTreeNode[K, V]) *BinarySearchTreeNode[K, V] { 245 if y.Left == nil { 246 return y 247 } 248 x := y.Left 249 t2 := x.Right 250 x.Right = y 251 y.Left = t2 252 y.Height = max(bst._height(y.Left), bst._height(y.Right)) + 1 253 x.Height = max(bst._height(x.Left), bst._height(x.Right)) + 1 254 return x 255 } 256 257 func (bst *BinarySearchTree[K, V]) _rotateLeft(x *BinarySearchTreeNode[K, V]) *BinarySearchTreeNode[K, V] { 258 if x.Right == nil { 259 return x 260 } 261 262 y := x.Right 263 t2 := y.Left 264 y.Left = x 265 x.Right = t2 266 x.Height = max(bst._height(x.Left), bst._height(x.Right)) + 1 267 y.Height = max(bst._height(y.Left), bst._height(y.Right)) + 1 268 return y 269 } 270 271 func (bst *BinarySearchTree[K, V]) _getBalanceFactor(n *BinarySearchTreeNode[K, V]) int { 272 if n == nil { 273 return 0 274 } 275 return bst._height(n.Left) - bst._height(n.Right) 276 } 277 278 func (bst *BinarySearchTree[K, V]) _keysEqual(a, b *BinarySearchTreeNode[K, V]) bool { 279 if a == nil && b == nil { 280 return true 281 } 282 if a != nil && b == nil { 283 return false 284 } 285 if a == nil && b != nil { 286 return false 287 } 288 if a.Key != b.Key { 289 return false 290 } 291 if a.Height != b.Height { 292 return false 293 } 294 return bst._keysEqual(a.Left, b.Left) && bst._keysEqual(a.Right, b.Right) 295 } 296 297 // BinarySearchTreeNode is a node in a BinarySearchTree. 298 type BinarySearchTreeNode[K cmp.Ordered, V any] struct { 299 Key K 300 Value V 301 Left *BinarySearchTreeNode[K, V] 302 Right *BinarySearchTreeNode[K, V] 303 Height int 304 } 305 306 func max[K cmp.Ordered](keys ...K) (k K) { 307 if len(keys) == 0 { 308 return 309 } 310 k = keys[0] 311 for x := 1; x < len(keys); x++ { 312 if keys[x] > k { 313 k = keys[x] 314 } 315 } 316 return 317 }