github.com/15mga/kiwi@v0.0.2-0.20240324021231-b95d5c3ac751/ds/rb_tree.go (about)

     1  package ds
     2  
     3  import (
     4  	"github.com/15mga/kiwi/util"
     5  )
     6  
     7  const (
     8  	red   = true
     9  	black = false
    10  )
    11  
    12  type RBTreeNode[KT comparable, VT any] struct {
    13  	Key    KT
    14  	Value  VT
    15  	color  bool
    16  	left   *RBTreeNode[KT, VT]
    17  	right  *RBTreeNode[KT, VT]
    18  	parent *RBTreeNode[KT, VT]
    19  }
    20  
    21  func (n *RBTreeNode[KT, VT]) isRed() bool {
    22  	if n == nil {
    23  		return false
    24  	}
    25  	return n.color
    26  }
    27  
    28  func (n *RBTreeNode[KT, VT]) grandparent() *RBTreeNode[KT, VT] {
    29  	if n != nil && n.parent != nil {
    30  		return n.parent.parent
    31  	}
    32  	return nil
    33  }
    34  
    35  func (n *RBTreeNode[KT, VT]) uncle() *RBTreeNode[KT, VT] {
    36  	if n == nil || n.parent == nil || n.parent.parent == nil {
    37  		return nil
    38  	}
    39  	return n.parent.sibling()
    40  }
    41  
    42  func (n *RBTreeNode[KT, VT]) maximumNode() *RBTreeNode[KT, VT] {
    43  	if n == nil {
    44  		return nil
    45  	}
    46  	for n.right != nil {
    47  		return n.right
    48  	}
    49  	return n
    50  }
    51  
    52  func (n *RBTreeNode[KT, VT]) sibling() *RBTreeNode[KT, VT] {
    53  	if n == nil || n.parent == nil {
    54  		return nil
    55  	}
    56  	if n == n.parent.left {
    57  		return n.parent.right
    58  	}
    59  	return n.parent.left
    60  }
    61  
    62  type RBTree[KT comparable, VT any] struct {
    63  	root    *RBTreeNode[KT, VT]
    64  	size    int
    65  	compare util.Compare[KT]
    66  }
    67  
    68  func NewRBTree[KT comparable, VT any](compare util.Compare[KT]) *RBTree[KT, VT] {
    69  	return &RBTree[KT, VT]{
    70  		compare: compare,
    71  	}
    72  }
    73  
    74  func NewRBTreeFromM[KT comparable, VT any](m map[KT]VT, compare util.Compare[KT]) *RBTree[KT, VT] {
    75  	tree := NewRBTree[KT, VT](compare)
    76  	for k, v := range m {
    77  		tree.Set(k, v)
    78  	}
    79  	return tree
    80  }
    81  
    82  func (t *RBTree[KT, VT]) Set(k KT, v VT) {
    83  	var nn *RBTreeNode[KT, VT]
    84  	if t.root == nil {
    85  		nn = &RBTreeNode[KT, VT]{
    86  			Key:   k,
    87  			Value: v,
    88  			color: red,
    89  		}
    90  		t.root = nn
    91  	} else {
    92  		ok := true
    93  		n := t.root
    94  		for ok {
    95  			cpr := t.compare(k, n.Key)
    96  			switch {
    97  			case cpr == 0:
    98  				n.Value = v
    99  				return
   100  			case cpr < 0:
   101  				if n.left == nil {
   102  					nn = &RBTreeNode[KT, VT]{
   103  						Key:   k,
   104  						Value: v,
   105  						color: true,
   106  					}
   107  					n.left = nn
   108  					ok = false
   109  				} else {
   110  					n = n.left
   111  				}
   112  			case cpr > 0:
   113  				if n.right == nil {
   114  					nn = &RBTreeNode[KT, VT]{
   115  						Key:   k,
   116  						Value: v,
   117  						color: red,
   118  					}
   119  					n.right = nn
   120  					ok = false
   121  				} else {
   122  					n = n.right
   123  				}
   124  			}
   125  		}
   126  		nn.parent = n
   127  	}
   128  	t.insert1(nn)
   129  	t.size++
   130  }
   131  
   132  func (t *RBTree[KT, VT]) Reset() {
   133  	t.root = nil
   134  	t.size = 0
   135  }
   136  
   137  func (t *RBTree[KT, VT]) Update(m map[KT]VT) {
   138  	for k, v := range m {
   139  		t.Set(k, v)
   140  	}
   141  }
   142  
   143  func (t *RBTree[KT, VT]) Get(k KT) (VT, bool) {
   144  	node, ok := t.getNode(k)
   145  	if ok {
   146  		return node.Value, true
   147  	}
   148  	return util.Default[VT](), false
   149  }
   150  
   151  func (t *RBTree[KT, VT]) Del(key KT) (value VT) {
   152  	child := (*RBTreeNode[KT, VT])(nil)
   153  	node, ok := t.getNode(key)
   154  	if !ok {
   155  		return
   156  	}
   157  	value = node.Value
   158  	if node.left != nil && node.right != nil {
   159  		p := node.left.maximumNode()
   160  		node.Key = p.Key
   161  		node.Value = p.Value
   162  		node = p
   163  	}
   164  	if node.left == nil || node.right == nil {
   165  		if node.right == nil {
   166  			child = node.left
   167  		} else {
   168  			child = node.right
   169  		}
   170  		if !node.color {
   171  			node.color = child.isRed()
   172  			t.deleteCase1(node)
   173  		}
   174  		t.updateNode(node, child)
   175  		if node.parent == nil && child != nil {
   176  			child.color = black
   177  		}
   178  	}
   179  	t.size--
   180  	return
   181  }
   182  
   183  func (t *RBTree[KT, VT]) AnyAsc(fn func(KT, VT) bool) bool {
   184  	return t.anyAsc(t.leftNode(), fn)
   185  }
   186  
   187  func (t *RBTree[KT, VT]) AnyAscFrom(k KT, fn func(KT, VT) bool) (match bool, ok bool) {
   188  	node, ok := t.getNode(k)
   189  	if !ok {
   190  		return
   191  	}
   192  	match = true
   193  	ok = t.anyAsc(node, fn)
   194  	return
   195  }
   196  
   197  func (t *RBTree[KT, VT]) AnyDesc(fn func(KT, VT) bool) bool {
   198  	return t.anyAsc(t.rightNode(), fn)
   199  }
   200  
   201  func (t *RBTree[KT, VT]) AnyDescFrom(k KT, fn func(KT, VT) bool) (match bool, ok bool) {
   202  	node, ok := t.getNode(k)
   203  	if !ok {
   204  		return
   205  	}
   206  	match = true
   207  	ok = t.anyDesc(node, fn)
   208  	return
   209  }
   210  
   211  func (t *RBTree[KT, VT]) M() map[KT]VT {
   212  	m := make(map[KT]VT, t.size)
   213  	_ = t.AnyAsc(func(k KT, v VT) bool {
   214  		m[k] = v
   215  		return false
   216  	})
   217  	return m
   218  }
   219  
   220  func (t *RBTree[KT, VT]) getNode(k KT) (node *RBTreeNode[KT, VT], ok bool) {
   221  	node = t.root
   222  	for node != nil {
   223  		cpr := t.compare(k, node.Key)
   224  		switch {
   225  		case cpr == 0:
   226  			return node, true
   227  		case cpr < 0:
   228  			node = node.left
   229  		case cpr > 0:
   230  			node = node.right
   231  		}
   232  	}
   233  	return node, false
   234  }
   235  
   236  func (t *RBTree[KT, VT]) leftNode() *RBTreeNode[KT, VT] {
   237  	p := (*RBTreeNode[KT, VT])(nil)
   238  	n := t.root
   239  	for n != nil {
   240  		p = n
   241  		n = n.left
   242  	}
   243  	return p
   244  }
   245  
   246  // rightNode returns the right-most (max) node or nil if tree is empty.
   247  func (t *RBTree[KT, VT]) rightNode() *RBTreeNode[KT, VT] {
   248  	p := (*RBTreeNode[KT, VT])(nil)
   249  	n := t.root
   250  	for n != nil {
   251  		p = n
   252  		n = n.right
   253  	}
   254  	return p
   255  }
   256  
   257  func (t *RBTree[KT, VT]) insert1(node *RBTreeNode[KT, VT]) {
   258  	if node.parent == nil {
   259  		node.color = black
   260  	} else {
   261  		t.insert2(node)
   262  	}
   263  }
   264  
   265  func (t *RBTree[KT, VT]) insert2(node *RBTreeNode[KT, VT]) {
   266  	if !node.parent.isRed() {
   267  		return
   268  	}
   269  	t.insert3(node)
   270  }
   271  
   272  func (t *RBTree[KT, VT]) insert3(node *RBTreeNode[KT, VT]) {
   273  	uncle := node.uncle()
   274  	if uncle.isRed() {
   275  		node.parent.color = black
   276  		uncle.color = black
   277  		node.grandparent().color = red
   278  		t.insert1(node.grandparent())
   279  	} else {
   280  		t.insert4(node)
   281  	}
   282  }
   283  
   284  func (t *RBTree[KT, VT]) insert4(node *RBTreeNode[KT, VT]) {
   285  	grandparent := node.grandparent()
   286  	if node == node.parent.right && node.parent == grandparent.left {
   287  		t.rotateLeft(node.parent)
   288  		node = node.left
   289  	} else if node == node.parent.left && node.parent == grandparent.right {
   290  		t.rotateRight(node.parent)
   291  		node = node.right
   292  	}
   293  	t.insert5(node)
   294  }
   295  
   296  func (t *RBTree[KT, VT]) insert5(node *RBTreeNode[KT, VT]) {
   297  	node.parent.color = black
   298  	grandparent := node.grandparent()
   299  	grandparent.color = red
   300  	if node == node.parent.left && node.parent == grandparent.left {
   301  		t.rotateRight(grandparent)
   302  	} else if node == node.parent.right && node.parent == grandparent.right {
   303  		t.rotateLeft(grandparent)
   304  	}
   305  }
   306  
   307  func (t *RBTree[KT, VT]) deleteCase1(node *RBTreeNode[KT, VT]) {
   308  	if node.parent == nil {
   309  		return
   310  	}
   311  	t.deleteCase2(node)
   312  }
   313  
   314  func (t *RBTree[KT, VT]) deleteCase2(node *RBTreeNode[KT, VT]) {
   315  	sibling := node.sibling()
   316  	if sibling.isRed() {
   317  		node.parent.color = red
   318  		sibling.color = black
   319  		if node == node.parent.left {
   320  			t.rotateLeft(node.parent)
   321  		} else {
   322  			t.rotateRight(node.parent)
   323  		}
   324  	}
   325  	t.deleteCase3(node)
   326  }
   327  
   328  func (t *RBTree[KT, VT]) deleteCase3(node *RBTreeNode[KT, VT]) {
   329  	sibling := node.sibling()
   330  	if node.parent.isRed() ||
   331  		sibling.isRed() ||
   332  		sibling.left.isRed() ||
   333  		sibling.right.isRed() {
   334  		t.deleteCase4(node)
   335  	} else {
   336  		sibling.color = red
   337  		t.deleteCase1(node.parent)
   338  	}
   339  }
   340  
   341  func (t *RBTree[KT, VT]) deleteCase4(node *RBTreeNode[KT, VT]) {
   342  	sibling := node.sibling()
   343  	if !node.parent.isRed() ||
   344  		sibling.isRed() ||
   345  		sibling.left.isRed() ||
   346  		sibling.right.isRed() {
   347  		t.deleteCase5(node)
   348  	} else {
   349  		sibling.color = red
   350  		node.parent.color = black
   351  	}
   352  }
   353  
   354  func (t *RBTree[KT, VT]) deleteCase5(node *RBTreeNode[KT, VT]) {
   355  	sibling := node.sibling()
   356  	if node == node.parent.left &&
   357  		!sibling.isRed() &&
   358  		sibling.left.isRed() &&
   359  		!sibling.right.isRed() {
   360  		sibling.color = red
   361  		sibling.left.color = black
   362  		t.rotateRight(sibling)
   363  	} else if node == node.parent.right &&
   364  		!sibling.isRed() &&
   365  		sibling.right.isRed() &&
   366  		!sibling.left.isRed() {
   367  		sibling.color = red
   368  		sibling.right.color = black
   369  		t.rotateLeft(sibling)
   370  	}
   371  	t.deleteCase6(node)
   372  }
   373  
   374  func (t *RBTree[KT, VT]) deleteCase6(node *RBTreeNode[KT, VT]) {
   375  	sibling := node.sibling()
   376  	sibling.color = node.parent.isRed()
   377  	node.parent.color = black
   378  	if node == node.parent.left && sibling.right.isRed() {
   379  		sibling.right.color = black
   380  		t.rotateLeft(node.parent)
   381  	} else if sibling.left.isRed() {
   382  		sibling.left.color = black
   383  		t.rotateRight(node.parent)
   384  	}
   385  }
   386  
   387  func (t *RBTree[KT, VT]) rotateLeft(node *RBTreeNode[KT, VT]) {
   388  	right := node.right
   389  	t.updateNode(node, right)
   390  	node.right = right.left
   391  	if right.left != nil {
   392  		right.left.parent = node
   393  	}
   394  	right.left = node
   395  	node.parent = right
   396  }
   397  
   398  func (t *RBTree[KT, VT]) rotateRight(node *RBTreeNode[KT, VT]) {
   399  	left := node.left
   400  	t.updateNode(node, left)
   401  	node.left = left.right
   402  	if left.right != nil {
   403  		left.right.parent = node
   404  	}
   405  	left.right = node
   406  	node.parent = left
   407  }
   408  
   409  func (t *RBTree[KT, VT]) updateNode(old *RBTreeNode[KT, VT], new *RBTreeNode[KT, VT]) {
   410  	if old.parent == nil {
   411  		t.root = new
   412  	} else {
   413  		if old == old.parent.left {
   414  			old.parent.left = new
   415  		} else {
   416  			old.parent.right = new
   417  		}
   418  	}
   419  	if new != nil {
   420  		new.parent = old.parent
   421  	}
   422  }
   423  
   424  func (t *RBTree[KT, VT]) anyAsc(node *RBTreeNode[KT, VT], f func(KT, VT) bool) bool {
   425  loop:
   426  	if node == nil {
   427  		return false
   428  	}
   429  	if f(node.Key, node.Value) {
   430  		return true
   431  	}
   432  	if node.right != nil {
   433  		node = node.right
   434  		for node.left != nil {
   435  			node = node.left
   436  		}
   437  		goto loop
   438  	}
   439  	if node.parent != nil {
   440  		old := node
   441  		for node.parent != nil {
   442  			node = node.parent
   443  			if t.compare(old.Key, node.Key) <= 0 {
   444  				goto loop
   445  			}
   446  		}
   447  	}
   448  	return false
   449  }
   450  
   451  func (t *RBTree[KT, VT]) anyDesc(node *RBTreeNode[KT, VT], f func(KT, VT) bool) bool {
   452  loop:
   453  	if node == nil {
   454  		return false
   455  	}
   456  	if f(node.Key, node.Value) {
   457  		return true
   458  	}
   459  	if node.left != nil {
   460  		node = node.left
   461  		for node.right != nil {
   462  			node = node.right
   463  		}
   464  		goto loop
   465  	}
   466  	if node.parent != nil {
   467  		old := node
   468  		for node.parent != nil {
   469  			node = node.parent
   470  			if t.compare(old.Key, node.Key) >= 0 {
   471  				goto loop
   472  			}
   473  		}
   474  	}
   475  	return false
   476  }