github.com/alexandrestein/gods@v1.0.1/trees/avltree/avltree.go (about)

     1  // Copyright (c) 2017, Benjamin Scher Purcell. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package avltree implements an AVL balanced binary tree.
     6  //
     7  // Structure is not thread safe.
     8  //
     9  // References: https://en.wikipedia.org/wiki/AVL_tree
    10  package avltree
    11  
    12  import (
    13  	"fmt"
    14  	"github.com/alexandrestein/gods/trees"
    15  	"github.com/alexandrestein/gods/utils"
    16  )
    17  
    18  func assertTreeImplementation() {
    19  	var _ trees.Tree = new(Tree)
    20  }
    21  
    22  // Tree holds elements of the AVL tree.
    23  type Tree struct {
    24  	Root       *Node            // Root node
    25  	Comparator utils.Comparator // Key comparator
    26  	size       int              // Total number of keys in the tree
    27  }
    28  
    29  // Node is a single element within the tree
    30  type Node struct {
    31  	Key      interface{}
    32  	Value    interface{}
    33  	Parent   *Node    // Parent node
    34  	Children [2]*Node // Children nodes
    35  	b        int8
    36  }
    37  
    38  // NewWith instantiates an AVL tree with the custom comparator.
    39  func NewWith(comparator utils.Comparator) *Tree {
    40  	return &Tree{Comparator: comparator}
    41  }
    42  
    43  // NewWithIntComparator instantiates an AVL tree with the IntComparator, i.e. keys are of type int.
    44  func NewWithIntComparator() *Tree {
    45  	return &Tree{Comparator: utils.IntComparator}
    46  }
    47  
    48  // NewWithStringComparator instantiates an AVL tree with the StringComparator, i.e. keys are of type string.
    49  func NewWithStringComparator() *Tree {
    50  	return &Tree{Comparator: utils.StringComparator}
    51  }
    52  
    53  // Put inserts node into the tree.
    54  // Key should adhere to the comparator's type assertion, otherwise method panics.
    55  func (t *Tree) Put(key interface{}, value interface{}) {
    56  	t.put(key, value, nil, &t.Root)
    57  }
    58  
    59  // Get searches the node in the tree by key and returns its value or nil if key is not found in tree.
    60  // Second return parameter is true if key was found, otherwise false.
    61  // Key should adhere to the comparator's type assertion, otherwise method panics.
    62  func (t *Tree) Get(key interface{}) (value interface{}, found bool) {
    63  	n := t.Root
    64  	for n != nil {
    65  		cmp := t.Comparator(key, n.Key)
    66  		switch {
    67  		case cmp == 0:
    68  			return n.Value, true
    69  		case cmp < 0:
    70  			n = n.Children[0]
    71  		case cmp > 0:
    72  			n = n.Children[1]
    73  		}
    74  	}
    75  	return nil, false
    76  }
    77  
    78  // Remove remove the node from the tree by key.
    79  // Key should adhere to the comparator's type assertion, otherwise method panics.
    80  func (t *Tree) Remove(key interface{}) {
    81  	t.remove(key, &t.Root)
    82  }
    83  
    84  // Empty returns true if tree does not contain any nodes.
    85  func (t *Tree) Empty() bool {
    86  	return t.size == 0
    87  }
    88  
    89  // Size returns the number of elements stored in the tree.
    90  func (t *Tree) Size() int {
    91  	return t.size
    92  }
    93  
    94  // Keys returns all keys in-order
    95  func (t *Tree) Keys() []interface{} {
    96  	keys := make([]interface{}, t.size)
    97  	it := t.Iterator()
    98  	for i := 0; it.Next(); i++ {
    99  		keys[i] = it.Key()
   100  	}
   101  	return keys
   102  }
   103  
   104  // Values returns all values in-order based on the key.
   105  func (t *Tree) Values() []interface{} {
   106  	values := make([]interface{}, t.size)
   107  	it := t.Iterator()
   108  	for i := 0; it.Next(); i++ {
   109  		values[i] = it.Value()
   110  	}
   111  	return values
   112  }
   113  
   114  // Left returns the minimum element of the AVL tree
   115  // or nil if the tree is empty.
   116  func (t *Tree) Left() *Node {
   117  	return t.bottom(0)
   118  }
   119  
   120  // Right returns the maximum element of the AVL tree
   121  // or nil if the tree is empty.
   122  func (t *Tree) Right() *Node {
   123  	return t.bottom(1)
   124  }
   125  
   126  // Floor Finds floor node of the input key, return the floor node or nil if no ceiling is found.
   127  // Second return parameter is true if floor was found, otherwise false.
   128  //
   129  // Floor node is defined as the largest node that is smaller than or equal to the given node.
   130  // A floor node may not be found, either because the tree is empty, or because
   131  // all nodes in the tree is larger than the given node.
   132  //
   133  // Key should adhere to the comparator's type assertion, otherwise method panics.
   134  func (t *Tree) Floor(key interface{}) (floor *Node, found bool) {
   135  	found = false
   136  	n := t.Root
   137  	for n != nil {
   138  		c := t.Comparator(key, n.Key)
   139  		switch {
   140  		case c == 0:
   141  			return n, true
   142  		case c < 0:
   143  			n = n.Children[0]
   144  		case c > 0:
   145  			floor, found = n, true
   146  			n = n.Children[1]
   147  		}
   148  	}
   149  	if found {
   150  		return
   151  	}
   152  	return nil, false
   153  }
   154  
   155  // Ceiling finds ceiling node of the input key, return the ceiling node or nil if no ceiling is found.
   156  // Second return parameter is true if ceiling was found, otherwise false.
   157  //
   158  // Ceiling node is defined as the smallest node that is larger than or equal to the given node.
   159  // A ceiling node may not be found, either because the tree is empty, or because
   160  // all nodes in the tree is smaller than the given node.
   161  //
   162  // Key should adhere to the comparator's type assertion, otherwise method panics.
   163  func (t *Tree) Ceiling(key interface{}) (floor *Node, found bool) {
   164  	found = false
   165  	n := t.Root
   166  	for n != nil {
   167  		c := t.Comparator(key, n.Key)
   168  		switch {
   169  		case c == 0:
   170  			return n, true
   171  		case c < 0:
   172  			floor, found = n, true
   173  			n = n.Children[0]
   174  		case c > 0:
   175  			n = n.Children[1]
   176  		}
   177  	}
   178  	if found {
   179  		return
   180  	}
   181  	return nil, false
   182  }
   183  
   184  // Clear removes all nodes from the tree.
   185  func (t *Tree) Clear() {
   186  	t.Root = nil
   187  	t.size = 0
   188  }
   189  
   190  // String returns a string representation of container
   191  func (t *Tree) String() string {
   192  	str := "AVLTree\n"
   193  	if !t.Empty() {
   194  		output(t.Root, "", true, &str)
   195  	}
   196  	return str
   197  }
   198  
   199  func (n *Node) String() string {
   200  	return fmt.Sprintf("%v", n.Key)
   201  }
   202  
   203  func (t *Tree) put(key interface{}, value interface{}, p *Node, qp **Node) bool {
   204  	q := *qp
   205  	if q == nil {
   206  		t.size++
   207  		*qp = &Node{Key: key, Value: value, Parent: p}
   208  		return true
   209  	}
   210  
   211  	c := t.Comparator(key, q.Key)
   212  	if c == 0 {
   213  		q.Key = key
   214  		q.Value = value
   215  		return false
   216  	}
   217  
   218  	if c < 0 {
   219  		c = -1
   220  	} else {
   221  		c = 1
   222  	}
   223  	a := (c + 1) / 2
   224  	var fix bool
   225  	fix = t.put(key, value, q, &q.Children[a])
   226  	if fix {
   227  		return putFix(int8(c), qp)
   228  	}
   229  	return false
   230  }
   231  
   232  func (t *Tree) remove(key interface{}, qp **Node) bool {
   233  	q := *qp
   234  	if q == nil {
   235  		return false
   236  	}
   237  
   238  	c := t.Comparator(key, q.Key)
   239  	if c == 0 {
   240  		t.size--
   241  		if q.Children[1] == nil {
   242  			if q.Children[0] != nil {
   243  				q.Children[0].Parent = q.Parent
   244  			}
   245  			*qp = q.Children[0]
   246  			return true
   247  		}
   248  		fix := removeMin(&q.Children[1], &q.Key, &q.Value)
   249  		if fix {
   250  			return removeFix(-1, qp)
   251  		}
   252  		return false
   253  	}
   254  
   255  	if c < 0 {
   256  		c = -1
   257  	} else {
   258  		c = 1
   259  	}
   260  	a := (c + 1) / 2
   261  	fix := t.remove(key, &q.Children[a])
   262  	if fix {
   263  		return removeFix(int8(-c), qp)
   264  	}
   265  	return false
   266  }
   267  
   268  func removeMin(qp **Node, minKey *interface{}, minVal *interface{}) bool {
   269  	q := *qp
   270  	if q.Children[0] == nil {
   271  		*minKey = q.Key
   272  		*minVal = q.Value
   273  		if q.Children[1] != nil {
   274  			q.Children[1].Parent = q.Parent
   275  		}
   276  		*qp = q.Children[1]
   277  		return true
   278  	}
   279  	fix := removeMin(&q.Children[0], minKey, minVal)
   280  	if fix {
   281  		return removeFix(1, qp)
   282  	}
   283  	return false
   284  }
   285  
   286  func putFix(c int8, t **Node) bool {
   287  	s := *t
   288  	if s.b == 0 {
   289  		s.b = c
   290  		return true
   291  	}
   292  
   293  	if s.b == -c {
   294  		s.b = 0
   295  		return false
   296  	}
   297  
   298  	if s.Children[(c+1)/2].b == c {
   299  		s = singlerot(c, s)
   300  	} else {
   301  		s = doublerot(c, s)
   302  	}
   303  	*t = s
   304  	return false
   305  }
   306  
   307  func removeFix(c int8, t **Node) bool {
   308  	s := *t
   309  	if s.b == 0 {
   310  		s.b = c
   311  		return false
   312  	}
   313  
   314  	if s.b == -c {
   315  		s.b = 0
   316  		return true
   317  	}
   318  
   319  	a := (c + 1) / 2
   320  	if s.Children[a].b == 0 {
   321  		s = rotate(c, s)
   322  		s.b = -c
   323  		*t = s
   324  		return false
   325  	}
   326  
   327  	if s.Children[a].b == c {
   328  		s = singlerot(c, s)
   329  	} else {
   330  		s = doublerot(c, s)
   331  	}
   332  	*t = s
   333  	return true
   334  }
   335  
   336  func singlerot(c int8, s *Node) *Node {
   337  	s.b = 0
   338  	s = rotate(c, s)
   339  	s.b = 0
   340  	return s
   341  }
   342  
   343  func doublerot(c int8, s *Node) *Node {
   344  	a := (c + 1) / 2
   345  	r := s.Children[a]
   346  	s.Children[a] = rotate(-c, s.Children[a])
   347  	p := rotate(c, s)
   348  
   349  	switch {
   350  	default:
   351  		s.b = 0
   352  		r.b = 0
   353  	case p.b == c:
   354  		s.b = -c
   355  		r.b = 0
   356  	case p.b == -c:
   357  		s.b = 0
   358  		r.b = c
   359  	}
   360  
   361  	p.b = 0
   362  	return p
   363  }
   364  
   365  func rotate(c int8, s *Node) *Node {
   366  	a := (c + 1) / 2
   367  	r := s.Children[a]
   368  	s.Children[a] = r.Children[a^1]
   369  	if s.Children[a] != nil {
   370  		s.Children[a].Parent = s
   371  	}
   372  	r.Children[a^1] = s
   373  	r.Parent = s.Parent
   374  	s.Parent = r
   375  	return r
   376  }
   377  
   378  func (t *Tree) bottom(d int) *Node {
   379  	n := t.Root
   380  	if n == nil {
   381  		return nil
   382  	}
   383  
   384  	for c := n.Children[d]; c != nil; c = n.Children[d] {
   385  		n = c
   386  	}
   387  	return n
   388  }
   389  
   390  // Prev returns the previous element in an inorder
   391  // walk of the AVL tree.
   392  func (n *Node) Prev() *Node {
   393  	return n.walk1(0)
   394  }
   395  
   396  // Next returns the next element in an inorder
   397  // walk of the AVL tree.
   398  func (n *Node) Next() *Node {
   399  	return n.walk1(1)
   400  }
   401  
   402  func (n *Node) walk1(a int) *Node {
   403  	if n == nil {
   404  		return nil
   405  	}
   406  
   407  	if n.Children[a] != nil {
   408  		n = n.Children[a]
   409  		for n.Children[a^1] != nil {
   410  			n = n.Children[a^1]
   411  		}
   412  		return n
   413  	}
   414  
   415  	p := n.Parent
   416  	for p != nil && p.Children[a] == n {
   417  		n = p
   418  		p = p.Parent
   419  	}
   420  	return p
   421  }
   422  
   423  func output(node *Node, prefix string, isTail bool, str *string) {
   424  	if node.Children[1] != nil {
   425  		newPrefix := prefix
   426  		if isTail {
   427  			newPrefix += "│   "
   428  		} else {
   429  			newPrefix += "    "
   430  		}
   431  		output(node.Children[1], newPrefix, false, str)
   432  	}
   433  	*str += prefix
   434  	if isTail {
   435  		*str += "└── "
   436  	} else {
   437  		*str += "┌── "
   438  	}
   439  	*str += node.String() + "\n"
   440  	if node.Children[0] != nil {
   441  		newPrefix := prefix
   442  		if isTail {
   443  			newPrefix += "    "
   444  		} else {
   445  			newPrefix += "│   "
   446  		}
   447  		output(node.Children[0], newPrefix, true, str)
   448  	}
   449  }