github.com/chenjiandongx/go-queue@v0.0.0-20191023082232-e2a36f382f84/avl_tree.go (about)

     1  package collections
     2  
     3  type avlNode struct {
     4  	h     int
     5  	value int
     6  	left  *avlNode
     7  	right *avlNode
     8  }
     9  
    10  type AVLTree struct {
    11  	tree *avlNode
    12  }
    13  
    14  // 生成 AVL 树
    15  func NewAVLTree() *AVLTree {
    16  	return &AVLTree{&avlNode{h: -2}}
    17  }
    18  
    19  // 插入节点
    20  func (a *AVLTree) Insert(v int) {
    21  	a.tree = insert(v, a.tree)
    22  }
    23  
    24  // 搜索节点
    25  func (a *AVLTree) Search(v int) bool {
    26  	return a.tree.search(v)
    27  }
    28  
    29  // 删除节点
    30  func (a *AVLTree) Delete(v int) bool {
    31  	if a.tree.search(v) {
    32  		a.tree.delete(v)
    33  		return true
    34  	}
    35  	return false
    36  }
    37  
    38  // 获取所有节点中的最大值
    39  func (a *AVLTree) GetMaxValue() int {
    40  	return a.tree.maxNode().value
    41  }
    42  
    43  // 获取所有节点中的最小值
    44  func (a *AVLTree) GetMinValue() int {
    45  	return a.tree.minNode().value
    46  }
    47  
    48  // 返回排序后所有值
    49  func (a *AVLTree) AllValues() []int {
    50  	return a.tree.values()
    51  }
    52  
    53  func max(a, b int) int {
    54  	if a > b {
    55  		return a
    56  	}
    57  	return b
    58  }
    59  
    60  func insert(v int, t *avlNode) *avlNode {
    61  	if t == nil {
    62  		return &avlNode{value: v}
    63  	}
    64  	if t.h == -2 {
    65  		t.value = v
    66  		t.h = 0
    67  		return t
    68  	}
    69  
    70  	cmp := v - t.value
    71  	if cmp > 0 {
    72  		// 将节点插入到右子树中
    73  		t.right = insert(v, t.right)
    74  	} else if cmp < 0 {
    75  		// 将节点插入到左子树中
    76  		t.left = insert(v, t.left)
    77  	}
    78  	// 维持树平衡
    79  	t = t.keepBalance(v)
    80  	t.h = max(t.left.height(), t.right.height()) + 1
    81  	return t
    82  }
    83  
    84  func (t *avlNode) search(v int) bool {
    85  	if t == nil {
    86  		return false
    87  	}
    88  	cmp := v - t.value
    89  	if cmp > 0 {
    90  		// 如果 v 大于当前节点值,继续从右子树中寻找
    91  		return t.right.search(v)
    92  	} else if cmp < 0 {
    93  		// 如果 v 小于当前节点值,继续从左子树中寻找
    94  		return t.left.search(v)
    95  	} else {
    96  		// 相等则表示找到
    97  		return true
    98  	}
    99  }
   100  
   101  func (t *avlNode) delete(v int) *avlNode {
   102  	if t == nil {
   103  		return t
   104  	}
   105  	cmp := v - t.value
   106  	if cmp > 0 {
   107  		// 如果 v 大于当前节点值,继续从右子树中删除
   108  		t.right = t.right.delete(v)
   109  	} else if cmp < 0 {
   110  		// 如果 v 小于当前节点值,继续从左子树中删除
   111  		t.left = t.left.delete(v)
   112  	} else {
   113  		// 找到 v
   114  		if t.left != nil && t.right != nil {
   115  			// 如果该节点既有左子树又有右子树
   116  			// 使用右子树中的最小节点取代删除节点,然后删除右子树中的最小节点
   117  			t.value = t.right.minNode().value
   118  			t.right = t.right.delete(t.value)
   119  		} else if t.left != nil {
   120  			// 如果只有左子树,则直接删除节点
   121  			t = t.left
   122  		} else {
   123  			// 只有右子树或空树
   124  			t = t.right
   125  		}
   126  	}
   127  
   128  	if t != nil {
   129  		t.h = max(t.left.height(), t.right.height()) + 1
   130  		t = t.keepBalance(v)
   131  	}
   132  	return t
   133  }
   134  
   135  func (t *avlNode) minNode() *avlNode {
   136  	if t == nil {
   137  		return nil
   138  	}
   139  	// 整棵树的最左边节点就是值最小的节点
   140  	if t.left == nil {
   141  		return t
   142  	} else {
   143  		return t.left.minNode()
   144  	}
   145  }
   146  
   147  func (t *avlNode) maxNode() *avlNode {
   148  	if t == nil {
   149  		return nil
   150  	}
   151  	// 整棵树的最右边节点就是值最大的节点
   152  	if t.right == nil {
   153  		return t
   154  	} else {
   155  		return t.right.maxNode()
   156  	}
   157  }
   158  
   159  /*
   160  左左情况:右旋
   161  		*
   162  	   *
   163  	  *
   164  */
   165  func (t *avlNode) llRotate() *avlNode {
   166  	node := t.left
   167  	t.left = node.right
   168  	node.right = t
   169  
   170  	node.h = max(node.left.height(), node.right.height()) + 1
   171  	t.h = max(t.left.height(), t.right.height()) + 1
   172  	return node
   173  }
   174  
   175  /*
   176  右右情况:左旋
   177  		*
   178  	     *
   179  	      *
   180  */
   181  func (t *avlNode) rrRotate() *avlNode {
   182  	node := t.right
   183  	t.right = node.left
   184  	node.left = t
   185  
   186  	node.h = max(node.left.height(), node.right.height()) + 1
   187  	t.h = max(t.left.height(), t.right.height()) + 1
   188  	return node
   189  }
   190  
   191  /*
   192  左右情况:先左旋 后右旋
   193  		*
   194  	   *
   195  	    *
   196  */
   197  func (t *avlNode) lrRotate() *avlNode {
   198  	t.left = t.left.rrRotate()
   199  	return t.llRotate()
   200  }
   201  
   202  /*
   203  右左情况:先右旋 后左旋
   204  		*
   205  	     *
   206          *
   207  */
   208  func (t *avlNode) rlRotate() *avlNode {
   209  	t.right = t.right.llRotate()
   210  	return t.rrRotate()
   211  }
   212  
   213  func (t *avlNode) keepBalance(v int) *avlNode {
   214  	// 左子树失衡
   215  	if t.left.height()-t.right.height() == 2 {
   216  		if v-t.left.value < 0 {
   217  			// 当插入的节点在失衡节点的左子树的左子树中,直接右旋
   218  			t = t.llRotate()
   219  		} else {
   220  			// 当插入的节点在失衡节点的左子树的右子树中,先左旋后右旋
   221  			t = t.lrRotate()
   222  		}
   223  	} else if t.right.height()-t.left.height() == 2 {
   224  		if t.right.right.height() > t.right.left.height() {
   225  			// 当插入的节点在失衡节点的右子树的右子树中,直接左旋
   226  			t = t.rrRotate()
   227  		} else {
   228  			// 当插入的节点在失衡节点的右子树的左子树中,先右旋后左旋
   229  			t = t.rlRotate()
   230  		}
   231  	}
   232  	// 调整树高度
   233  	t.h = max(t.left.height(), t.right.height()) + 1
   234  	return t
   235  }
   236  
   237  func (t *avlNode) height() int {
   238  	if t != nil {
   239  		return t.h
   240  	}
   241  	return -1
   242  }
   243  
   244  // 中序遍历按顺序获取所有值
   245  func appendValue(values []int, t *avlNode) []int {
   246  	if t != nil {
   247  		values = appendValue(values, t.left)
   248  		values = append(values, t.value)
   249  		values = appendValue(values, t.right)
   250  	}
   251  	return values
   252  }
   253  
   254  func (t *avlNode) values() []int {
   255  	values := make([]int, 0)
   256  	return appendValue(values, t)
   257  }