github.com/qiuhoude/go-web@v0.0.0-20220223060959-ab545e78f20d/algorithm/datastructures/tree/avl/avl_tree.go (about) 1 package avl 2 3 import ( 4 "container/list" 5 "fmt" 6 "strings" 7 ) 8 9 // 图像化 https://www.cs.usfca.edu/~galles/visualization/Algorithms.html 10 // 11 12 /* 13 平衡二叉树: 14 对于任意一个节点,左子树和右子树的高度差不能超过1 15 16 平衡二叉树的高度和节点数量之间的关系也是 O(logn) 17 所以需要 记录高度 和 计算平衡因子 18 */ 19 20 // 节点 21 type avlNode struct { 22 left, right *avlNode 23 height int // 当前节点的高度 24 v interface{} 25 } 26 27 func (n *avlNode) String() string { 28 return fmt.Sprintf("h:%v, v:%+v", n.height, n.v) 29 } 30 31 func newAvlNode(v interface{}) *avlNode { 32 return &avlNode{ 33 v: v, 34 height: 1, 35 } 36 } 37 38 //获得节点node的高度 39 func getHeight(n *avlNode) int { 40 if n == nil { 41 return 0 42 } 43 return n.height 44 } 45 46 // 获得节点node的平衡因子,左子树高度 - 右子树高度 47 func getBalanceFactor(n *avlNode) int { 48 if n == nil { 49 return 0 50 } 51 return getHeight(n.left) - getHeight(n.right) 52 } 53 54 // avl树 55 type AVLTree struct { 56 root *avlNode 57 size int 58 compareFunc CompareFunc 59 } 60 61 // 比较函数,v表示要操作的的值 62 type CompareFunc func(v, nodeV interface{}) int 63 64 func NewAVLTree(cfunc CompareFunc) *AVLTree { 65 return &AVLTree{ 66 compareFunc: cfunc, 67 } 68 } 69 70 func (t *AVLTree) InOrder(f func(interface{})) { 71 inOrder(t.root, f) 72 } 73 74 func inOrder(n *avlNode, f func(interface{})) { 75 if n == nil { 76 return 77 } 78 inOrder(n.left, f) 79 f(n.v) 80 inOrder(n.right, f) 81 } 82 83 func (t *AVLTree) Add(v interface{}) { 84 t.root, _ = t.add(t.root, v) 85 } 86 87 func (t *AVLTree) add(n *avlNode, v interface{}) (*avlNode, bool) { 88 if n == nil { 89 t.size++ 90 return newAvlNode(v), true 91 } 92 cmp := t.compareFunc(v, n.v) 93 var addSuc bool 94 if cmp == 0 { 95 return n, false 96 } else if cmp > 0 { 97 n.right, addSuc = t.add(n.right, v) 98 } else { //cmp < 0 99 n.left, addSuc = t.add(n.left, v) 100 } 101 102 // 更新height 103 n.height = max(getHeight(n.left), getHeight(n.right)) + 1 104 105 // 计算平衡因子 106 // balanceFactor > 0 说明 左边高 右边低 107 // balanceFactor < 0 说明 左边低 右边高 108 balanceFactor := getBalanceFactor(n) 109 110 //if abs(balanceFactor) > 1 { // 平衡因子大1 需要处理 111 // fmt.Println("unbalanced:", balanceFactor) 112 //} 113 114 // 维护平衡操作 115 if balanceFactor > 1 && getBalanceFactor(n.left) >= 0 { //LL 116 // LL 型,在父节点的左孩子的左子树添加了新节点,导致根节点的平衡因子变为 +2,二叉树失去平衡 117 // 右旋一次即可 118 n = t.rightRotate(n) 119 } else if balanceFactor < -1 && getBalanceFactor(n.right) <= 0 { //RR 120 // 121 // RR 型 同理 122 n = t.leftRotate(n) 123 } else if balanceFactor > 1 && getBalanceFactor(n.left) < 0 { //LR 124 //LR 就是将新的节点插入到了 n 的左孩子的右子树上导致的不平衡的情况。 125 // 这时我们需要的是先对 左孩子进行一次左旋再对 自己 进行一次右旋 126 n.left = t.leftRotate(n.left) 127 n = t.rightRotate(n) 128 129 } else if balanceFactor < -1 && getBalanceFactor(n.right) > 0 { //RL 130 n.right = t.rightRotate(n.right) 131 n = t.leftRotate(n) 132 } 133 134 return n, addSuc 135 } 136 137 // 计算数高 138 func treeHeight(n *avlNode) int { 139 if n == nil { 140 return 0 141 } 142 return max(treeHeight(n.left), treeHeight(n.right)) + 1 143 } 144 145 // 对节点y进行向右旋转操作,返回旋转后新的根节点x 146 // y x 147 // / \ / \ 148 // x T4 向右旋转 (y) z y 149 // / \ - - - - - - - -> / \ / \ 150 // z T3 T1 T2 T3 T4 151 // / \ 152 // T1 T2 153 func (t *AVLTree) rightRotate(y *avlNode) *avlNode { 154 if y == nil { 155 return nil 156 } 157 x := y.left 158 t3 := x.right 159 160 x.right = y 161 y.left = t3 162 163 // 更新height 164 y.height = max(getHeight(y.left), getHeight(y.right)) + 1 165 x.height = max(getHeight(x.left), getHeight(x.right)) + 1 166 return x 167 } 168 169 // 对节点y进行向左旋转操作,返回旋转后新的根节点x 170 // y x 171 // / \ / \ 172 // T1 x 向左旋转 (y) y z 173 // / \ - - - - - - - -> / \ / \ 174 // T2 z T1 T2 T3 T4 175 // / \ 176 // T3 T4 177 func (t *AVLTree) leftRotate(y *avlNode) *avlNode { 178 if y == nil { 179 return nil 180 } 181 x := y.right 182 t2 := x.left 183 184 x.left = y 185 y.right = t2 186 187 // 更新height 188 y.height = max(getHeight(y.left), getHeight(y.right)) + 1 189 x.height = max(getHeight(x.left), getHeight(x.right)) + 1 190 return x 191 } 192 193 func (t *AVLTree) findNode(n *avlNode, v interface{}) *avlNode { 194 if n == nil { 195 return nil 196 } 197 // 递归方式查找 198 cmp := t.compareFunc(v, n.v) 199 if cmp > 0 { 200 return t.findNode(n.right, v) 201 } else if cmp < 0 { 202 return t.findNode(n.left, v) 203 } else { // == 204 return n 205 } 206 207 /* 208 //非递归方式 209 findN := n 210 for findN != nil { 211 cmp := t.compareFunc(v, findN.v) 212 if cmp > 0 { 213 findN = n.right 214 } else if cmp < 0 { 215 findN = n.left 216 } else { 217 break 218 } 219 } 220 return findN 221 */ 222 } 223 224 func (t *AVLTree) Contains(v interface{}) bool { 225 return t.findNode(t.root, v) != nil 226 } 227 228 func (t *AVLTree) Remove(v interface{}) bool { 229 if !t.Contains(v) { // 不存在删除失败 230 return false 231 } 232 n := t.remove(t.root, v) 233 t.root = n 234 return true 235 } 236 237 // 移除对应元素,返回节点将挂载到调用者的子节点上 238 func (t *AVLTree) remove(n *avlNode, v interface{}) *avlNode { 239 if n == nil { 240 return nil 241 } 242 cmp := t.compareFunc(v, n.v) 243 var retNode *avlNode // 返回父节点 244 if cmp < 0 { 245 n.left = t.remove(n.left, v) 246 retNode = n 247 } else if cmp > 0 { 248 n.right = t.remove(n.right, v) 249 retNode = n 250 } else { // == 251 if n.left == nil && n.right == nil { // 待删除节点左右都为空 252 // 返回nil 给他的父节点 253 retNode = nil 254 } else if n.left == nil { // 1. 待删除节点左子树为空的情况 255 // 将右子树的数据反给上层 256 tn := n.right 257 n.left = nil 258 t.size-- 259 retNode = tn 260 } else if n.right == nil { // 2. 待删除节点右子树为空的情况 261 tn := n.left 262 n.left = nil 263 t.size-- 264 retNode = tn 265 } else { // 3. 左右都有数据 266 // 找到比待删除节点大的最小节点, 即待删除节点右子树的最小节点 267 // 用这个节点顶替待删除节点的位置 268 successor := t.minimum(n.right) 269 //removeMin中已经 size--,外面不需要 -- 270 successor.right = t.removeMin(n.right) 271 successor.left = n.left 272 n.left = nil 273 n.right = nil 274 retNode = successor 275 } 276 } 277 if retNode == nil { 278 return nil 279 } 280 281 // 维护更新height 282 retNode.height = max(getHeight(retNode.left), getHeight(retNode.right)) + 1 283 284 // 计算平衡因子 285 balanceFactor := getBalanceFactor(retNode) 286 287 // 维护平衡操作 288 if balanceFactor > 1 && getBalanceFactor(retNode.left) >= 0 { //LL 289 retNode = t.rightRotate(retNode) 290 } else if balanceFactor < -1 && getBalanceFactor(retNode.right) <= 0 { //RR 291 retNode = t.leftRotate(retNode) 292 } else if balanceFactor > 1 && getBalanceFactor(retNode.left) < 0 { //LR 293 retNode.left = t.leftRotate(retNode.left) 294 retNode = t.rightRotate(retNode) 295 } else if balanceFactor < -1 && getBalanceFactor(retNode.right) > 0 { //RL 296 retNode.right = t.rightRotate(retNode.right) 297 retNode = t.leftRotate(retNode) 298 } 299 return retNode 300 301 } 302 303 // 返回以node为根的二分搜索树的最小值所在的节点 304 func (t *AVLTree) minimum(n *avlNode) *avlNode { 305 if n.left == nil { 306 return n 307 } 308 return t.minimum(n.left) 309 } 310 311 // 删除掉以node为根的二分搜索树中的最小节点 312 // 返回删除节点后新的二分搜索树的根 和 是否删除成功 313 func (t *AVLTree) removeMin(n *avlNode) *avlNode { 314 if n.left == nil { 315 // 将要删除的右节点挂载父节点上,通过返回值返给服节点 316 tn := n.right 317 n.right = nil // 置空 gc回收 318 t.size-- 319 return tn 320 } 321 n.left = t.removeMin(n.left) 322 return n 323 } 324 325 func (t *AVLTree) String() string { 326 var sb strings.Builder 327 generateBSTLevelString(t.root, &sb) 328 return sb.String() 329 } 330 331 func generateBSTLevelString(n *avlNode, sb *strings.Builder) { 332 l := list.New() 333 l.PushBack(n) 334 335 curNodeCnt := 1 // 当前行的节点 336 nextNodeCnt := 0 // 下一行的节点 337 for l.Len() != 0 { 338 n, _ := l.Remove(l.Front()).(*avlNode) 339 340 sb.WriteString(fmt.Sprintf("%v ", n)) 341 curNodeCnt-- 342 if n.left != nil { 343 l.PushBack(n.left) 344 nextNodeCnt++ 345 } 346 if n.right != nil { 347 l.PushBack(n.right) 348 nextNodeCnt++ 349 } 350 if curNodeCnt == 0 { // 当前行打印完了 351 sb.WriteRune('\n') 352 curNodeCnt = nextNodeCnt 353 nextNodeCnt = 0 354 } 355 } 356 }