github.com/chain5j/chain5j-pkg@v1.0.7/collection/maps/treemap/v2/treemap.go (about)

     1  // Package treemap provides a generic key-sorted map.
     2  // It uses red-black tree under the hood.
     3  // Iterators are designed after C++.
     4  //
     5  // Example:
     6  //
     7  //     package main
     8  //
     9  //     import (
    10  //         "fmt"
    11  //
    12  //         "github.com/igrmk/treemap/v2"
    13  //     )
    14  //
    15  //     func main() {
    16  //         tree := treemap.New[int, string]()
    17  //         tree.Set(1, "World")
    18  //         tree.Set(0, "Hello")
    19  //         for it := tree.Iterator(); it.Valid(); it.Next() {
    20  //             fmt.Println(it.Key(), it.Value())
    21  //         }
    22  //     }
    23  //
    24  //     // Output:
    25  //     // 0 Hello
    26  //     // 1 World
    27  package treemap
    28  
    29  import (
    30  	"sync"
    31  
    32  	"golang.org/x/exp/constraints"
    33  )
    34  
    35  // TreeMap is the generic red-black tree based map
    36  type TreeMap[Key, Value any] struct {
    37  	mu         sync.Mutex
    38  	endNode    *node[Key, Value]
    39  	beginNode  *node[Key, Value]
    40  	count      int
    41  	keyCompare func(a Key, b Key) bool
    42  }
    43  
    44  type node[Key, Value any] struct {
    45  	right   *node[Key, Value]
    46  	left    *node[Key, Value]
    47  	parent  *node[Key, Value]
    48  	isBlack bool
    49  	key     Key
    50  	value   Value
    51  }
    52  
    53  // New creates and returns new TreeMap.
    54  func New[Key constraints.Ordered, Value any]() *TreeMap[Key, Value] {
    55  	endNode := &node[Key, Value]{isBlack: true}
    56  	return &TreeMap[Key, Value]{beginNode: endNode, endNode: endNode, keyCompare: defaultKeyCompare[Key]}
    57  }
    58  
    59  // NewWithKeyCompare creates and returns new TreeMap with the specified key compare function.
    60  // Parameter keyCompare is a function returning a < b.
    61  func NewWithKeyCompare[Key, Value any](
    62  	keyCompare func(a, b Key) bool,
    63  ) *TreeMap[Key, Value] {
    64  	endNode := &node[Key, Value]{isBlack: true}
    65  	return &TreeMap[Key, Value]{beginNode: endNode, endNode: endNode, keyCompare: keyCompare}
    66  }
    67  
    68  // Len returns total count of elements in a map.
    69  // Complexity: O(1).
    70  func (t *TreeMap[Key, Value]) Len() int { return t.count }
    71  
    72  // Set sets the value and silently overrides previous value if it exists.
    73  // Complexity: O(log N).
    74  func (t *TreeMap[Key, Value]) Set(key Key, value Value) {
    75  	t.mu.Lock()
    76  	defer t.mu.Unlock()
    77  	parent := t.endNode
    78  	current := parent.left
    79  	less := true
    80  	for current != nil {
    81  		parent = current
    82  		switch {
    83  		case t.keyCompare(key, current.key):
    84  			current = current.left
    85  			less = true
    86  		case t.keyCompare(current.key, key):
    87  			current = current.right
    88  			less = false
    89  		default:
    90  			current.value = value
    91  			return
    92  		}
    93  	}
    94  	x := &node[Key, Value]{parent: parent, value: value, key: key}
    95  	if less {
    96  		parent.left = x
    97  	} else {
    98  		parent.right = x
    99  	}
   100  	if t.beginNode.left != nil {
   101  		t.beginNode = t.beginNode.left
   102  	}
   103  	t.insertFixup(x)
   104  	t.count++
   105  }
   106  
   107  // Del deletes the value.
   108  // Complexity: O(log N).
   109  func (t *TreeMap[Key, Value]) Del(key Key) {
   110  	t.mu.Lock()
   111  	defer t.mu.Unlock()
   112  	z := t.findNode(key)
   113  	if z == nil {
   114  		return
   115  	}
   116  	if t.beginNode == z {
   117  		if z.right != nil {
   118  			t.beginNode = z.right
   119  		} else {
   120  			t.beginNode = z.parent
   121  		}
   122  	}
   123  	t.count--
   124  	removeNode(t.endNode.left, z)
   125  }
   126  
   127  // Clear clears the map.
   128  // Complexity: O(1).
   129  func (t *TreeMap[Key, Value]) Clear() {
   130  	t.mu.Lock()
   131  	defer t.mu.Unlock()
   132  	t.count = 0
   133  	t.beginNode = t.endNode
   134  	t.endNode.left = nil
   135  }
   136  
   137  // Get retrieves a value from a map for specified key and reports if it exists.
   138  // Complexity: O(log N).
   139  func (t *TreeMap[Key, Value]) Get(id Key) (Value, bool) {
   140  	t.mu.Lock()
   141  	node := t.findNode(id)
   142  	t.mu.Unlock()
   143  	if node == nil {
   144  		node = t.endNode
   145  	}
   146  	return node.value, node != t.endNode
   147  }
   148  
   149  // Contains checks if key exists in a map.
   150  // Complexity: O(log N)
   151  func (t *TreeMap[Key, Value]) Contains(id Key) bool {
   152  	t.mu.Lock()
   153  	defer t.mu.Unlock()
   154  	return t.findNode(id) != nil
   155  }
   156  
   157  // Range returns a pair of iterators that you can use to go through all the keys in the range [from, to].
   158  // More specifically it returns iterators pointing to lower bound and upper bound.
   159  // Complexity: O(log N).
   160  func (t *TreeMap[Key, Value]) Range(from, to Key) (ForwardIterator[Key, Value], ForwardIterator[Key, Value]) {
   161  	return t.LowerBound(from), t.UpperBound(to)
   162  }
   163  
   164  // LowerBound returns an iterator pointing to the first element that is not less than the given key.
   165  // Complexity: O(log N).
   166  func (t *TreeMap[Key, Value]) LowerBound(key Key) ForwardIterator[Key, Value] {
   167  	t.mu.Lock()
   168  	defer t.mu.Unlock()
   169  	result := t.endNode
   170  	node := t.endNode.left
   171  	if node == nil {
   172  		return ForwardIterator[Key, Value]{tree: t, node: t.endNode}
   173  	}
   174  	for {
   175  		if t.keyCompare(node.key, key) {
   176  			if node.right != nil {
   177  				node = node.right
   178  			} else {
   179  				return ForwardIterator[Key, Value]{tree: t, node: result}
   180  			}
   181  		} else {
   182  			result = node
   183  			if node.left != nil {
   184  				node = node.left
   185  			} else {
   186  				return ForwardIterator[Key, Value]{tree: t, node: result}
   187  			}
   188  		}
   189  	}
   190  }
   191  
   192  // UpperBound returns an iterator pointing to the first element that is greater than the given key.
   193  // Complexity: O(log N).
   194  func (t *TreeMap[Key, Value]) UpperBound(key Key) ForwardIterator[Key, Value] {
   195  	t.mu.Lock()
   196  	defer t.mu.Unlock()
   197  	result := t.endNode
   198  	node := t.endNode.left
   199  	if node == nil {
   200  		return ForwardIterator[Key, Value]{tree: t, node: t.endNode}
   201  	}
   202  	for {
   203  		if !t.keyCompare(key, node.key) {
   204  			if node.right != nil {
   205  				node = node.right
   206  			} else {
   207  				return ForwardIterator[Key, Value]{tree: t, node: result}
   208  			}
   209  		} else {
   210  			result = node
   211  			if node.left != nil {
   212  				node = node.left
   213  			} else {
   214  				return ForwardIterator[Key, Value]{tree: t, node: result}
   215  			}
   216  		}
   217  	}
   218  }
   219  
   220  // Iterator returns an iterator for tree map.
   221  // It starts at the first element and goes to the one-past-the-end position.
   222  // You can iterate a map at O(N) complexity.
   223  // Method complexity: O(1)
   224  func (t *TreeMap[Key, Value]) Iterator() ForwardIterator[Key, Value] {
   225  	return ForwardIterator[Key, Value]{tree: t, node: t.beginNode}
   226  }
   227  
   228  // Reverse returns a reverse iterator for tree map.
   229  // It starts at the last element and goes to the one-before-the-start position.
   230  // You can iterate a map at O(N) complexity.
   231  // Method complexity: O(log N)
   232  func (t *TreeMap[Key, Value]) Reverse() ReverseIterator[Key, Value] {
   233  	t.mu.Lock()
   234  	defer t.mu.Unlock()
   235  	node := t.endNode.left
   236  	if node != nil {
   237  		node = mostRight(node)
   238  	}
   239  	return ReverseIterator[Key, Value]{tree: t, node: node}
   240  }
   241  
   242  func defaultKeyCompare[Key constraints.Ordered](
   243  	a, b Key,
   244  ) bool {
   245  	return a < b
   246  }
   247  
   248  func (t *TreeMap[Key, Value]) findNode(id Key) *node[Key, Value] {
   249  	current := t.endNode.left
   250  	for current != nil {
   251  		switch {
   252  		case t.keyCompare(id, current.key):
   253  			current = current.left
   254  		case t.keyCompare(current.key, id):
   255  			current = current.right
   256  		default:
   257  			return current
   258  		}
   259  	}
   260  	return nil
   261  }
   262  func (t *TreeMap[Key, Value]) KeyCompare() func(a Key, b Key) bool {
   263  	return t.keyCompare
   264  }
   265  
   266  func mostLeft[Key, Value any](
   267  	x *node[Key, Value],
   268  ) *node[Key, Value] {
   269  	for x.left != nil {
   270  		x = x.left
   271  	}
   272  	return x
   273  }
   274  
   275  func mostRight[Key, Value any](
   276  	x *node[Key, Value],
   277  ) *node[Key, Value] {
   278  	for x.right != nil {
   279  		x = x.right
   280  	}
   281  	return x
   282  }
   283  
   284  func successor[Key, Value any](
   285  	x *node[Key, Value],
   286  ) *node[Key, Value] {
   287  	if x.right != nil {
   288  		return mostLeft(x.right)
   289  	}
   290  	for x != x.parent.left {
   291  		x = x.parent
   292  	}
   293  	return x.parent
   294  }
   295  
   296  func predecessor[Key, Value any](
   297  	x *node[Key, Value],
   298  ) *node[Key, Value] {
   299  	if x.left != nil {
   300  		return mostRight(x.left)
   301  	}
   302  	for x.parent != nil && x != x.parent.right {
   303  		x = x.parent
   304  	}
   305  	return x.parent
   306  }
   307  
   308  func rotateLeft[Key, Value any](
   309  	x *node[Key, Value],
   310  ) {
   311  	y := x.right
   312  	x.right = y.left
   313  	if x.right != nil {
   314  		x.right.parent = x
   315  	}
   316  	y.parent = x.parent
   317  	if x == x.parent.left {
   318  		x.parent.left = y
   319  	} else {
   320  		x.parent.right = y
   321  	}
   322  	y.left = x
   323  	x.parent = y
   324  }
   325  
   326  func rotateRight[Key, Value any](
   327  	x *node[Key, Value],
   328  ) {
   329  	y := x.left
   330  	x.left = y.right
   331  	if x.left != nil {
   332  		x.left.parent = x
   333  	}
   334  	y.parent = x.parent
   335  	if x == x.parent.left {
   336  		x.parent.left = y
   337  	} else {
   338  		x.parent.right = y
   339  	}
   340  	y.right = x
   341  	x.parent = y
   342  }
   343  
   344  func (t *TreeMap[Key, Value]) insertFixup(x *node[Key, Value]) {
   345  	root := t.endNode.left
   346  	x.isBlack = x == root
   347  	for x != root && !x.parent.isBlack {
   348  		if x.parent == x.parent.parent.left {
   349  			y := x.parent.parent.right
   350  			if y != nil && !y.isBlack {
   351  				x = x.parent
   352  				x.isBlack = true
   353  				x = x.parent
   354  				x.isBlack = x == root
   355  				y.isBlack = true
   356  			} else {
   357  				if x != x.parent.left {
   358  					x = x.parent
   359  					rotateLeft(x)
   360  				}
   361  				x = x.parent
   362  				x.isBlack = true
   363  				x = x.parent
   364  				x.isBlack = false
   365  				rotateRight(x)
   366  				break
   367  			}
   368  		} else {
   369  			y := x.parent.parent.left
   370  			if y != nil && !y.isBlack {
   371  				x = x.parent
   372  				x.isBlack = true
   373  				x = x.parent
   374  				x.isBlack = x == root
   375  				y.isBlack = true
   376  			} else {
   377  				if x == x.parent.left {
   378  					x = x.parent
   379  					rotateRight(x)
   380  				}
   381  				x = x.parent
   382  				x.isBlack = true
   383  				x = x.parent
   384  				x.isBlack = false
   385  				rotateLeft(x)
   386  				break
   387  			}
   388  		}
   389  	}
   390  }
   391  
   392  //nolint:gocyclo
   393  //noinspection GoNilness
   394  func removeNode[Key, Value any](
   395  	root, z *node[Key, Value],
   396  ) {
   397  	var y *node[Key, Value]
   398  	if z.left == nil || z.right == nil {
   399  		y = z
   400  	} else {
   401  		y = successor(z)
   402  	}
   403  	var x *node[Key, Value]
   404  	if y.left != nil {
   405  		x = y.left
   406  	} else {
   407  		x = y.right
   408  	}
   409  	var w *node[Key, Value]
   410  	if x != nil {
   411  		x.parent = y.parent
   412  	}
   413  	if y == y.parent.left {
   414  		y.parent.left = x
   415  		if y != root {
   416  			w = y.parent.right
   417  		} else {
   418  			root = x // w == nil
   419  		}
   420  	} else {
   421  		y.parent.right = x
   422  		w = y.parent.left
   423  	}
   424  	removedBlack := y.isBlack
   425  	if y != z {
   426  		y.parent = z.parent
   427  		if z == z.parent.left {
   428  			y.parent.left = y
   429  		} else {
   430  			y.parent.right = y
   431  		}
   432  		y.left = z.left
   433  		y.left.parent = y
   434  		y.right = z.right
   435  		if y.right != nil {
   436  			y.right.parent = y
   437  		}
   438  		y.isBlack = z.isBlack
   439  		if root == z {
   440  			root = y
   441  		}
   442  	}
   443  	if removedBlack && root != nil {
   444  		if x != nil {
   445  			x.isBlack = true
   446  		} else {
   447  			for {
   448  				if w != w.parent.left {
   449  					if !w.isBlack {
   450  						w.isBlack = true
   451  						w.parent.isBlack = false
   452  						rotateLeft(w.parent)
   453  						if root == w.left {
   454  							root = w
   455  						}
   456  						w = w.left.right
   457  					}
   458  					if (w.left == nil || w.left.isBlack) && (w.right == nil || w.right.isBlack) {
   459  						w.isBlack = false
   460  						x = w.parent
   461  						if x == root || !x.isBlack {
   462  							x.isBlack = true
   463  							break
   464  						}
   465  						if x == x.parent.left {
   466  							w = x.parent.right
   467  						} else {
   468  							w = x.parent.left
   469  						}
   470  					} else {
   471  						if w.right == nil || w.right.isBlack {
   472  							w.left.isBlack = true
   473  							w.isBlack = false
   474  							rotateRight(w)
   475  							w = w.parent
   476  						}
   477  						w.isBlack = w.parent.isBlack
   478  						w.parent.isBlack = true
   479  						w.right.isBlack = true
   480  						rotateLeft(w.parent)
   481  						break
   482  					}
   483  				} else {
   484  					if !w.isBlack {
   485  						w.isBlack = true
   486  						w.parent.isBlack = false
   487  						rotateRight(w.parent)
   488  						if root == w.right {
   489  							root = w
   490  						}
   491  						w = w.right.left
   492  					}
   493  					if (w.left == nil || w.left.isBlack) && (w.right == nil || w.right.isBlack) {
   494  						w.isBlack = false
   495  						x = w.parent
   496  						if !x.isBlack || x == root {
   497  							x.isBlack = true
   498  							break
   499  						}
   500  						if x == x.parent.left {
   501  							w = x.parent.right
   502  						} else {
   503  							w = x.parent.left
   504  						}
   505  					} else {
   506  						if w.left == nil || w.left.isBlack {
   507  							w.right.isBlack = true
   508  							w.isBlack = false
   509  							rotateLeft(w)
   510  							w = w.parent
   511  						}
   512  						w.isBlack = w.parent.isBlack
   513  						w.parent.isBlack = true
   514  						w.left.isBlack = true
   515  						rotateRight(w.parent)
   516  						break
   517  					}
   518  				}
   519  			}
   520  		}
   521  	}
   522  }
   523  
   524  // ForwardIterator represents a position in a tree map.
   525  // It is designed to iterate a map in a forward order.
   526  // It can point to any position from the first element to the one-past-the-end element.
   527  type ForwardIterator[Key, Value any] struct {
   528  	tree *TreeMap[Key, Value]
   529  	node *node[Key, Value]
   530  }
   531  
   532  // Valid reports if the iterator position is valid.
   533  // In other words it returns true if an iterator is not at the one-past-the-end position.
   534  func (i ForwardIterator[Key, Value]) Valid() bool { return i.node != i.tree.endNode }
   535  
   536  // Next moves an iterator to the next element.
   537  // It panics if it goes out of bounds.
   538  func (i *ForwardIterator[Key, Value]) Next() {
   539  	if i.node == i.tree.endNode {
   540  		panic("out of bound iteration")
   541  	}
   542  	i.node = successor(i.node)
   543  }
   544  
   545  // Prev moves an iterator to the previous element.
   546  // It panics if it goes out of bounds.
   547  func (i *ForwardIterator[Key, Value]) Prev() {
   548  	i.node = predecessor(i.node)
   549  	if i.node == nil {
   550  		panic("out of bound iteration")
   551  	}
   552  }
   553  
   554  // Key returns a key at the iterator position
   555  func (i ForwardIterator[Key, Value]) Key() Key { return i.node.key }
   556  
   557  // Value returns a value at the iterator position
   558  func (i ForwardIterator[Key, Value]) Value() Value { return i.node.value }
   559  
   560  // ReverseIterator represents a position in a tree map.
   561  // It is designed to iterate a map in a reverse order.
   562  // It can point to any position from the one-before-the-start element to the last element.
   563  type ReverseIterator[Key, Value any] struct {
   564  	tree *TreeMap[Key, Value]
   565  	node *node[Key, Value]
   566  }
   567  
   568  // Valid reports if the iterator position is valid.
   569  // In other words it returns true if an iterator is not at the one-before-the-start position.
   570  func (i ReverseIterator[Key, Value]) Valid() bool { return i.node != nil }
   571  
   572  // Next moves an iterator to the next element in reverse order.
   573  // It panics if it goes out of bounds.
   574  func (i *ReverseIterator[Key, Value]) Next() {
   575  	if i.node == nil {
   576  		panic("out of bound iteration")
   577  	}
   578  	i.node = predecessor(i.node)
   579  }
   580  
   581  // Prev moves an iterator to the previous element in reverse order.
   582  // It panics if it goes out of bounds.
   583  func (i *ReverseIterator[Key, Value]) Prev() {
   584  	if i.node != nil {
   585  		i.node = successor(i.node)
   586  	} else {
   587  		i.node = i.tree.beginNode
   588  	}
   589  	if i.node == i.tree.endNode {
   590  		panic("out of bound iteration")
   591  	}
   592  }
   593  
   594  // Key returns a key at the iterator position
   595  func (i ReverseIterator[Key, Value]) Key() Key { return i.node.key }
   596  
   597  // Value returns a value at the iterator position
   598  func (i ReverseIterator[Key, Value]) Value() Value { return i.node.value }