go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/collections/binary_search_tree.go (about)

     1  /*
     2  
     3  Copyright (c) 2023 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package collections
     9  
    10  import "cmp"
    11  
    12  // BinarySearchTree is a AVL balanced tree which holds the properties
    13  // that nodes are ordered left to right.
    14  //
    15  // The choice to use AVL to balance the tree means the use cases skew
    16  // towards fast lookups at the expense of more costly mutations.
    17  type BinarySearchTree[K cmp.Ordered, V any] struct {
    18  	root *BinarySearchTreeNode[K, V]
    19  }
    20  
    21  // Insert adds a new value to the binary search tree.
    22  func (bst *BinarySearchTree[K, V]) Insert(k K, v V) {
    23  	bst.root = bst._insert(bst.root, k, v)
    24  }
    25  
    26  // Delete deletes a value from the tree, and returns if it existed.
    27  func (bst *BinarySearchTree[K, V]) Delete(k K) {
    28  	bst.root = bst._delete(bst.root, k)
    29  }
    30  
    31  // Search searches for a node with a given key, returning the value
    32  // and a boolean indicating the key was found.
    33  func (bst *BinarySearchTree[K, V]) Search(k K) (v V, ok bool) {
    34  	v, ok = bst._search(bst.root, k)
    35  	return
    36  }
    37  
    38  // Min returns the minimum key and value.
    39  func (bst *BinarySearchTree[K, V]) Min() (k K, v V, ok bool) {
    40  	if bst.root == nil {
    41  		return
    42  	}
    43  	k, v, ok = bst.root.Key, bst.root.Value, true
    44  	current := bst.root
    45  	for current.Left != nil {
    46  		current = current.Left
    47  		k, v = current.Key, current.Value
    48  	}
    49  	return
    50  }
    51  
    52  // Max returns the maximum key and value.
    53  func (bst *BinarySearchTree[K, V]) Max() (k K, v V, ok bool) {
    54  	if bst.root == nil {
    55  		return
    56  	}
    57  	k, v, ok = bst.root.Key, bst.root.Value, true
    58  	current := bst.root
    59  	for current.Right != nil {
    60  		current = current.Right
    61  		k, v = current.Key, current.Value
    62  	}
    63  	return
    64  }
    65  
    66  // InOrder traversal returns the sorted values in the tree.
    67  func (bst *BinarySearchTree[K, V]) InOrder(fn func(K, V)) {
    68  	bst._inOrder(bst.root, fn)
    69  }
    70  
    71  // PreOrder traversal returns the values in the tree in pre-order.
    72  func (bst *BinarySearchTree[K, V]) PreOrder(fn func(K, V)) {
    73  	bst._preOrder(bst.root, fn)
    74  }
    75  
    76  // PostOrder traversal returns the values in the tree in post-order.
    77  func (bst *BinarySearchTree[K, V]) PostOrder(fn func(K, V)) {
    78  	bst._postOrder(bst.root, fn)
    79  }
    80  
    81  // KeysEqual is a function that can be used to deeply compare two trees based on their keys.
    82  //
    83  // Values are _not_ considered because values are not comparable by design.
    84  func (bst *BinarySearchTree[K, V]) KeysEqual(other *BinarySearchTree[K, V]) bool {
    85  	return bst._keysEqual(bst.root, other.root)
    86  }
    87  
    88  //
    89  // internal methods
    90  //
    91  
    92  func (bst *BinarySearchTree[K, V]) _height(n *BinarySearchTreeNode[K, V]) int {
    93  	if n == nil {
    94  		return 0
    95  	}
    96  	return n.Height
    97  }
    98  
    99  func (bst *BinarySearchTree[K, V]) _inOrder(n *BinarySearchTreeNode[K, V], fn func(K, V)) {
   100  	if n == nil {
   101  		return
   102  	}
   103  	bst._inOrder(n.Left, fn)
   104  	fn(n.Key, n.Value)
   105  	bst._inOrder(n.Right, fn)
   106  }
   107  
   108  func (bst *BinarySearchTree[K, V]) _preOrder(n *BinarySearchTreeNode[K, V], fn func(K, V)) {
   109  	if n == nil {
   110  		return
   111  	}
   112  	fn(n.Key, n.Value)
   113  	bst._preOrder(n.Left, fn)
   114  	bst._preOrder(n.Right, fn)
   115  }
   116  
   117  func (bst *BinarySearchTree[K, V]) _postOrder(n *BinarySearchTreeNode[K, V], fn func(K, V)) {
   118  	if n == nil {
   119  		return
   120  	}
   121  	bst._postOrder(n.Left, fn)
   122  	bst._postOrder(n.Right, fn)
   123  	fn(n.Key, n.Value)
   124  }
   125  
   126  func (bst *BinarySearchTree[K, V]) _insert(n *BinarySearchTreeNode[K, V], k K, v V) *BinarySearchTreeNode[K, V] {
   127  	if n == nil {
   128  		return &BinarySearchTreeNode[K, V]{
   129  			Key:    k,
   130  			Value:  v,
   131  			Height: 1,
   132  		}
   133  	}
   134  
   135  	if k < n.Key {
   136  		n.Left = bst._insert(n.Left, k, v)
   137  	} else if k > n.Key {
   138  		n.Right = bst._insert(n.Right, k, v)
   139  	} else {
   140  		n.Value = v
   141  		return n
   142  	}
   143  
   144  	n.Height = max(bst._height(n.Left), bst._height(n.Right)) + 1
   145  
   146  	balanceFactor := bst._getBalanceFactor(n)
   147  	if balanceFactor > 1 && k < n.Left.Key {
   148  		return bst._rotateRight(n)
   149  	}
   150  	if balanceFactor < -1 && k > n.Right.Key {
   151  		return bst._rotateLeft(n)
   152  	}
   153  	if balanceFactor > 1 && k > n.Left.Key {
   154  		n.Left = bst._rotateLeft(n.Left)
   155  		return bst._rotateRight(n)
   156  	}
   157  	if balanceFactor < -1 && k < n.Right.Key {
   158  		n.Right = bst._rotateRight(n.Right)
   159  		return bst._rotateLeft(n)
   160  	}
   161  	return n
   162  }
   163  
   164  func (bst *BinarySearchTree[K, V]) _delete(n *BinarySearchTreeNode[K, V], k K) *BinarySearchTreeNode[K, V] {
   165  	if n == nil {
   166  		return nil
   167  	}
   168  
   169  	if k < n.Key {
   170  		n.Left = bst._delete(n.Left, k)
   171  	} else if k > n.Key {
   172  		n.Right = bst._delete(n.Right, k)
   173  	} else {
   174  		if n.Left == nil || n.Right == nil {
   175  			var temp *BinarySearchTreeNode[K, V]
   176  			if n.Left == nil {
   177  				temp = n.Right
   178  			} else {
   179  				temp = n.Left
   180  			}
   181  			if temp == nil {
   182  				n = nil
   183  			} else {
   184  				n = temp
   185  			}
   186  		} else {
   187  			temp := bst._searchMin(n.Right)
   188  			n.Key, n.Value = temp.Key, temp.Value
   189  			n.Right = bst._delete(n.Right, temp.Key)
   190  		}
   191  	}
   192  
   193  	if n == nil {
   194  		return nil
   195  	}
   196  
   197  	n.Height = max(bst._height(n.Left), bst._height(n.Right)) + 1
   198  
   199  	balanceFactor := bst._getBalanceFactor(n)
   200  	if balanceFactor > 1 && bst._getBalanceFactor(n.Left) >= 0 {
   201  		return bst._rotateRight(n)
   202  	}
   203  	if balanceFactor > 1 && bst._getBalanceFactor(n.Left) < 0 {
   204  		n.Left = bst._rotateLeft(n.Left)
   205  		return bst._rotateRight(n)
   206  	}
   207  
   208  	if balanceFactor < -1 && bst._getBalanceFactor(n.Right) <= 0 {
   209  		return bst._rotateLeft(n)
   210  	}
   211  
   212  	if balanceFactor < -1 && bst._getBalanceFactor(n.Right) > 0 {
   213  		n.Right = bst._rotateRight(n.Right)
   214  		return bst._rotateLeft(n)
   215  	}
   216  	return n
   217  }
   218  
   219  func (bst *BinarySearchTree[K, V]) _searchMin(n *BinarySearchTreeNode[K, V]) (min *BinarySearchTreeNode[K, V]) {
   220  	min = n
   221  	for min.Left != nil {
   222  		min = min.Left
   223  	}
   224  	return
   225  }
   226  
   227  func (bst *BinarySearchTree[K, V]) _search(n *BinarySearchTreeNode[K, V], k K) (v V, ok bool) {
   228  	if n == nil {
   229  		return
   230  	}
   231  	if n.Key == k {
   232  		v = n.Value
   233  		ok = true
   234  		return
   235  	}
   236  	if k < n.Key {
   237  		v, ok = bst._search(n.Left, k)
   238  		return
   239  	}
   240  	v, ok = bst._search(n.Right, k)
   241  	return
   242  }
   243  
   244  func (bst *BinarySearchTree[K, V]) _rotateRight(y *BinarySearchTreeNode[K, V]) *BinarySearchTreeNode[K, V] {
   245  	if y.Left == nil {
   246  		return y
   247  	}
   248  	x := y.Left
   249  	t2 := x.Right
   250  	x.Right = y
   251  	y.Left = t2
   252  	y.Height = max(bst._height(y.Left), bst._height(y.Right)) + 1
   253  	x.Height = max(bst._height(x.Left), bst._height(x.Right)) + 1
   254  	return x
   255  }
   256  
   257  func (bst *BinarySearchTree[K, V]) _rotateLeft(x *BinarySearchTreeNode[K, V]) *BinarySearchTreeNode[K, V] {
   258  	if x.Right == nil {
   259  		return x
   260  	}
   261  
   262  	y := x.Right
   263  	t2 := y.Left
   264  	y.Left = x
   265  	x.Right = t2
   266  	x.Height = max(bst._height(x.Left), bst._height(x.Right)) + 1
   267  	y.Height = max(bst._height(y.Left), bst._height(y.Right)) + 1
   268  	return y
   269  }
   270  
   271  func (bst *BinarySearchTree[K, V]) _getBalanceFactor(n *BinarySearchTreeNode[K, V]) int {
   272  	if n == nil {
   273  		return 0
   274  	}
   275  	return bst._height(n.Left) - bst._height(n.Right)
   276  }
   277  
   278  func (bst *BinarySearchTree[K, V]) _keysEqual(a, b *BinarySearchTreeNode[K, V]) bool {
   279  	if a == nil && b == nil {
   280  		return true
   281  	}
   282  	if a != nil && b == nil {
   283  		return false
   284  	}
   285  	if a == nil && b != nil {
   286  		return false
   287  	}
   288  	if a.Key != b.Key {
   289  		return false
   290  	}
   291  	if a.Height != b.Height {
   292  		return false
   293  	}
   294  	return bst._keysEqual(a.Left, b.Left) && bst._keysEqual(a.Right, b.Right)
   295  }
   296  
   297  // BinarySearchTreeNode is a node in a BinarySearchTree.
   298  type BinarySearchTreeNode[K cmp.Ordered, V any] struct {
   299  	Key    K
   300  	Value  V
   301  	Left   *BinarySearchTreeNode[K, V]
   302  	Right  *BinarySearchTreeNode[K, V]
   303  	Height int
   304  }
   305  
   306  func max[K cmp.Ordered](keys ...K) (k K) {
   307  	if len(keys) == 0 {
   308  		return
   309  	}
   310  	k = keys[0]
   311  	for x := 1; x < len(keys); x++ {
   312  		if keys[x] > k {
   313  			k = keys[x]
   314  		}
   315  	}
   316  	return
   317  }