github.com/songzhibin97/go-baseutils@v0.0.2-0.20240302024150-487d8ce9c082/structure/trees/redblacktree/redblacktree.go (about)

     1  // Package redblacktree implements a red-black tree.
     2  //
     3  // Used by TreeSet and TreeMap.
     4  //
     5  // Structure is not thread safe.
     6  //
     7  // References: http://en.wikipedia.org/wiki/Red%E2%80%93black_tree
     8  package redblacktree
     9  
    10  import (
    11  	"encoding/json"
    12  	"fmt"
    13  	"github.com/songzhibin97/go-baseutils/base/bcomparator"
    14  	"github.com/songzhibin97/go-baseutils/structure/trees"
    15  )
    16  
    17  // Assert Tree implementation
    18  var _ trees.Tree[any] = (*Tree[any, any])(nil)
    19  
    20  type color bool
    21  
    22  const (
    23  	black, red color = true, false
    24  )
    25  
    26  // Tree holds elements of the red-black tree
    27  type Tree[K, V any] struct {
    28  	Root       *Node[K, V]
    29  	size       int
    30  	Comparator bcomparator.Comparator[K]
    31  	zeroV      V
    32  }
    33  
    34  // Node is a single element within the tree
    35  type Node[K, V any] struct {
    36  	Key    K
    37  	Value  V
    38  	color  color
    39  	Left   *Node[K, V]
    40  	Right  *Node[K, V]
    41  	Parent *Node[K, V]
    42  }
    43  
    44  // NewWith instantiates a red-black tree with the custom comparator.
    45  func NewWith[K, V any](comparator bcomparator.Comparator[K]) *Tree[K, V] {
    46  	return &Tree[K, V]{Comparator: comparator}
    47  }
    48  
    49  // NewWithIntComparator instantiates a red-black tree with the IntComparator, i.e. keys are of type int.
    50  func NewWithIntComparator[V any]() *Tree[int, V] {
    51  	return &Tree[int, V]{Comparator: bcomparator.IntComparator()}
    52  }
    53  
    54  // NewWithStringComparator instantiates a red-black tree with the StringComparator, i.e. keys are of type string.
    55  func NewWithStringComparator[V any]() *Tree[string, V] {
    56  	return &Tree[string, V]{Comparator: bcomparator.StringComparator()}
    57  }
    58  
    59  // Put inserts node into the tree.
    60  // Key should adhere to the comparator's type assertion, otherwise method panics.
    61  func (tree *Tree[K, V]) Put(key K, value V) {
    62  	var insertedNode *Node[K, V]
    63  	if tree.Root == nil {
    64  		// Assert key is of comparator's type for initial tree
    65  		tree.Comparator(key, key)
    66  		tree.Root = &Node[K, V]{Key: key, Value: value, color: red}
    67  		insertedNode = tree.Root
    68  	} else {
    69  		node := tree.Root
    70  		loop := true
    71  		for loop {
    72  			compare := tree.Comparator(key, node.Key)
    73  			switch {
    74  			case compare == 0:
    75  				node.Key = key
    76  				node.Value = value
    77  				return
    78  			case compare < 0:
    79  				if node.Left == nil {
    80  					node.Left = &Node[K, V]{Key: key, Value: value, color: red}
    81  					insertedNode = node.Left
    82  					loop = false
    83  				} else {
    84  					node = node.Left
    85  				}
    86  			case compare > 0:
    87  				if node.Right == nil {
    88  					node.Right = &Node[K, V]{Key: key, Value: value, color: red}
    89  					insertedNode = node.Right
    90  					loop = false
    91  				} else {
    92  					node = node.Right
    93  				}
    94  			}
    95  		}
    96  		insertedNode.Parent = node
    97  	}
    98  	tree.insertCase1(insertedNode)
    99  	tree.size++
   100  }
   101  
   102  // Get searches the node in the tree by key and returns its value or nil if key is not found in tree.
   103  // Second return parameter is true if key was found, otherwise false.
   104  // Key should adhere to the comparator's type assertion, otherwise method panics.
   105  func (tree *Tree[K, V]) Get(key K) (value V, found bool) {
   106  	node := tree.lookup(key)
   107  	if node != nil {
   108  		return node.Value, true
   109  	}
   110  	return tree.zeroV, false
   111  }
   112  
   113  // GetNode searches the node in the tree by key and returns its node or nil if key is not found in tree.
   114  // Key should adhere to the comparator's type assertion, otherwise method panics.
   115  func (tree *Tree[K, V]) GetNode(key K) *Node[K, V] {
   116  	return tree.lookup(key)
   117  }
   118  
   119  // Remove remove the node from the tree by key.
   120  // Key should adhere to the comparator's type assertion, otherwise method panics.
   121  func (tree *Tree[K, V]) Remove(key K) {
   122  	var child *Node[K, V]
   123  	node := tree.lookup(key)
   124  	if node == nil {
   125  		return
   126  	}
   127  	if node.Left != nil && node.Right != nil {
   128  		pred := node.Left.maximumNode()
   129  		node.Key = pred.Key
   130  		node.Value = pred.Value
   131  		node = pred
   132  	}
   133  	if node.Left == nil || node.Right == nil {
   134  		if node.Right == nil {
   135  			child = node.Left
   136  		} else {
   137  			child = node.Right
   138  		}
   139  		if node.color == black {
   140  			node.color = nodeColor(child)
   141  			tree.deleteCase1(node)
   142  		}
   143  		tree.replaceNode(node, child)
   144  		if node.Parent == nil && child != nil {
   145  			child.color = black
   146  		}
   147  	}
   148  	tree.size--
   149  }
   150  
   151  // Empty returns true if tree does not contain any nodes
   152  func (tree *Tree[K, V]) Empty() bool {
   153  	return tree.size == 0
   154  }
   155  
   156  // Size returns number of nodes in the tree.
   157  func (tree *Tree[K, V]) Size() int {
   158  	return tree.size
   159  }
   160  
   161  // Size returns the number of elements stored in the subtree.
   162  // Computed dynamically on each call, i.e. the subtree is traversed to count the number of the nodes.
   163  func (node *Node[K, V]) Size() int {
   164  	if node == nil {
   165  		return 0
   166  	}
   167  	size := 1
   168  	if node.Left != nil {
   169  		size += node.Left.Size()
   170  	}
   171  	if node.Right != nil {
   172  		size += node.Right.Size()
   173  	}
   174  	return size
   175  }
   176  
   177  // Keys returns all keys in-order
   178  func (tree *Tree[K, V]) Keys() []K {
   179  	keys := make([]K, tree.size)
   180  	it := tree.Iterator()
   181  	for i := 0; it.Next(); i++ {
   182  		keys[i] = it.Key()
   183  	}
   184  	return keys
   185  }
   186  
   187  // Values returns all values in-order based on the key.
   188  func (tree *Tree[K, V]) Values() []V {
   189  	values := make([]V, tree.size)
   190  	it := tree.Iterator()
   191  	for i := 0; it.Next(); i++ {
   192  		values[i] = it.Value()
   193  	}
   194  	return values
   195  }
   196  
   197  // Left returns the left-most (min) node or nil if tree is empty.
   198  func (tree *Tree[K, V]) Left() *Node[K, V] {
   199  	var parent *Node[K, V]
   200  	current := tree.Root
   201  	for current != nil {
   202  		parent = current
   203  		current = current.Left
   204  	}
   205  	return parent
   206  }
   207  
   208  // Right returns the right-most (max) node or nil if tree is empty.
   209  func (tree *Tree[K, V]) Right() *Node[K, V] {
   210  	var parent *Node[K, V]
   211  	current := tree.Root
   212  	for current != nil {
   213  		parent = current
   214  		current = current.Right
   215  	}
   216  	return parent
   217  }
   218  
   219  // Floor Finds floor node of the input key, return the floor node or nil if no floor is found.
   220  // Second return parameter is true if floor was found, otherwise false.
   221  //
   222  // Floor node is defined as the largest node that is smaller than or equal to the given node.
   223  // A floor node may not be found, either because the tree is empty, or because
   224  // all nodes in the tree are larger than the given node.
   225  //
   226  // Key should adhere to the comparator's type assertion, otherwise method panics.
   227  func (tree *Tree[K, V]) Floor(key K) (floor *Node[K, V], found bool) {
   228  	found = false
   229  	node := tree.Root
   230  	for node != nil {
   231  		compare := tree.Comparator(key, node.Key)
   232  		switch {
   233  		case compare == 0:
   234  			return node, true
   235  		case compare < 0:
   236  			node = node.Left
   237  		case compare > 0:
   238  			floor, found = node, true
   239  			node = node.Right
   240  		}
   241  	}
   242  	if found {
   243  		return floor, true
   244  	}
   245  	return nil, false
   246  }
   247  
   248  // Ceiling finds ceiling node of the input key, return the ceiling node or nil if no ceiling is found.
   249  // Second return parameter is true if ceiling was found, otherwise false.
   250  //
   251  // Ceiling node is defined as the smallest node that is larger than or equal to the given node.
   252  // A ceiling node may not be found, either because the tree is empty, or because
   253  // all nodes in the tree are smaller than the given node.
   254  //
   255  // Key should adhere to the comparator's type assertion, otherwise method panics.
   256  func (tree *Tree[K, V]) Ceiling(key K) (ceiling *Node[K, V], found bool) {
   257  	found = false
   258  	node := tree.Root
   259  	for node != nil {
   260  		compare := tree.Comparator(key, node.Key)
   261  		switch {
   262  		case compare == 0:
   263  			return node, true
   264  		case compare < 0:
   265  			ceiling, found = node, true
   266  			node = node.Left
   267  		case compare > 0:
   268  			node = node.Right
   269  		}
   270  	}
   271  	if found {
   272  		return ceiling, true
   273  	}
   274  	return nil, false
   275  }
   276  
   277  // Clear removes all nodes from the tree.
   278  func (tree *Tree[K, V]) Clear() {
   279  	tree.Root = nil
   280  	tree.size = 0
   281  }
   282  
   283  // String returns a string representation of container
   284  func (tree *Tree[K, V]) String() string {
   285  	str := "RedBlackTree\n"
   286  	if !tree.Empty() {
   287  		output(tree.Root, "", true, &str)
   288  	}
   289  	return str
   290  }
   291  
   292  func (node *Node[K, V]) String() string {
   293  	return fmt.Sprintf("%v", node.Key)
   294  }
   295  
   296  func output[K, V any](node *Node[K, V], prefix string, isTail bool, str *string) {
   297  	if node.Right != nil {
   298  		newPrefix := prefix
   299  		if isTail {
   300  			newPrefix += "│   "
   301  		} else {
   302  			newPrefix += "    "
   303  		}
   304  		output(node.Right, newPrefix, false, str)
   305  	}
   306  	*str += prefix
   307  	if isTail {
   308  		*str += "└── "
   309  	} else {
   310  		*str += "┌── "
   311  	}
   312  	*str += node.String() + "\n"
   313  	if node.Left != nil {
   314  		newPrefix := prefix
   315  		if isTail {
   316  			newPrefix += "    "
   317  		} else {
   318  			newPrefix += "│   "
   319  		}
   320  		output(node.Left, newPrefix, true, str)
   321  	}
   322  }
   323  
   324  func (tree *Tree[K, V]) lookup(key K) *Node[K, V] {
   325  	node := tree.Root
   326  	for node != nil {
   327  		compare := tree.Comparator(key, node.Key)
   328  		switch {
   329  		case compare == 0:
   330  			return node
   331  		case compare < 0:
   332  			node = node.Left
   333  		case compare > 0:
   334  			node = node.Right
   335  		}
   336  	}
   337  	return nil
   338  }
   339  
   340  func (node *Node[K, V]) grandparent() *Node[K, V] {
   341  	if node != nil && node.Parent != nil {
   342  		return node.Parent.Parent
   343  	}
   344  	return nil
   345  }
   346  
   347  func (node *Node[K, V]) uncle() *Node[K, V] {
   348  	if node == nil || node.Parent == nil || node.Parent.Parent == nil {
   349  		return nil
   350  	}
   351  	return node.Parent.sibling()
   352  }
   353  
   354  func (node *Node[K, V]) sibling() *Node[K, V] {
   355  	if node == nil || node.Parent == nil {
   356  		return nil
   357  	}
   358  	if node == node.Parent.Left {
   359  		return node.Parent.Right
   360  	}
   361  	return node.Parent.Left
   362  }
   363  
   364  func (tree *Tree[K, V]) rotateLeft(node *Node[K, V]) {
   365  	right := node.Right
   366  	tree.replaceNode(node, right)
   367  	node.Right = right.Left
   368  	if right.Left != nil {
   369  		right.Left.Parent = node
   370  	}
   371  	right.Left = node
   372  	node.Parent = right
   373  }
   374  
   375  func (tree *Tree[K, V]) rotateRight(node *Node[K, V]) {
   376  	left := node.Left
   377  	tree.replaceNode(node, left)
   378  	node.Left = left.Right
   379  	if left.Right != nil {
   380  		left.Right.Parent = node
   381  	}
   382  	left.Right = node
   383  	node.Parent = left
   384  }
   385  
   386  func (tree *Tree[K, V]) replaceNode(old *Node[K, V], new *Node[K, V]) {
   387  	if old.Parent == nil {
   388  		tree.Root = new
   389  	} else {
   390  		if old == old.Parent.Left {
   391  			old.Parent.Left = new
   392  		} else {
   393  			old.Parent.Right = new
   394  		}
   395  	}
   396  	if new != nil {
   397  		new.Parent = old.Parent
   398  	}
   399  }
   400  
   401  func (tree *Tree[K, V]) insertCase1(node *Node[K, V]) {
   402  	if node.Parent == nil {
   403  		node.color = black
   404  	} else {
   405  		tree.insertCase2(node)
   406  	}
   407  }
   408  
   409  func (tree *Tree[K, V]) insertCase2(node *Node[K, V]) {
   410  	if nodeColor(node.Parent) == black {
   411  		return
   412  	}
   413  	tree.insertCase3(node)
   414  }
   415  
   416  func (tree *Tree[K, V]) insertCase3(node *Node[K, V]) {
   417  	uncle := node.uncle()
   418  	if nodeColor(uncle) == red {
   419  		node.Parent.color = black
   420  		uncle.color = black
   421  		node.grandparent().color = red
   422  		tree.insertCase1(node.grandparent())
   423  	} else {
   424  		tree.insertCase4(node)
   425  	}
   426  }
   427  
   428  func (tree *Tree[K, V]) insertCase4(node *Node[K, V]) {
   429  	grandparent := node.grandparent()
   430  	if node == node.Parent.Right && node.Parent == grandparent.Left {
   431  		tree.rotateLeft(node.Parent)
   432  		node = node.Left
   433  	} else if node == node.Parent.Left && node.Parent == grandparent.Right {
   434  		tree.rotateRight(node.Parent)
   435  		node = node.Right
   436  	}
   437  	tree.insertCase5(node)
   438  }
   439  
   440  func (tree *Tree[K, V]) insertCase5(node *Node[K, V]) {
   441  	node.Parent.color = black
   442  	grandparent := node.grandparent()
   443  	grandparent.color = red
   444  	if node == node.Parent.Left && node.Parent == grandparent.Left {
   445  		tree.rotateRight(grandparent)
   446  	} else if node == node.Parent.Right && node.Parent == grandparent.Right {
   447  		tree.rotateLeft(grandparent)
   448  	}
   449  }
   450  
   451  func (node *Node[K, V]) maximumNode() *Node[K, V] {
   452  	if node == nil {
   453  		return nil
   454  	}
   455  	for node.Right != nil {
   456  		node = node.Right
   457  	}
   458  	return node
   459  }
   460  
   461  func (tree *Tree[K, V]) deleteCase1(node *Node[K, V]) {
   462  	if node.Parent == nil {
   463  		return
   464  	}
   465  	tree.deleteCase2(node)
   466  }
   467  
   468  func (tree *Tree[K, V]) deleteCase2(node *Node[K, V]) {
   469  	sibling := node.sibling()
   470  	if nodeColor(sibling) == red {
   471  		node.Parent.color = red
   472  		sibling.color = black
   473  		if node == node.Parent.Left {
   474  			tree.rotateLeft(node.Parent)
   475  		} else {
   476  			tree.rotateRight(node.Parent)
   477  		}
   478  	}
   479  	tree.deleteCase3(node)
   480  }
   481  
   482  func (tree *Tree[K, V]) deleteCase3(node *Node[K, V]) {
   483  	sibling := node.sibling()
   484  	if nodeColor(node.Parent) == black &&
   485  		nodeColor(sibling) == black &&
   486  		nodeColor(sibling.Left) == black &&
   487  		nodeColor(sibling.Right) == black {
   488  		sibling.color = red
   489  		tree.deleteCase1(node.Parent)
   490  	} else {
   491  		tree.deleteCase4(node)
   492  	}
   493  }
   494  
   495  func (tree *Tree[K, V]) deleteCase4(node *Node[K, V]) {
   496  	sibling := node.sibling()
   497  	if nodeColor(node.Parent) == red &&
   498  		nodeColor(sibling) == black &&
   499  		nodeColor(sibling.Left) == black &&
   500  		nodeColor(sibling.Right) == black {
   501  		sibling.color = red
   502  		node.Parent.color = black
   503  	} else {
   504  		tree.deleteCase5(node)
   505  	}
   506  }
   507  
   508  func (tree *Tree[K, V]) deleteCase5(node *Node[K, V]) {
   509  	sibling := node.sibling()
   510  	if node == node.Parent.Left &&
   511  		nodeColor(sibling) == black &&
   512  		nodeColor(sibling.Left) == red &&
   513  		nodeColor(sibling.Right) == black {
   514  		sibling.color = red
   515  		sibling.Left.color = black
   516  		tree.rotateRight(sibling)
   517  	} else if node == node.Parent.Right &&
   518  		nodeColor(sibling) == black &&
   519  		nodeColor(sibling.Right) == red &&
   520  		nodeColor(sibling.Left) == black {
   521  		sibling.color = red
   522  		sibling.Right.color = black
   523  		tree.rotateLeft(sibling)
   524  	}
   525  	tree.deleteCase6(node)
   526  }
   527  
   528  func (tree *Tree[K, V]) deleteCase6(node *Node[K, V]) {
   529  	sibling := node.sibling()
   530  	sibling.color = nodeColor(node.Parent)
   531  	node.Parent.color = black
   532  	if node == node.Parent.Left && nodeColor(sibling.Right) == red {
   533  		sibling.Right.color = black
   534  		tree.rotateLeft(node.Parent)
   535  	} else if nodeColor(sibling.Left) == red {
   536  		sibling.Left.color = black
   537  		tree.rotateRight(node.Parent)
   538  	}
   539  }
   540  
   541  func nodeColor[K, V any](node *Node[K, V]) color {
   542  	if node == nil {
   543  		return black
   544  	}
   545  	return node.color
   546  }
   547  
   548  // UnmarshalJSON @implements json.Unmarshaler
   549  func (tree *Tree[K, V]) UnmarshalJSON(bytes []byte) error {
   550  	elements := make(map[string]V)
   551  	err := json.Unmarshal(bytes, &elements)
   552  	if err == nil {
   553  		tree.Clear()
   554  		for key, value := range elements {
   555  			var nk K
   556  			err = tree.Comparator.Unmarshal([]byte(key), &nk)
   557  			if err != nil {
   558  				return err
   559  			}
   560  			tree.Put(nk, value)
   561  		}
   562  	}
   563  	return err
   564  }
   565  
   566  // MarshalJSON @implements json.Marshaler
   567  func (tree *Tree[K, V]) MarshalJSON() ([]byte, error) {
   568  	elements := make(map[string]V)
   569  	it := tree.Iterator()
   570  	for it.Next() {
   571  		k, err := tree.Comparator.Marshal(it.Key())
   572  		if err != nil {
   573  			return nil, err
   574  		}
   575  		elements[string(k)] = it.Value()
   576  	}
   577  	return json.Marshal(&elements)
   578  }