github.com/richardwilkes/toolbox@v1.121.0/collection/redblack/tree.go (about)

     1  // Copyright (c) 2016-2024 by Richard A. Wilkes. All rights reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the Mozilla Public
     4  // License, version 2.0. If a copy of the MPL was not distributed with
     5  // this file, You can obtain one at http://mozilla.org/MPL/2.0/.
     6  //
     7  // This Source Code Form is "Incompatible With Secondary Licenses", as
     8  // defined by the Mozilla Public License, version 2.0.
     9  
    10  package redblack
    11  
    12  // Tree implements a Red-black Tree, as described here: https://en.wikipedia.org/wiki/Red–black_tree
    13  type Tree[K, V any] struct {
    14  	root    *node[K, V]
    15  	compare func(a, b K) int
    16  	count   int
    17  }
    18  
    19  // New creates a new red-black tree.
    20  func New[K, V any](compareFunc func(a, b K) int) *Tree[K, V] {
    21  	return &Tree[K, V]{compare: compareFunc}
    22  }
    23  
    24  // Empty returns true if the tree is empty.
    25  func (t *Tree[K, V]) Empty() bool {
    26  	return t.count == 0
    27  }
    28  
    29  // Count returns the number of nodes in the tree.
    30  func (t *Tree[K, V]) Count() int {
    31  	return t.count
    32  }
    33  
    34  // Get returns the first value that matches the given key.
    35  func (t *Tree[K, V]) Get(key K) (value V, exists bool) {
    36  	if n := t.root.find(t.compare, key); n != nil {
    37  		return n.value, true
    38  	}
    39  	return value, false
    40  }
    41  
    42  // First returns the first value in the tree.
    43  func (t *Tree[K, V]) First() (value V, exists bool) {
    44  	n := t.root
    45  	if n == nil {
    46  		return value, false
    47  	}
    48  	for n.left != nil {
    49  		n = n.left
    50  	}
    51  	return n.value, true
    52  }
    53  
    54  // Last returns the last value in the tree.
    55  func (t *Tree[K, V]) Last() (value V, exists bool) {
    56  	n := t.root
    57  	if n == nil {
    58  		return value, false
    59  	}
    60  	for n.right != nil {
    61  		n = n.right
    62  	}
    63  	return n.value, true
    64  }
    65  
    66  // Dump a text version of the tree for debugging purposes.
    67  func (t *Tree[K, V]) Dump() {
    68  	t.root.dump(0, "")
    69  }
    70  
    71  // Traverse the tree, calling visitorFunc for each node, in order. If the visitorFunc returns false, the traversal will
    72  // be aborted.
    73  func (t *Tree[K, V]) Traverse(visitorFunc func(key K, value V) bool) {
    74  	t.root.traverse(visitorFunc)
    75  }
    76  
    77  // TraverseStartingAt traverses the tree starting with the first node whose key is equal to or greater than the given
    78  // key, calling visitorFunc for each node, in order. If the visitorFunc returns false, the traversal will be aborted.
    79  func (t *Tree[K, V]) TraverseStartingAt(key K, visitorFunc func(key K, value V) bool) {
    80  	t.root.traverseEqualOrGreater(t.compare, key, visitorFunc)
    81  }
    82  
    83  // ReverseTraverse traverses the tree, calling visitorFunc for each node, in reverse order. If the visitorFunc returns
    84  // false, the traversal will be aborted.
    85  func (t *Tree[K, V]) ReverseTraverse(visitorFunc func(key K, value V) bool) {
    86  	t.root.reverseTraverse(visitorFunc)
    87  }
    88  
    89  // ReverseTraverseStartingAt traverses the tree starting with the last node whose key is equal to or less than the given
    90  // key, calling visitorFunc for each node, in order. If the visitorFunc returns false, the traversal will be aborted.
    91  func (t *Tree[K, V]) ReverseTraverseStartingAt(key K, visitorFunc func(key K, value V) bool) {
    92  	t.root.traverseEqualOrLess(t.compare, key, visitorFunc)
    93  }
    94  
    95  // Insert a node into the tree.
    96  func (t *Tree[K, V]) Insert(key K, value V) {
    97  	n := &node[K, V]{key: key, value: value}
    98  	cur := t.root
    99  	n.parent = t.root
   100  	for cur != nil {
   101  		n.parent = cur
   102  		if t.compare(key, cur.key) < 0 {
   103  			cur = cur.left
   104  		} else {
   105  			cur = cur.right
   106  		}
   107  	}
   108  	if n.parent == nil {
   109  		t.root = n
   110  	} else {
   111  		if t.compare(key, n.parent.key) < 0 {
   112  			n.parent.left = n
   113  		} else {
   114  			n.parent.right = n
   115  		}
   116  	}
   117  	if n.parent != nil {
   118  		parent := n.parent
   119  		grandParent := parent.parent
   120  		for grandParent != nil && parent.isRed() {
   121  			if parent == grandParent.left {
   122  				uncle := grandParent.right
   123  				switch {
   124  				case uncle.isRed():
   125  					parent.black = true
   126  					uncle.black = true
   127  					grandParent.black = false
   128  					n = grandParent
   129  					parent = n.parent
   130  					if parent != nil {
   131  						grandParent = parent.parent
   132  					} else {
   133  						grandParent = nil
   134  					}
   135  				case n == parent.right:
   136  					n, parent = parent, n
   137  					t.rotateLeft(n)
   138  				default:
   139  					parent.black = true
   140  					grandParent.black = false
   141  					t.rotateRight(grandParent)
   142  				}
   143  			} else {
   144  				uncle := grandParent.left
   145  				switch {
   146  				case uncle.isRed():
   147  					parent.black = true
   148  					uncle.black = true
   149  					grandParent.black = false
   150  					n = grandParent
   151  					parent = n.parent
   152  					if parent != nil {
   153  						grandParent = parent.parent
   154  					} else {
   155  						grandParent = nil
   156  					}
   157  				case n == parent.left:
   158  					n, parent = parent, n
   159  					t.rotateRight(n)
   160  				default:
   161  					parent.black = true
   162  					grandParent.black = false
   163  					t.rotateLeft(grandParent)
   164  				}
   165  			}
   166  		}
   167  	}
   168  	t.root.black = true
   169  	t.count++
   170  }
   171  
   172  func (t *Tree[K, V]) rotateLeft(n *node[K, V]) {
   173  	right := n.right
   174  	n.right = right.left
   175  	if right.left != nil {
   176  		n.right.parent = n
   177  	}
   178  	right.parent = n.parent
   179  	if n.parent != nil {
   180  		if n.parent.left == n {
   181  			n.parent.left = right
   182  		} else {
   183  			n.parent.right = right
   184  		}
   185  	} else {
   186  		t.root = right
   187  	}
   188  	right.left = n
   189  	n.parent = right
   190  }
   191  
   192  func (t *Tree[K, V]) rotateRight(n *node[K, V]) {
   193  	left := n.left
   194  	n.left = left.right
   195  	if left.right != nil {
   196  		n.left.parent = n
   197  	}
   198  	left.parent = n.parent
   199  	if n.parent != nil {
   200  		if n.parent.right == n {
   201  			n.parent.right = left
   202  		} else {
   203  			n.parent.left = left
   204  		}
   205  	} else {
   206  		t.root = left
   207  	}
   208  	left.right = n
   209  	n.parent = left
   210  }
   211  
   212  // Remove a node from the tree. Note that if the key is not unique within the tree, the first key that matches on
   213  // traversal will be chosen as the one to remove.
   214  func (t *Tree[K, V]) Remove(key K) {
   215  	n := t.root.find(t.compare, key)
   216  	if n == nil {
   217  		return
   218  	}
   219  	splice := n
   220  	if n.left != nil && n.right != nil {
   221  		splice = n.right
   222  		for splice.left != nil {
   223  			splice = splice.left
   224  		}
   225  	}
   226  	var child *node[K, V]
   227  	if splice.left != nil {
   228  		child = splice.left
   229  	} else {
   230  		child = splice.right
   231  	}
   232  	if child != nil {
   233  		child.parent = splice.parent
   234  	}
   235  	if splice.parent != nil {
   236  		left := false
   237  		parent := splice.parent
   238  		if splice == parent.left {
   239  			parent.left = child
   240  			left = true
   241  		} else {
   242  			parent.right = child
   243  		}
   244  		if splice != n {
   245  			n.key, splice.key = splice.key, n.key
   246  			n.value, splice.value = splice.value, n.value
   247  		}
   248  		if splice.black {
   249  			if child != nil {
   250  				t.recolor(child)
   251  			} else {
   252  				child = splice
   253  				child.parent = parent
   254  				child.left = nil
   255  				child.right = nil
   256  				if left {
   257  					parent.left = child
   258  				} else {
   259  					parent.right = child
   260  				}
   261  				t.recolor(child)
   262  				if left {
   263  					parent.left = nil
   264  				} else {
   265  					parent.right = nil
   266  				}
   267  			}
   268  		}
   269  	} else {
   270  		t.root = child
   271  	}
   272  	if t.root != nil {
   273  		t.root.black = true
   274  	}
   275  	t.count--
   276  }
   277  
   278  func (t *Tree[K, V]) recolor(n *node[K, V]) {
   279  	for n != t.root && n.isBlack() {
   280  		parent := n.parent
   281  		switch {
   282  		case parent.left == n:
   283  			if sibling := parent.right; sibling != nil {
   284  				if sibling.isRed() {
   285  					sibling.black = true
   286  					parent.black = false
   287  					t.rotateLeft(parent)
   288  					parent = n.parent
   289  					sibling = parent.right
   290  				}
   291  				if sibling.left.isBlack() && sibling.right.isBlack() {
   292  					sibling.black = false
   293  					n = n.parent
   294  				} else {
   295  					if sibling.right.isBlack() {
   296  						sibling.left.black = true
   297  						sibling.black = false
   298  						t.rotateRight(sibling)
   299  						sibling = parent.right
   300  					}
   301  					sibling.black = parent.black
   302  					parent.black = true
   303  					sibling.right.black = true
   304  					t.rotateLeft(parent)
   305  					n = t.root
   306  				}
   307  			}
   308  		case parent.right == n:
   309  			if sibling := parent.left; sibling != nil {
   310  				if sibling.isRed() {
   311  					sibling.black = true
   312  					parent.black = false
   313  					t.rotateRight(parent)
   314  					parent = n.parent
   315  					sibling = parent.left
   316  				}
   317  				if sibling.right.isBlack() && sibling.left.isBlack() {
   318  					sibling.black = false
   319  					n = n.parent
   320  				} else {
   321  					if sibling.left.isBlack() {
   322  						sibling.right.black = true
   323  						sibling.black = false
   324  						t.rotateLeft(sibling)
   325  						sibling = parent.left
   326  					}
   327  					sibling.black = parent.black
   328  					parent.black = true
   329  					sibling.left.black = true
   330  					t.rotateRight(parent)
   331  					n = t.root
   332  				}
   333  			}
   334  		default:
   335  			parent.black = true
   336  		}
   337  	}
   338  	n.black = true
   339  }