github.com/alexandrestein/gods@v1.0.1/trees/redblacktree/redblacktree.go (about)

     1  // Copyright (c) 2015, Emir Pasic. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package redblacktree implements a red-black tree.
     6  //
     7  // Used by TreeSet and TreeMap.
     8  //
     9  // Structure is not thread safe.
    10  //
    11  // References: http://en.wikipedia.org/wiki/Red%E2%80%93black_tree
    12  package redblacktree
    13  
    14  import (
    15  	"fmt"
    16  	"github.com/alexandrestein/gods/trees"
    17  	"github.com/alexandrestein/gods/utils"
    18  )
    19  
    20  func assertTreeImplementation() {
    21  	var _ trees.Tree = (*Tree)(nil)
    22  }
    23  
    24  type color bool
    25  
    26  const (
    27  	black, red color = true, false
    28  )
    29  
    30  // Tree holds elements of the red-black tree
    31  type Tree struct {
    32  	Root       *Node
    33  	size       int
    34  	Comparator utils.Comparator
    35  }
    36  
    37  // Node is a single element within the tree
    38  type Node struct {
    39  	Key    interface{}
    40  	Value  interface{}
    41  	color  color
    42  	Left   *Node
    43  	Right  *Node
    44  	Parent *Node
    45  }
    46  
    47  // NewWith instantiates a red-black tree with the custom comparator.
    48  func NewWith(comparator utils.Comparator) *Tree {
    49  	return &Tree{Comparator: comparator}
    50  }
    51  
    52  // NewWithIntComparator instantiates a red-black tree with the IntComparator, i.e. keys are of type int.
    53  func NewWithIntComparator() *Tree {
    54  	return &Tree{Comparator: utils.IntComparator}
    55  }
    56  
    57  // NewWithStringComparator instantiates a red-black tree with the StringComparator, i.e. keys are of type string.
    58  func NewWithStringComparator() *Tree {
    59  	return &Tree{Comparator: utils.StringComparator}
    60  }
    61  
    62  // Put inserts node into the tree.
    63  // Key should adhere to the comparator's type assertion, otherwise method panics.
    64  func (tree *Tree) Put(key interface{}, value interface{}) {
    65  	var insertedNode *Node
    66  	if tree.Root == nil {
    67  		tree.Root = &Node{Key: key, Value: value, color: red}
    68  		insertedNode = tree.Root
    69  	} else {
    70  		node := tree.Root
    71  		loop := true
    72  		for loop {
    73  			compare := tree.Comparator(key, node.Key)
    74  			switch {
    75  			case compare == 0:
    76  				node.Key = key
    77  				node.Value = value
    78  				return
    79  			case compare < 0:
    80  				if node.Left == nil {
    81  					node.Left = &Node{Key: key, Value: value, color: red}
    82  					insertedNode = node.Left
    83  					loop = false
    84  				} else {
    85  					node = node.Left
    86  				}
    87  			case compare > 0:
    88  				if node.Right == nil {
    89  					node.Right = &Node{Key: key, Value: value, color: red}
    90  					insertedNode = node.Right
    91  					loop = false
    92  				} else {
    93  					node = node.Right
    94  				}
    95  			}
    96  		}
    97  		insertedNode.Parent = node
    98  	}
    99  	tree.insertCase1(insertedNode)
   100  	tree.size++
   101  }
   102  
   103  // Get searches the node in the tree by key and returns its value or nil if key is not found in tree.
   104  // Second return parameter is true if key was found, otherwise false.
   105  // Key should adhere to the comparator's type assertion, otherwise method panics.
   106  func (tree *Tree) Get(key interface{}) (value interface{}, found bool) {
   107  	node := tree.lookup(key)
   108  	if node != nil {
   109  		return node.Value, true
   110  	}
   111  	return nil, false
   112  }
   113  
   114  // Remove remove the node from the tree by key.
   115  // Key should adhere to the comparator's type assertion, otherwise method panics.
   116  func (tree *Tree) Remove(key interface{}) {
   117  	var child *Node
   118  	node := tree.lookup(key)
   119  	if node == nil {
   120  		return
   121  	}
   122  	if node.Left != nil && node.Right != nil {
   123  		pred := node.Left.maximumNode()
   124  		node.Key = pred.Key
   125  		node.Value = pred.Value
   126  		node = pred
   127  	}
   128  	if node.Left == nil || node.Right == nil {
   129  		if node.Right == nil {
   130  			child = node.Left
   131  		} else {
   132  			child = node.Right
   133  		}
   134  		if node.color == black {
   135  			node.color = nodeColor(child)
   136  			tree.deleteCase1(node)
   137  		}
   138  		tree.replaceNode(node, child)
   139  		if node.Parent == nil && child != nil {
   140  			child.color = black
   141  		}
   142  	}
   143  	tree.size--
   144  }
   145  
   146  // Empty returns true if tree does not contain any nodes
   147  func (tree *Tree) Empty() bool {
   148  	return tree.size == 0
   149  }
   150  
   151  // Size returns number of nodes in the tree.
   152  func (tree *Tree) Size() int {
   153  	return tree.size
   154  }
   155  
   156  // Keys returns all keys in-order
   157  func (tree *Tree) Keys() []interface{} {
   158  	keys := make([]interface{}, tree.size)
   159  	it := tree.Iterator()
   160  	for i := 0; it.Next(); i++ {
   161  		keys[i] = it.Key()
   162  	}
   163  	return keys
   164  }
   165  
   166  // Values returns all values in-order based on the key.
   167  func (tree *Tree) Values() []interface{} {
   168  	values := make([]interface{}, tree.size)
   169  	it := tree.Iterator()
   170  	for i := 0; it.Next(); i++ {
   171  		values[i] = it.Value()
   172  	}
   173  	return values
   174  }
   175  
   176  // Left returns the left-most (min) node or nil if tree is empty.
   177  func (tree *Tree) Left() *Node {
   178  	var parent *Node
   179  	current := tree.Root
   180  	for current != nil {
   181  		parent = current
   182  		current = current.Left
   183  	}
   184  	return parent
   185  }
   186  
   187  // Right returns the right-most (max) node or nil if tree is empty.
   188  func (tree *Tree) Right() *Node {
   189  	var parent *Node
   190  	current := tree.Root
   191  	for current != nil {
   192  		parent = current
   193  		current = current.Right
   194  	}
   195  	return parent
   196  }
   197  
   198  // Floor Finds floor node of the input key, return the floor node or nil if no ceiling is found.
   199  // Second return parameter is true if floor was found, otherwise false.
   200  //
   201  // Floor node is defined as the largest node that is smaller than or equal to the given node.
   202  // A floor node may not be found, either because the tree is empty, or because
   203  // all nodes in the tree is larger than the given node.
   204  //
   205  // Key should adhere to the comparator's type assertion, otherwise method panics.
   206  func (tree *Tree) Floor(key interface{}) (floor *Node, found bool) {
   207  	found = false
   208  	node := tree.Root
   209  	for node != nil {
   210  		compare := tree.Comparator(key, node.Key)
   211  		switch {
   212  		case compare == 0:
   213  			return node, true
   214  		case compare < 0:
   215  			node = node.Left
   216  		case compare > 0:
   217  			floor, found = node, true
   218  			node = node.Right
   219  		}
   220  	}
   221  	if found {
   222  		return floor, true
   223  	}
   224  	return nil, false
   225  }
   226  
   227  // Ceiling finds ceiling node of the input key, return the ceiling node or nil if no ceiling is found.
   228  // Second return parameter is true if ceiling was found, otherwise false.
   229  //
   230  // Ceiling node is defined as the smallest node that is larger than or equal to the given node.
   231  // A ceiling node may not be found, either because the tree is empty, or because
   232  // all nodes in the tree is smaller than the given node.
   233  //
   234  // Key should adhere to the comparator's type assertion, otherwise method panics.
   235  func (tree *Tree) Ceiling(key interface{}) (ceiling *Node, found bool) {
   236  	found = false
   237  	node := tree.Root
   238  	for node != nil {
   239  		compare := tree.Comparator(key, node.Key)
   240  		switch {
   241  		case compare == 0:
   242  			return node, true
   243  		case compare < 0:
   244  			ceiling, found = node, true
   245  			node = node.Left
   246  		case compare > 0:
   247  			node = node.Right
   248  		}
   249  	}
   250  	if found {
   251  		return ceiling, true
   252  	}
   253  	return nil, false
   254  }
   255  
   256  // Clear removes all nodes from the tree.
   257  func (tree *Tree) Clear() {
   258  	tree.Root = nil
   259  	tree.size = 0
   260  }
   261  
   262  // String returns a string representation of container
   263  func (tree *Tree) String() string {
   264  	str := "RedBlackTree\n"
   265  	if !tree.Empty() {
   266  		output(tree.Root, "", true, &str)
   267  	}
   268  	return str
   269  }
   270  
   271  func (node *Node) String() string {
   272  	return fmt.Sprintf("%v", node.Key)
   273  }
   274  
   275  func output(node *Node, prefix string, isTail bool, str *string) {
   276  	if node.Right != nil {
   277  		newPrefix := prefix
   278  		if isTail {
   279  			newPrefix += "│   "
   280  		} else {
   281  			newPrefix += "    "
   282  		}
   283  		output(node.Right, newPrefix, false, str)
   284  	}
   285  	*str += prefix
   286  	if isTail {
   287  		*str += "└── "
   288  	} else {
   289  		*str += "┌── "
   290  	}
   291  	*str += node.String() + "\n"
   292  	if node.Left != nil {
   293  		newPrefix := prefix
   294  		if isTail {
   295  			newPrefix += "    "
   296  		} else {
   297  			newPrefix += "│   "
   298  		}
   299  		output(node.Left, newPrefix, true, str)
   300  	}
   301  }
   302  
   303  func (tree *Tree) lookup(key interface{}) *Node {
   304  	node := tree.Root
   305  	for node != nil {
   306  		compare := tree.Comparator(key, node.Key)
   307  		switch {
   308  		case compare == 0:
   309  			return node
   310  		case compare < 0:
   311  			node = node.Left
   312  		case compare > 0:
   313  			node = node.Right
   314  		}
   315  	}
   316  	return nil
   317  }
   318  
   319  func (node *Node) grandparent() *Node {
   320  	if node != nil && node.Parent != nil {
   321  		return node.Parent.Parent
   322  	}
   323  	return nil
   324  }
   325  
   326  func (node *Node) uncle() *Node {
   327  	if node == nil || node.Parent == nil || node.Parent.Parent == nil {
   328  		return nil
   329  	}
   330  	return node.Parent.sibling()
   331  }
   332  
   333  func (node *Node) sibling() *Node {
   334  	if node == nil || node.Parent == nil {
   335  		return nil
   336  	}
   337  	if node == node.Parent.Left {
   338  		return node.Parent.Right
   339  	}
   340  	return node.Parent.Left
   341  }
   342  
   343  func (tree *Tree) rotateLeft(node *Node) {
   344  	right := node.Right
   345  	tree.replaceNode(node, right)
   346  	node.Right = right.Left
   347  	if right.Left != nil {
   348  		right.Left.Parent = node
   349  	}
   350  	right.Left = node
   351  	node.Parent = right
   352  }
   353  
   354  func (tree *Tree) rotateRight(node *Node) {
   355  	left := node.Left
   356  	tree.replaceNode(node, left)
   357  	node.Left = left.Right
   358  	if left.Right != nil {
   359  		left.Right.Parent = node
   360  	}
   361  	left.Right = node
   362  	node.Parent = left
   363  }
   364  
   365  func (tree *Tree) replaceNode(old *Node, new *Node) {
   366  	if old.Parent == nil {
   367  		tree.Root = new
   368  	} else {
   369  		if old == old.Parent.Left {
   370  			old.Parent.Left = new
   371  		} else {
   372  			old.Parent.Right = new
   373  		}
   374  	}
   375  	if new != nil {
   376  		new.Parent = old.Parent
   377  	}
   378  }
   379  
   380  func (tree *Tree) insertCase1(node *Node) {
   381  	if node.Parent == nil {
   382  		node.color = black
   383  	} else {
   384  		tree.insertCase2(node)
   385  	}
   386  }
   387  
   388  func (tree *Tree) insertCase2(node *Node) {
   389  	if nodeColor(node.Parent) == black {
   390  		return
   391  	}
   392  	tree.insertCase3(node)
   393  }
   394  
   395  func (tree *Tree) insertCase3(node *Node) {
   396  	uncle := node.uncle()
   397  	if nodeColor(uncle) == red {
   398  		node.Parent.color = black
   399  		uncle.color = black
   400  		node.grandparent().color = red
   401  		tree.insertCase1(node.grandparent())
   402  	} else {
   403  		tree.insertCase4(node)
   404  	}
   405  }
   406  
   407  func (tree *Tree) insertCase4(node *Node) {
   408  	grandparent := node.grandparent()
   409  	if node == node.Parent.Right && node.Parent == grandparent.Left {
   410  		tree.rotateLeft(node.Parent)
   411  		node = node.Left
   412  	} else if node == node.Parent.Left && node.Parent == grandparent.Right {
   413  		tree.rotateRight(node.Parent)
   414  		node = node.Right
   415  	}
   416  	tree.insertCase5(node)
   417  }
   418  
   419  func (tree *Tree) insertCase5(node *Node) {
   420  	node.Parent.color = black
   421  	grandparent := node.grandparent()
   422  	grandparent.color = red
   423  	if node == node.Parent.Left && node.Parent == grandparent.Left {
   424  		tree.rotateRight(grandparent)
   425  	} else if node == node.Parent.Right && node.Parent == grandparent.Right {
   426  		tree.rotateLeft(grandparent)
   427  	}
   428  }
   429  
   430  func (node *Node) maximumNode() *Node {
   431  	if node == nil {
   432  		return nil
   433  	}
   434  	for node.Right != nil {
   435  		node = node.Right
   436  	}
   437  	return node
   438  }
   439  
   440  func (tree *Tree) deleteCase1(node *Node) {
   441  	if node.Parent == nil {
   442  		return
   443  	}
   444  	tree.deleteCase2(node)
   445  }
   446  
   447  func (tree *Tree) deleteCase2(node *Node) {
   448  	sibling := node.sibling()
   449  	if nodeColor(sibling) == red {
   450  		node.Parent.color = red
   451  		sibling.color = black
   452  		if node == node.Parent.Left {
   453  			tree.rotateLeft(node.Parent)
   454  		} else {
   455  			tree.rotateRight(node.Parent)
   456  		}
   457  	}
   458  	tree.deleteCase3(node)
   459  }
   460  
   461  func (tree *Tree) deleteCase3(node *Node) {
   462  	sibling := node.sibling()
   463  	if nodeColor(node.Parent) == black &&
   464  		nodeColor(sibling) == black &&
   465  		nodeColor(sibling.Left) == black &&
   466  		nodeColor(sibling.Right) == black {
   467  		sibling.color = red
   468  		tree.deleteCase1(node.Parent)
   469  	} else {
   470  		tree.deleteCase4(node)
   471  	}
   472  }
   473  
   474  func (tree *Tree) deleteCase4(node *Node) {
   475  	sibling := node.sibling()
   476  	if nodeColor(node.Parent) == red &&
   477  		nodeColor(sibling) == black &&
   478  		nodeColor(sibling.Left) == black &&
   479  		nodeColor(sibling.Right) == black {
   480  		sibling.color = red
   481  		node.Parent.color = black
   482  	} else {
   483  		tree.deleteCase5(node)
   484  	}
   485  }
   486  
   487  func (tree *Tree) deleteCase5(node *Node) {
   488  	sibling := node.sibling()
   489  	if node == node.Parent.Left &&
   490  		nodeColor(sibling) == black &&
   491  		nodeColor(sibling.Left) == red &&
   492  		nodeColor(sibling.Right) == black {
   493  		sibling.color = red
   494  		sibling.Left.color = black
   495  		tree.rotateRight(sibling)
   496  	} else if node == node.Parent.Right &&
   497  		nodeColor(sibling) == black &&
   498  		nodeColor(sibling.Right) == red &&
   499  		nodeColor(sibling.Left) == black {
   500  		sibling.color = red
   501  		sibling.Right.color = black
   502  		tree.rotateLeft(sibling)
   503  	}
   504  	tree.deleteCase6(node)
   505  }
   506  
   507  func (tree *Tree) deleteCase6(node *Node) {
   508  	sibling := node.sibling()
   509  	sibling.color = nodeColor(node.Parent)
   510  	node.Parent.color = black
   511  	if node == node.Parent.Left && nodeColor(sibling.Right) == red {
   512  		sibling.Right.color = black
   513  		tree.rotateLeft(node.Parent)
   514  	} else if nodeColor(sibling.Left) == red {
   515  		sibling.Left.color = black
   516  		tree.rotateRight(node.Parent)
   517  	}
   518  }
   519  
   520  func nodeColor(node *Node) color {
   521  	if node == nil {
   522  		return black
   523  	}
   524  	return node.color
   525  }