github.com/qiuhoude/go-web@v0.0.0-20220223060959-ab545e78f20d/algorithm/datastructures/tree/avl/avl_tree.go (about)

     1  package avl
     2  
     3  import (
     4  	"container/list"
     5  	"fmt"
     6  	"strings"
     7  )
     8  
     9  // 图像化 https://www.cs.usfca.edu/~galles/visualization/Algorithms.html
    10  //
    11  
    12  /*
    13  平衡二叉树:
    14  对于任意一个节点,左子树和右子树的高度差不能超过1
    15  
    16  平衡二叉树的高度和节点数量之间的关系也是 O(logn)
    17  所以需要 记录高度 和 计算平衡因子
    18  */
    19  
    20  // 节点
    21  type avlNode struct {
    22  	left, right *avlNode
    23  	height      int // 当前节点的高度
    24  	v           interface{}
    25  }
    26  
    27  func (n *avlNode) String() string {
    28  	return fmt.Sprintf("h:%v, v:%+v", n.height, n.v)
    29  }
    30  
    31  func newAvlNode(v interface{}) *avlNode {
    32  	return &avlNode{
    33  		v:      v,
    34  		height: 1,
    35  	}
    36  }
    37  
    38  //获得节点node的高度
    39  func getHeight(n *avlNode) int {
    40  	if n == nil {
    41  		return 0
    42  	}
    43  	return n.height
    44  }
    45  
    46  // 获得节点node的平衡因子,左子树高度 - 右子树高度
    47  func getBalanceFactor(n *avlNode) int {
    48  	if n == nil {
    49  		return 0
    50  	}
    51  	return getHeight(n.left) - getHeight(n.right)
    52  }
    53  
    54  // avl树
    55  type AVLTree struct {
    56  	root        *avlNode
    57  	size        int
    58  	compareFunc CompareFunc
    59  }
    60  
    61  // 比较函数,v表示要操作的的值
    62  type CompareFunc func(v, nodeV interface{}) int
    63  
    64  func NewAVLTree(cfunc CompareFunc) *AVLTree {
    65  	return &AVLTree{
    66  		compareFunc: cfunc,
    67  	}
    68  }
    69  
    70  func (t *AVLTree) InOrder(f func(interface{})) {
    71  	inOrder(t.root, f)
    72  }
    73  
    74  func inOrder(n *avlNode, f func(interface{})) {
    75  	if n == nil {
    76  		return
    77  	}
    78  	inOrder(n.left, f)
    79  	f(n.v)
    80  	inOrder(n.right, f)
    81  }
    82  
    83  func (t *AVLTree) Add(v interface{}) {
    84  	t.root, _ = t.add(t.root, v)
    85  }
    86  
    87  func (t *AVLTree) add(n *avlNode, v interface{}) (*avlNode, bool) {
    88  	if n == nil {
    89  		t.size++
    90  		return newAvlNode(v), true
    91  	}
    92  	cmp := t.compareFunc(v, n.v)
    93  	var addSuc bool
    94  	if cmp == 0 {
    95  		return n, false
    96  	} else if cmp > 0 {
    97  		n.right, addSuc = t.add(n.right, v)
    98  	} else { //cmp < 0
    99  		n.left, addSuc = t.add(n.left, v)
   100  	}
   101  
   102  	// 更新height
   103  	n.height = max(getHeight(n.left), getHeight(n.right)) + 1
   104  
   105  	// 计算平衡因子
   106  	// balanceFactor > 0 说明 左边高 右边低
   107  	// balanceFactor < 0 说明 左边低 右边高
   108  	balanceFactor := getBalanceFactor(n)
   109  
   110  	//if abs(balanceFactor) > 1 { // 平衡因子大1 需要处理
   111  	//	fmt.Println("unbalanced:", balanceFactor)
   112  	//}
   113  
   114  	// 维护平衡操作
   115  	if balanceFactor > 1 && getBalanceFactor(n.left) >= 0 { //LL
   116  		// LL 型,在父节点的左孩子的左子树添加了新节点,导致根节点的平衡因子变为 +2,二叉树失去平衡
   117  		// 右旋一次即可
   118  		n = t.rightRotate(n)
   119  	} else if balanceFactor < -1 && getBalanceFactor(n.right) <= 0 { //RR
   120  		//
   121  		// RR 型 同理
   122  		n = t.leftRotate(n)
   123  	} else if balanceFactor > 1 && getBalanceFactor(n.left) < 0 { //LR
   124  		//LR 就是将新的节点插入到了 n 的左孩子的右子树上导致的不平衡的情况。
   125  		// 这时我们需要的是先对 左孩子进行一次左旋再对 自己 进行一次右旋
   126  		n.left = t.leftRotate(n.left)
   127  		n = t.rightRotate(n)
   128  
   129  	} else if balanceFactor < -1 && getBalanceFactor(n.right) > 0 { //RL
   130  		n.right = t.rightRotate(n.right)
   131  		n = t.leftRotate(n)
   132  	}
   133  
   134  	return n, addSuc
   135  }
   136  
   137  // 计算数高
   138  func treeHeight(n *avlNode) int {
   139  	if n == nil {
   140  		return 0
   141  	}
   142  	return max(treeHeight(n.left), treeHeight(n.right)) + 1
   143  }
   144  
   145  // 对节点y进行向右旋转操作,返回旋转后新的根节点x
   146  //        y                              x
   147  //       / \                           /   \
   148  //      x   T4     向右旋转 (y)        z     y
   149  //     / \       - - - - - - - ->    / \   / \
   150  //    z   T3                       T1  T2 T3 T4
   151  //   / \
   152  // T1   T2
   153  func (t *AVLTree) rightRotate(y *avlNode) *avlNode {
   154  	if y == nil {
   155  		return nil
   156  	}
   157  	x := y.left
   158  	t3 := x.right
   159  
   160  	x.right = y
   161  	y.left = t3
   162  
   163  	// 更新height
   164  	y.height = max(getHeight(y.left), getHeight(y.right)) + 1
   165  	x.height = max(getHeight(x.left), getHeight(x.right)) + 1
   166  	return x
   167  }
   168  
   169  // 对节点y进行向左旋转操作,返回旋转后新的根节点x
   170  //    y                             x
   171  //  /  \                          /   \
   172  // T1   x      向左旋转 (y)       y     z
   173  //     / \   - - - - - - - ->   / \   / \
   174  //   T2  z                     T1 T2 T3 T4
   175  //      / \
   176  //     T3 T4
   177  func (t *AVLTree) leftRotate(y *avlNode) *avlNode {
   178  	if y == nil {
   179  		return nil
   180  	}
   181  	x := y.right
   182  	t2 := x.left
   183  
   184  	x.left = y
   185  	y.right = t2
   186  
   187  	// 更新height
   188  	y.height = max(getHeight(y.left), getHeight(y.right)) + 1
   189  	x.height = max(getHeight(x.left), getHeight(x.right)) + 1
   190  	return x
   191  }
   192  
   193  func (t *AVLTree) findNode(n *avlNode, v interface{}) *avlNode {
   194  	if n == nil {
   195  		return nil
   196  	}
   197  	// 递归方式查找
   198  	cmp := t.compareFunc(v, n.v)
   199  	if cmp > 0 {
   200  		return t.findNode(n.right, v)
   201  	} else if cmp < 0 {
   202  		return t.findNode(n.left, v)
   203  	} else { // ==
   204  		return n
   205  	}
   206  
   207  	/*
   208  		//非递归方式
   209  		findN := n
   210  		for findN != nil {
   211  			cmp := t.compareFunc(v, findN.v)
   212  			if cmp > 0 {
   213  				findN = n.right
   214  			} else if cmp < 0 {
   215  				findN = n.left
   216  			} else {
   217  				break
   218  			}
   219  		}
   220  		return findN
   221  	*/
   222  }
   223  
   224  func (t *AVLTree) Contains(v interface{}) bool {
   225  	return t.findNode(t.root, v) != nil
   226  }
   227  
   228  func (t *AVLTree) Remove(v interface{}) bool {
   229  	if !t.Contains(v) { // 不存在删除失败
   230  		return false
   231  	}
   232  	n := t.remove(t.root, v)
   233  	t.root = n
   234  	return true
   235  }
   236  
   237  // 移除对应元素,返回节点将挂载到调用者的子节点上
   238  func (t *AVLTree) remove(n *avlNode, v interface{}) *avlNode {
   239  	if n == nil {
   240  		return nil
   241  	}
   242  	cmp := t.compareFunc(v, n.v)
   243  	var retNode *avlNode // 返回父节点
   244  	if cmp < 0 {
   245  		n.left = t.remove(n.left, v)
   246  		retNode = n
   247  	} else if cmp > 0 {
   248  		n.right = t.remove(n.right, v)
   249  		retNode = n
   250  	} else { // ==
   251  		if n.left == nil && n.right == nil { // 待删除节点左右都为空
   252  			// 返回nil 给他的父节点
   253  			retNode = nil
   254  		} else if n.left == nil { // 1. 待删除节点左子树为空的情况
   255  			// 将右子树的数据反给上层
   256  			tn := n.right
   257  			n.left = nil
   258  			t.size--
   259  			retNode = tn
   260  		} else if n.right == nil { // 2. 待删除节点右子树为空的情况
   261  			tn := n.left
   262  			n.left = nil
   263  			t.size--
   264  			retNode = tn
   265  		} else { // 3. 左右都有数据
   266  			// 找到比待删除节点大的最小节点, 即待删除节点右子树的最小节点
   267  			// 用这个节点顶替待删除节点的位置
   268  			successor := t.minimum(n.right)
   269  			//removeMin中已经 size--,外面不需要 --
   270  			successor.right = t.removeMin(n.right)
   271  			successor.left = n.left
   272  			n.left = nil
   273  			n.right = nil
   274  			retNode = successor
   275  		}
   276  	}
   277  	if retNode == nil {
   278  		return nil
   279  	}
   280  
   281  	// 维护更新height
   282  	retNode.height = max(getHeight(retNode.left), getHeight(retNode.right)) + 1
   283  
   284  	// 计算平衡因子
   285  	balanceFactor := getBalanceFactor(retNode)
   286  
   287  	// 维护平衡操作
   288  	if balanceFactor > 1 && getBalanceFactor(retNode.left) >= 0 { //LL
   289  		retNode = t.rightRotate(retNode)
   290  	} else if balanceFactor < -1 && getBalanceFactor(retNode.right) <= 0 { //RR
   291  		retNode = t.leftRotate(retNode)
   292  	} else if balanceFactor > 1 && getBalanceFactor(retNode.left) < 0 { //LR
   293  		retNode.left = t.leftRotate(retNode.left)
   294  		retNode = t.rightRotate(retNode)
   295  	} else if balanceFactor < -1 && getBalanceFactor(retNode.right) > 0 { //RL
   296  		retNode.right = t.rightRotate(retNode.right)
   297  		retNode = t.leftRotate(retNode)
   298  	}
   299  	return retNode
   300  
   301  }
   302  
   303  //  返回以node为根的二分搜索树的最小值所在的节点
   304  func (t *AVLTree) minimum(n *avlNode) *avlNode {
   305  	if n.left == nil {
   306  		return n
   307  	}
   308  	return t.minimum(n.left)
   309  }
   310  
   311  // 删除掉以node为根的二分搜索树中的最小节点
   312  // 返回删除节点后新的二分搜索树的根 和 是否删除成功
   313  func (t *AVLTree) removeMin(n *avlNode) *avlNode {
   314  	if n.left == nil {
   315  		// 将要删除的右节点挂载父节点上,通过返回值返给服节点
   316  		tn := n.right
   317  		n.right = nil // 置空 gc回收
   318  		t.size--
   319  		return tn
   320  	}
   321  	n.left = t.removeMin(n.left)
   322  	return n
   323  }
   324  
   325  func (t *AVLTree) String() string {
   326  	var sb strings.Builder
   327  	generateBSTLevelString(t.root, &sb)
   328  	return sb.String()
   329  }
   330  
   331  func generateBSTLevelString(n *avlNode, sb *strings.Builder) {
   332  	l := list.New()
   333  	l.PushBack(n)
   334  
   335  	curNodeCnt := 1  // 当前行的节点
   336  	nextNodeCnt := 0 // 下一行的节点
   337  	for l.Len() != 0 {
   338  		n, _ := l.Remove(l.Front()).(*avlNode)
   339  
   340  		sb.WriteString(fmt.Sprintf("%v ", n))
   341  		curNodeCnt--
   342  		if n.left != nil {
   343  			l.PushBack(n.left)
   344  			nextNodeCnt++
   345  		}
   346  		if n.right != nil {
   347  			l.PushBack(n.right)
   348  			nextNodeCnt++
   349  		}
   350  		if curNodeCnt == 0 { // 当前行打印完了
   351  			sb.WriteRune('\n')
   352  			curNodeCnt = nextNodeCnt
   353  			nextNodeCnt = 0
   354  		}
   355  	}
   356  }