github.com/songzhibin97/go-baseutils@v0.0.2-0.20240302024150-487d8ce9c082/structure/trees/avltree/avltree.go (about)

     1  // Package avltree implements an AVL balanced binary tree.
     2  //
     3  // Structure is not thread safe.
     4  //
     5  // References: https://en.wikipedia.org/wiki/AVL_tree
     6  package avltree
     7  
     8  import (
     9  	"encoding/json"
    10  	"fmt"
    11  	"strings"
    12  
    13  	"github.com/songzhibin97/go-baseutils/base/bcomparator"
    14  	"github.com/songzhibin97/go-baseutils/structure/trees"
    15  )
    16  
    17  // Assert Tree implementation
    18  var _ trees.Tree[any] = new(Tree[any, any])
    19  
    20  // Tree holds elements of the AVL tree.
    21  type Tree[K any, V any] struct {
    22  	Root       *Node[K, V]               // Root node
    23  	Comparator bcomparator.Comparator[K] // Key comparator
    24  	size       int                       // Total number of keys in the tree
    25  	zeroV      V
    26  }
    27  
    28  // Node is a single element within the tree
    29  type Node[K, V any] struct {
    30  	Key      K
    31  	Value    V
    32  	Parent   *Node[K, V]    // Parent node
    33  	Children [2]*Node[K, V] // Children nodes
    34  	b        int8
    35  }
    36  
    37  // NewWith instantiates an AVL tree with the custom comparator.
    38  func NewWith[K, V any](comparator bcomparator.Comparator[K]) *Tree[K, V] {
    39  	return &Tree[K, V]{Comparator: comparator}
    40  }
    41  
    42  // NewWithIntComparator instantiates an AVL tree with the IntComparator, i.e. keys are of type int.
    43  func NewWithIntComparator[V any]() *Tree[int, V] {
    44  	return &Tree[int, V]{Comparator: bcomparator.IntComparator()}
    45  }
    46  
    47  // NewWithStringComparator instantiates an AVL tree with the StringComparator, i.e. keys are of type string.
    48  func NewWithStringComparator[V any]() *Tree[string, V] {
    49  	return &Tree[string, V]{Comparator: bcomparator.StringComparator()}
    50  }
    51  
    52  // Put inserts node into the tree.
    53  // Key should adhere to the comparator's type assertion, otherwise method panics.
    54  func (tree *Tree[K, V]) Put(key K, value V) {
    55  	tree.put(key, value, nil, &tree.Root)
    56  }
    57  
    58  // Get searches the node in the tree by key and returns its value or nil if key is not found in tree.
    59  // Second return parameter is true if key was found, otherwise false.
    60  // Key should adhere to the comparator's type assertion, otherwise method panics.
    61  func (tree *Tree[K, V]) Get(key K) (value V, found bool) {
    62  	n := tree.GetNode(key)
    63  	if n != nil {
    64  		return n.Value, true
    65  	}
    66  	return tree.zeroV, false
    67  }
    68  
    69  // GetNode searches the node in the tree by key and returns its node or nil if key is not found in tree.
    70  // Key should adhere to the comparator's type assertion, otherwise method panics.
    71  func (tree *Tree[K, V]) GetNode(key K) *Node[K, V] {
    72  	n := tree.Root
    73  	for n != nil {
    74  		cmp := tree.Comparator(key, n.Key)
    75  		switch {
    76  		case cmp == 0:
    77  			return n
    78  		case cmp < 0:
    79  			n = n.Children[0]
    80  		case cmp > 0:
    81  			n = n.Children[1]
    82  		}
    83  	}
    84  	return n
    85  }
    86  
    87  // Remove remove the node from the tree by key.
    88  // Key should adhere to the comparator's type assertion, otherwise method panics.
    89  func (tree *Tree[K, V]) Remove(key K) {
    90  	tree.remove(key, &tree.Root)
    91  }
    92  
    93  // Empty returns true if tree does not contain any nodes.
    94  func (tree *Tree[K, V]) Empty() bool {
    95  	return tree.size == 0
    96  }
    97  
    98  // Size returns the number of elements stored in the tree.
    99  func (tree *Tree[K, V]) Size() int {
   100  	return tree.size
   101  }
   102  
   103  // Size returns the number of elements stored in the subtree.
   104  // Computed dynamically on each call, i.e. the subtree is traversed to count the number of the nodes.
   105  func (n *Node[K, V]) Size() int {
   106  	if n == nil {
   107  		return 0
   108  	}
   109  	size := 1
   110  	if n.Children[0] != nil {
   111  		size += n.Children[0].Size()
   112  	}
   113  	if n.Children[1] != nil {
   114  		size += n.Children[1].Size()
   115  	}
   116  	return size
   117  }
   118  
   119  // Keys returns all keys in-order
   120  func (tree *Tree[K, V]) Keys() []K {
   121  	keys := make([]K, tree.size)
   122  	it := tree.Iterator()
   123  	for i := 0; it.Next(); i++ {
   124  		keys[i] = it.Key()
   125  	}
   126  	return keys
   127  }
   128  
   129  // Values returns all values in-order based on the key.
   130  func (tree *Tree[K, V]) Values() []V {
   131  	values := make([]V, tree.size)
   132  	it := tree.Iterator()
   133  	for i := 0; it.Next(); i++ {
   134  		values[i] = it.Value()
   135  	}
   136  	return values
   137  }
   138  
   139  // Left returns the minimum element of the AVL tree
   140  // or nil if the tree is empty.
   141  func (tree *Tree[K, V]) Left() *Node[K, V] {
   142  	return tree.bottom(0)
   143  }
   144  
   145  // Right returns the maximum element of the AVL tree
   146  // or nil if the tree is empty.
   147  func (tree *Tree[K, V]) Right() *Node[K, V] {
   148  	return tree.bottom(1)
   149  }
   150  
   151  // Floor Finds floor node of the input key, return the floor node or nil if no floor is found.
   152  // Second return parameter is true if floor was found, otherwise false.
   153  //
   154  // Floor node is defined as the largest node that is smaller than or equal to the given node.
   155  // A floor node may not be found, either because the tree is empty, or because
   156  // all nodes in the tree is larger than the given node.
   157  //
   158  // Key should adhere to the comparator's type assertion, otherwise method panics.
   159  func (tree *Tree[K, V]) Floor(key K) (floor *Node[K, V], found bool) {
   160  	found = false
   161  	n := tree.Root
   162  	for n != nil {
   163  		c := tree.Comparator(key, n.Key)
   164  		switch {
   165  		case c == 0:
   166  			return n, true
   167  		case c < 0:
   168  			n = n.Children[0]
   169  		case c > 0:
   170  			floor, found = n, true
   171  			n = n.Children[1]
   172  		}
   173  	}
   174  	if found {
   175  		return
   176  	}
   177  	return nil, false
   178  }
   179  
   180  // Ceiling finds ceiling node of the input key, return the ceiling node or nil if no ceiling is found.
   181  // Second return parameter is true if ceiling was found, otherwise false.
   182  //
   183  // Ceiling node is defined as the smallest node that is larger than or equal to the given node.
   184  // A ceiling node may not be found, either because the tree is empty, or because
   185  // all nodes in the tree is smaller than the given node.
   186  //
   187  // Key should adhere to the comparator's type assertion, otherwise method panics.
   188  func (tree *Tree[K, V]) Ceiling(key K) (floor *Node[K, V], found bool) {
   189  	found = false
   190  	n := tree.Root
   191  	for n != nil {
   192  		c := tree.Comparator(key, n.Key)
   193  		switch {
   194  		case c == 0:
   195  			return n, true
   196  		case c < 0:
   197  			floor, found = n, true
   198  			n = n.Children[0]
   199  		case c > 0:
   200  			n = n.Children[1]
   201  		}
   202  	}
   203  	if found {
   204  		return
   205  	}
   206  	return nil, false
   207  }
   208  
   209  // Clear removes all nodes from the tree.
   210  func (tree *Tree[K, V]) Clear() {
   211  	tree.Root = nil
   212  	tree.size = 0
   213  }
   214  
   215  // String returns a string representation of container
   216  func (tree *Tree[K, V]) String() string {
   217  	b := strings.Builder{}
   218  	b.WriteString("AVLTree\n")
   219  	if !tree.Empty() {
   220  		output(tree.Root, "", true, &b)
   221  	}
   222  	return b.String()
   223  }
   224  
   225  func (n *Node[K, V]) String() string {
   226  	return fmt.Sprintf("%v", n.Key)
   227  }
   228  
   229  func (tree *Tree[K, V]) put(key K, value V, p *Node[K, V], qp **Node[K, V]) bool {
   230  	q := *qp
   231  	if q == nil {
   232  		tree.size++
   233  		*qp = &Node[K, V]{Key: key, Value: value, Parent: p}
   234  		return true
   235  	}
   236  
   237  	c := tree.Comparator(key, q.Key)
   238  	if c == 0 {
   239  		q.Key = key
   240  		q.Value = value
   241  		return false
   242  	}
   243  
   244  	if c < 0 {
   245  		c = -1
   246  	} else {
   247  		c = 1
   248  	}
   249  	a := (c + 1) / 2
   250  	var fix bool
   251  	fix = tree.put(key, value, q, &q.Children[a])
   252  	if fix {
   253  		return putFix(int8(c), qp)
   254  	}
   255  	return false
   256  }
   257  
   258  func (tree *Tree[K, V]) remove(key K, qp **Node[K, V]) bool {
   259  	q := *qp
   260  	if q == nil {
   261  		return false
   262  	}
   263  
   264  	c := tree.Comparator(key, q.Key)
   265  	if c == 0 {
   266  		tree.size--
   267  		if q.Children[1] == nil {
   268  			if q.Children[0] != nil {
   269  				q.Children[0].Parent = q.Parent
   270  			}
   271  			*qp = q.Children[0]
   272  			return true
   273  		}
   274  		fix := removeMin(&q.Children[1], &q.Key, &q.Value)
   275  		if fix {
   276  			return removeFix(-1, qp)
   277  		}
   278  		return false
   279  	}
   280  
   281  	if c < 0 {
   282  		c = -1
   283  	} else {
   284  		c = 1
   285  	}
   286  	a := (c + 1) / 2
   287  	fix := tree.remove(key, &q.Children[a])
   288  	if fix {
   289  		return removeFix(int8(-c), qp)
   290  	}
   291  	return false
   292  }
   293  
   294  func removeMin[K, V any](qp **Node[K, V], minKey *K, minVal *V) bool {
   295  	q := *qp
   296  	if q.Children[0] == nil {
   297  		*minKey = q.Key
   298  		*minVal = q.Value
   299  		if q.Children[1] != nil {
   300  			q.Children[1].Parent = q.Parent
   301  		}
   302  		*qp = q.Children[1]
   303  		return true
   304  	}
   305  	fix := removeMin(&q.Children[0], minKey, minVal)
   306  	if fix {
   307  		return removeFix(1, qp)
   308  	}
   309  	return false
   310  }
   311  
   312  func putFix[K, V any](c int8, t **Node[K, V]) bool {
   313  	s := *t
   314  	if s.b == 0 {
   315  		s.b = c
   316  		return true
   317  	}
   318  
   319  	if s.b == -c {
   320  		s.b = 0
   321  		return false
   322  	}
   323  
   324  	if s.Children[(c+1)/2].b == c {
   325  		s = singlerot(c, s)
   326  	} else {
   327  		s = doublerot(c, s)
   328  	}
   329  	*t = s
   330  	return false
   331  }
   332  
   333  func removeFix[K, V any](c int8, t **Node[K, V]) bool {
   334  	s := *t
   335  	if s.b == 0 {
   336  		s.b = c
   337  		return false
   338  	}
   339  
   340  	if s.b == -c {
   341  		s.b = 0
   342  		return true
   343  	}
   344  
   345  	a := (c + 1) / 2
   346  	if s.Children[a].b == 0 {
   347  		s = rotate(c, s)
   348  		s.b = -c
   349  		*t = s
   350  		return false
   351  	}
   352  
   353  	if s.Children[a].b == c {
   354  		s = singlerot(c, s)
   355  	} else {
   356  		s = doublerot(c, s)
   357  	}
   358  	*t = s
   359  	return true
   360  }
   361  
   362  func singlerot[K, V any](c int8, s *Node[K, V]) *Node[K, V] {
   363  	s.b = 0
   364  	s = rotate(c, s)
   365  	s.b = 0
   366  	return s
   367  }
   368  
   369  func doublerot[K, V any](c int8, s *Node[K, V]) *Node[K, V] {
   370  	a := (c + 1) / 2
   371  	r := s.Children[a]
   372  	s.Children[a] = rotate(-c, s.Children[a])
   373  	p := rotate(c, s)
   374  
   375  	switch {
   376  	default:
   377  		s.b = 0
   378  		r.b = 0
   379  	case p.b == c:
   380  		s.b = -c
   381  		r.b = 0
   382  	case p.b == -c:
   383  		s.b = 0
   384  		r.b = c
   385  	}
   386  
   387  	p.b = 0
   388  	return p
   389  }
   390  
   391  func rotate[K, V any](c int8, s *Node[K, V]) *Node[K, V] {
   392  	a := (c + 1) / 2
   393  	r := s.Children[a]
   394  	s.Children[a] = r.Children[a^1]
   395  	if s.Children[a] != nil {
   396  		s.Children[a].Parent = s
   397  	}
   398  	r.Children[a^1] = s
   399  	r.Parent = s.Parent
   400  	s.Parent = r
   401  	return r
   402  }
   403  
   404  func (tree *Tree[K, V]) bottom(d int) *Node[K, V] {
   405  	n := tree.Root
   406  	if n == nil {
   407  		return nil
   408  	}
   409  
   410  	for c := n.Children[d]; c != nil; c = n.Children[d] {
   411  		n = c
   412  	}
   413  	return n
   414  }
   415  
   416  // Prev returns the previous element in an inorder
   417  // walk of the AVL tree.
   418  func (n *Node[K, V]) Prev() *Node[K, V] {
   419  	return n.walk1(0)
   420  }
   421  
   422  // Next returns the next element in an inorder
   423  // walk of the AVL tree.
   424  func (n *Node[K, V]) Next() *Node[K, V] {
   425  	return n.walk1(1)
   426  }
   427  
   428  func (n *Node[K, V]) walk1(a int) *Node[K, V] {
   429  	if n == nil {
   430  		return nil
   431  	}
   432  
   433  	if n.Children[a] != nil {
   434  		n = n.Children[a]
   435  		for n.Children[a^1] != nil {
   436  			n = n.Children[a^1]
   437  		}
   438  		return n
   439  	}
   440  
   441  	p := n.Parent
   442  	for p != nil && p.Children[a] == n {
   443  		n = p
   444  		p = p.Parent
   445  	}
   446  	return p
   447  }
   448  
   449  func output[K, V any](node *Node[K, V], prefix string, isTail bool, builder *strings.Builder) {
   450  	if node.Children[1] != nil {
   451  		newPrefix := prefix
   452  		if isTail {
   453  			newPrefix += "│   "
   454  		} else {
   455  			newPrefix += "    "
   456  		}
   457  		output(node.Children[1], newPrefix, false, builder)
   458  	}
   459  	builder.WriteString(prefix)
   460  	if isTail {
   461  		builder.WriteString("└── ")
   462  	} else {
   463  		builder.WriteString("┌── ")
   464  	}
   465  	builder.WriteString(node.String() + "\n")
   466  	if node.Children[0] != nil {
   467  		newPrefix := prefix
   468  		if isTail {
   469  			newPrefix += "    "
   470  		} else {
   471  			newPrefix += "│   "
   472  		}
   473  		output(node.Children[0], newPrefix, true, builder)
   474  	}
   475  }
   476  
   477  // UnmarshalJSON @implements json.Unmarshaler
   478  func (tree *Tree[K, V]) UnmarshalJSON(data []byte) error {
   479  	elements := make(map[string]V)
   480  	err := json.Unmarshal(data, &elements)
   481  	if err == nil {
   482  		tree.Clear()
   483  		for key, value := range elements {
   484  			var nk K
   485  			err = tree.Comparator.Unmarshal([]byte(key), &nk)
   486  			if err != nil {
   487  				return err
   488  			}
   489  			tree.Put(nk, value)
   490  		}
   491  	}
   492  	return err
   493  }
   494  
   495  // MarshalJSON @implements json.Marshaler
   496  func (tree *Tree[K, V]) MarshalJSON() ([]byte, error) {
   497  	elements := make(map[string]V)
   498  	it := tree.Iterator()
   499  	for it.Next() {
   500  		k, err := tree.Comparator.Marshal(it.Key())
   501  		if err != nil {
   502  			return nil, err
   503  		}
   504  		elements[string(k)] = it.Value()
   505  	}
   506  	return json.Marshal(&elements)
   507  }