github.com/sagernet/quic-go@v0.43.1-beta.1/internal/utils/tree/tree.go (about)

     1  // Originated from https://github.com/ross-oreto/go-tree/blob/master/btree.go with the following changes:
     2  // 1. Genericized the code
     3  // 2. Added Match function for our frame sorter use case
     4  // 3. Fixed a bug in deleteNode where in some cases the deleted flag was not set to true
     5  
     6  /*
     7  Copyright (c) 2017 Ross Oreto
     8  
     9  Permission is hereby granted, free of charge, to any person obtaining a copy
    10  of this software and associated documentation files (the "Software"), to deal
    11  in the Software without restriction, including without limitation the rights
    12  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    13  copies of the Software, and to permit persons to whom the Software is
    14  furnished to do so, subject to the following conditions:
    15  
    16  The above copyright notice and this permission notice shall be included in all
    17  copies or substantial portions of the Software.
    18  
    19  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    20  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    21  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    22  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    23  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    24  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    25  SOFTWARE.
    26  */
    27  
    28  package tree
    29  
    30  import (
    31  	"fmt"
    32  )
    33  
    34  type Val[T any] interface {
    35  	Comp(val T) int8   // returns 1 if > val, -1 if < val, 0 if equals to val
    36  	Match(cond T) int8 // returns 1 if > cond, -1 if < cond, 0 if matches cond
    37  }
    38  
    39  // Btree represents an AVL tree
    40  type Btree[T Val[T]] struct {
    41  	root   *Node[T]
    42  	values []T
    43  	len    int
    44  }
    45  
    46  // Node represents a node in the tree with a value, left and right children, and a height/balance of the node.
    47  type Node[T Val[T]] struct {
    48  	Value       T
    49  	left, right *Node[T]
    50  	height      int8
    51  }
    52  
    53  // New returns a new btree
    54  func New[T Val[T]]() *Btree[T] { return new(Btree[T]).Init() }
    55  
    56  // Init initializes all values/clears the tree and returns the tree pointer
    57  func (t *Btree[T]) Init() *Btree[T] {
    58  	t.root = nil
    59  	t.values = nil
    60  	t.len = 0
    61  	return t
    62  }
    63  
    64  // String returns a string representation of the tree values
    65  func (t *Btree[T]) String() string {
    66  	return fmt.Sprint(t.Values())
    67  }
    68  
    69  // Empty returns true if the tree is empty
    70  func (t *Btree[T]) Empty() bool {
    71  	return t.root == nil
    72  }
    73  
    74  // NotEmpty returns true if the tree is not empty
    75  func (t *Btree[T]) NotEmpty() bool {
    76  	return t.root != nil
    77  }
    78  
    79  // Insert inserts a new value into the tree and returns the tree pointer
    80  func (t *Btree[T]) Insert(value T) *Btree[T] {
    81  	added := false
    82  	t.root = insert(t.root, value, &added)
    83  	if added {
    84  		t.len++
    85  	}
    86  	t.values = nil
    87  	return t
    88  }
    89  
    90  func insert[T Val[T]](n *Node[T], value T, added *bool) *Node[T] {
    91  	if n == nil {
    92  		*added = true
    93  		return (&Node[T]{Value: value}).Init()
    94  	}
    95  	c := value.Comp(n.Value)
    96  	if c > 0 {
    97  		n.right = insert(n.right, value, added)
    98  	} else if c < 0 {
    99  		n.left = insert(n.left, value, added)
   100  	} else {
   101  		n.Value = value
   102  		*added = false
   103  		return n
   104  	}
   105  
   106  	n.height = n.maxHeight() + 1
   107  	c = balance(n)
   108  
   109  	if c > 1 {
   110  		c = value.Comp(n.left.Value)
   111  		if c < 0 {
   112  			return n.rotateRight()
   113  		} else if c > 0 {
   114  			n.left = n.left.rotateLeft()
   115  			return n.rotateRight()
   116  		}
   117  	} else if c < -1 {
   118  		c = value.Comp(n.right.Value)
   119  		if c > 0 {
   120  			return n.rotateLeft()
   121  		} else if c < 0 {
   122  			n.right = n.right.rotateRight()
   123  			return n.rotateLeft()
   124  		}
   125  	}
   126  	return n
   127  }
   128  
   129  // InsertAll inserts all the values into the tree and returns the tree pointer
   130  func (t *Btree[T]) InsertAll(values []T) *Btree[T] {
   131  	for _, v := range values {
   132  		t.Insert(v)
   133  	}
   134  	return t
   135  }
   136  
   137  // Contains returns true if the tree contains the specified value
   138  func (t *Btree[T]) Contains(value T) bool {
   139  	return t.Get(value) != nil
   140  }
   141  
   142  // ContainsAny returns true if the tree contains any of the values
   143  func (t *Btree[T]) ContainsAny(values []T) bool {
   144  	for _, v := range values {
   145  		if t.Contains(v) {
   146  			return true
   147  		}
   148  	}
   149  	return false
   150  }
   151  
   152  // ContainsAll returns true if the tree contains all of the values
   153  func (t *Btree[T]) ContainsAll(values []T) bool {
   154  	for _, v := range values {
   155  		if !t.Contains(v) {
   156  			return false
   157  		}
   158  	}
   159  	return true
   160  }
   161  
   162  // Get returns the node value associated with the search value
   163  func (t *Btree[T]) Get(value T) *T {
   164  	var node *Node[T]
   165  	if t.root != nil {
   166  		node = t.root.get(value)
   167  	}
   168  	if node != nil {
   169  		return &node.Value
   170  	}
   171  	return nil
   172  }
   173  
   174  func (t *Btree[T]) Match(cond T) []T {
   175  	var matches []T
   176  	if t.root != nil {
   177  		t.root.match(cond, &matches)
   178  	}
   179  	return matches
   180  }
   181  
   182  // Len return the number of nodes in the tree
   183  func (t *Btree[T]) Len() int {
   184  	return t.len
   185  }
   186  
   187  // Head returns the first value in the tree
   188  func (t *Btree[T]) Head() *T {
   189  	if t.root == nil {
   190  		return nil
   191  	}
   192  	beginning := t.root
   193  	for beginning.left != nil {
   194  		beginning = beginning.left
   195  	}
   196  	if beginning == nil {
   197  		for beginning.right != nil {
   198  			beginning = beginning.right
   199  		}
   200  	}
   201  	if beginning != nil {
   202  		return &beginning.Value
   203  	}
   204  	return nil
   205  }
   206  
   207  // Tail returns the last value in the tree
   208  func (t *Btree[T]) Tail() *T {
   209  	if t.root == nil {
   210  		return nil
   211  	}
   212  	beginning := t.root
   213  	for beginning.right != nil {
   214  		beginning = beginning.right
   215  	}
   216  	if beginning == nil {
   217  		for beginning.left != nil {
   218  			beginning = beginning.left
   219  		}
   220  	}
   221  	if beginning != nil {
   222  		return &beginning.Value
   223  	}
   224  	return nil
   225  }
   226  
   227  // Values returns a slice of all the values in tree in order
   228  func (t *Btree[T]) Values() []T {
   229  	if t.values == nil {
   230  		t.values = make([]T, t.len)
   231  		t.Ascend(func(n *Node[T], i int) bool {
   232  			t.values[i] = n.Value
   233  			return true
   234  		})
   235  	}
   236  	return t.values
   237  }
   238  
   239  // Delete deletes the node from the tree associated with the search value
   240  func (t *Btree[T]) Delete(value T) *Btree[T] {
   241  	deleted := false
   242  	t.root = deleteNode(t.root, value, &deleted)
   243  	if deleted {
   244  		t.len--
   245  	}
   246  	t.values = nil
   247  	return t
   248  }
   249  
   250  // DeleteAll deletes the nodes from the tree associated with the search values
   251  func (t *Btree[T]) DeleteAll(values []T) *Btree[T] {
   252  	for _, v := range values {
   253  		t.Delete(v)
   254  	}
   255  	return t
   256  }
   257  
   258  func deleteNode[T Val[T]](n *Node[T], value T, deleted *bool) *Node[T] {
   259  	if n == nil {
   260  		return n
   261  	}
   262  
   263  	c := value.Comp(n.Value)
   264  
   265  	if c < 0 {
   266  		n.left = deleteNode(n.left, value, deleted)
   267  	} else if c > 0 {
   268  		n.right = deleteNode(n.right, value, deleted)
   269  	} else {
   270  		if n.left == nil {
   271  			t := n.right
   272  			n.Init()
   273  			*deleted = true
   274  			return t
   275  		} else if n.right == nil {
   276  			t := n.left
   277  			n.Init()
   278  			*deleted = true
   279  			return t
   280  		}
   281  		t := n.right.min()
   282  		n.Value = t.Value
   283  		n.right = deleteNode(n.right, t.Value, deleted)
   284  		*deleted = true
   285  	}
   286  
   287  	// re-balance
   288  	if n == nil {
   289  		return n
   290  	}
   291  	n.height = n.maxHeight() + 1
   292  	bal := balance(n)
   293  	if bal > 1 {
   294  		if balance(n.left) >= 0 {
   295  			return n.rotateRight()
   296  		}
   297  		n.left = n.left.rotateLeft()
   298  		return n.rotateRight()
   299  	} else if bal < -1 {
   300  		if balance(n.right) <= 0 {
   301  			return n.rotateLeft()
   302  		}
   303  		n.right = n.right.rotateRight()
   304  		return n.rotateLeft()
   305  	}
   306  
   307  	return n
   308  }
   309  
   310  // Pop deletes the last node from the tree and returns its value
   311  func (t *Btree[T]) Pop() *T {
   312  	value := t.Tail()
   313  	if value != nil {
   314  		t.Delete(*value)
   315  	}
   316  	return value
   317  }
   318  
   319  // Pull deletes the first node from the tree and returns its value
   320  func (t *Btree[T]) Pull() *T {
   321  	value := t.Head()
   322  	if value != nil {
   323  		t.Delete(*value)
   324  	}
   325  	return value
   326  }
   327  
   328  // NodeIterator expresses the iterator function used for traversals
   329  type NodeIterator[T Val[T]] func(n *Node[T], i int) bool
   330  
   331  // Ascend performs an ascending order traversal of the tree calling the iterator function on each node
   332  // the iterator will continue as long as the NodeIterator returns true
   333  func (t *Btree[T]) Ascend(iterator NodeIterator[T]) {
   334  	var i int
   335  	if t.root != nil {
   336  		t.root.iterate(iterator, &i, true)
   337  	}
   338  }
   339  
   340  // Descend performs a descending order traversal of the tree using the iterator
   341  // the iterator will continue as long as the NodeIterator returns true
   342  func (t *Btree[T]) Descend(iterator NodeIterator[T]) {
   343  	var i int
   344  	if t.root != nil {
   345  		t.root.rIterate(iterator, &i, true)
   346  	}
   347  }
   348  
   349  // Debug prints out useful debug information about the tree for debugging purposes
   350  func (t *Btree[T]) Debug() {
   351  	fmt.Println("----------------------------------------------------------------------------------------------")
   352  	if t.Empty() {
   353  		fmt.Println("tree is empty")
   354  	} else {
   355  		fmt.Println(t.Len(), "elements")
   356  	}
   357  
   358  	t.Ascend(func(n *Node[T], i int) bool {
   359  		if t.root.Value.Comp(n.Value) == 0 {
   360  			fmt.Print("ROOT ** ")
   361  		}
   362  		n.Debug()
   363  		return true
   364  	})
   365  	fmt.Println("----------------------------------------------------------------------------------------------")
   366  }
   367  
   368  // Init initializes the values of the node or clears the node and returns the node pointer
   369  func (n *Node[T]) Init() *Node[T] {
   370  	n.height = 1
   371  	n.left = nil
   372  	n.right = nil
   373  	return n
   374  }
   375  
   376  // String returns a string representing the node
   377  func (n *Node[T]) String() string {
   378  	return fmt.Sprint(n.Value)
   379  }
   380  
   381  // Debug prints out useful debug information about the tree node for debugging purposes
   382  func (n *Node[T]) Debug() {
   383  	var children string
   384  	if n.left == nil && n.right == nil {
   385  		children = "no children |"
   386  	} else if n.left != nil && n.right != nil {
   387  		children = fmt.Sprint("left child:", n.left.String(), " right child:", n.right.String())
   388  	} else if n.right != nil {
   389  		children = fmt.Sprint("right child:", n.right.String())
   390  	} else {
   391  		children = fmt.Sprint("left child:", n.left.String())
   392  	}
   393  
   394  	fmt.Println(n.String(), "|", "height", n.height, "|", "balance", balance(n), "|", children)
   395  }
   396  
   397  func height[T Val[T]](n *Node[T]) int8 {
   398  	if n != nil {
   399  		return n.height
   400  	}
   401  	return 0
   402  }
   403  
   404  func balance[T Val[T]](n *Node[T]) int8 {
   405  	if n == nil {
   406  		return 0
   407  	}
   408  	return height(n.left) - height(n.right)
   409  }
   410  
   411  func (n *Node[T]) get(val T) *Node[T] {
   412  	var node *Node[T]
   413  	c := val.Comp(n.Value)
   414  	if c < 0 {
   415  		if n.left != nil {
   416  			node = n.left.get(val)
   417  		}
   418  	} else if c > 0 {
   419  		if n.right != nil {
   420  			node = n.right.get(val)
   421  		}
   422  	} else {
   423  		node = n
   424  	}
   425  	return node
   426  }
   427  
   428  func (n *Node[T]) match(cond T, results *[]T) {
   429  	c := n.Value.Match(cond)
   430  	if c > 0 {
   431  		if n.left != nil {
   432  			n.left.match(cond, results)
   433  		}
   434  	} else if c < 0 {
   435  		if n.right != nil {
   436  			n.right.match(cond, results)
   437  		}
   438  	} else {
   439  		// other matching nodes could be on both sides
   440  		if n.left != nil {
   441  			n.left.match(cond, results)
   442  		}
   443  		*results = append(*results, n.Value)
   444  		if n.right != nil {
   445  			n.right.match(cond, results)
   446  		}
   447  	}
   448  }
   449  
   450  func (n *Node[T]) rotateRight() *Node[T] {
   451  	l := n.left
   452  	// Rotation
   453  	l.right, n.left = n, l.right
   454  
   455  	// update heights
   456  	n.height = n.maxHeight() + 1
   457  	l.height = l.maxHeight() + 1
   458  
   459  	return l
   460  }
   461  
   462  func (n *Node[T]) rotateLeft() *Node[T] {
   463  	r := n.right
   464  	// Rotation
   465  	r.left, n.right = n, r.left
   466  
   467  	// update heights
   468  	n.height = n.maxHeight() + 1
   469  	r.height = r.maxHeight() + 1
   470  
   471  	return r
   472  }
   473  
   474  func (n *Node[T]) iterate(iterator NodeIterator[T], i *int, cont bool) {
   475  	if n != nil && cont {
   476  		n.left.iterate(iterator, i, cont)
   477  		cont = iterator(n, *i)
   478  		*i++
   479  		n.right.iterate(iterator, i, cont)
   480  	}
   481  }
   482  
   483  func (n *Node[T]) rIterate(iterator NodeIterator[T], i *int, cont bool) {
   484  	if n != nil && cont {
   485  		n.right.iterate(iterator, i, cont)
   486  		cont = iterator(n, *i)
   487  		*i++
   488  		n.left.iterate(iterator, i, cont)
   489  	}
   490  }
   491  
   492  func (n *Node[T]) min() *Node[T] {
   493  	current := n
   494  	for current.left != nil {
   495  		current = current.left
   496  	}
   497  	return current
   498  }
   499  
   500  func (n *Node[T]) maxHeight() int8 {
   501  	rh := height(n.right)
   502  	lh := height(n.left)
   503  	if rh > lh {
   504  		return rh
   505  	}
   506  	return lh
   507  }