go.mway.dev/x@v0.0.0-20240520034138-950aede9a3fb/container/tree/tree.go (about)

     1  // Copyright (c) 2024 Matt Way
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to
     5  // deal in the Software without restriction, including without limitation the
     6  // rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
     7  // sell copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
    18  // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
    19  // IN THE THE SOFTWARE.
    20  
    21  // Package tree provides tree structure related types and utilities.
    22  package tree
    23  
    24  import (
    25  	"errors"
    26  	"maps"
    27  	"slices"
    28  	"sort"
    29  
    30  	"golang.org/x/exp/constraints"
    31  	xmaps "golang.org/x/exp/maps"
    32  )
    33  
    34  // ErrSkipSubtree is a sentinel value that can be returned by a [NodeWalker] to
    35  // skip the remainder of a subtree. This error halts iteration of that subtree
    36  // at the first handling parent.
    37  var ErrSkipSubtree = errors.New("skip subtree")
    38  
    39  // A NodeWalker is a function that controls how nodes are walked.
    40  type NodeWalker[K constraints.Ordered, V comparable] func(node *BasicNode[K, V]) error
    41  
    42  // BasicNode is a basic, arbitrarily-ordered, non-balancing, key/value tree.
    43  type BasicNode[K constraints.Ordered, V comparable] struct {
    44  	key      K
    45  	value    V
    46  	parent   *BasicNode[K, V]
    47  	children map[K]*BasicNode[K, V]
    48  }
    49  
    50  // NewBasicNode creates a new [BasicNode].
    51  func NewBasicNode[K constraints.Ordered, V comparable](
    52  	key K,
    53  	value V,
    54  ) *BasicNode[K, V] {
    55  	return &BasicNode[K, V]{
    56  		key:   key,
    57  		value: value,
    58  	}
    59  }
    60  
    61  // Key returns the node's key.
    62  func (n *BasicNode[K, V]) Key() (key K) {
    63  	if n == nil {
    64  		return
    65  	}
    66  	return n.key
    67  }
    68  
    69  // Path returns all keys in the path from the root node to this node.
    70  func (n *BasicNode[K, V]) Path() []K {
    71  	keys := n.PathRev()
    72  	slices.Reverse(keys)
    73  	return keys
    74  }
    75  
    76  // PathRev returns all keys in the path from this node to the root.
    77  func (n *BasicNode[K, V]) PathRev() []K {
    78  	var (
    79  		keys = []K{n.key}
    80  		cur  = n
    81  	)
    82  
    83  	for cur.parent != nil {
    84  		cur = cur.parent
    85  		keys = append(keys, cur.key)
    86  	}
    87  
    88  	return keys
    89  }
    90  
    91  // Value returns the node's value.
    92  func (n *BasicNode[K, V]) Value() (value V) {
    93  	if n == nil {
    94  		return
    95  	}
    96  	return n.value
    97  }
    98  
    99  // SetValue sets the node's value.
   100  func (n *BasicNode[K, V]) SetValue(value V) {
   101  	if n == nil {
   102  		return
   103  	}
   104  	n.value = value
   105  }
   106  
   107  // Parent returns the node's parent node.
   108  func (n *BasicNode[K, V]) Parent() *BasicNode[K, V] {
   109  	if n == nil {
   110  		return nil
   111  	}
   112  	return n.parent
   113  }
   114  
   115  // Child returns the child of with the given key, if one exists. If no child
   116  // with the given key is found, nil is returned.
   117  func (n *BasicNode[K, V]) Child(key K) *BasicNode[K, V] {
   118  	if n == nil {
   119  		return nil
   120  	}
   121  
   122  	c, ok := n.children[key]
   123  	if !ok {
   124  		return nil
   125  	}
   126  	return c
   127  }
   128  
   129  // Children returns the node's children.
   130  func (n *BasicNode[K, V]) Children() map[K]*BasicNode[K, V] {
   131  	switch {
   132  	case n == nil:
   133  		return nil
   134  	case len(n.children) == 0:
   135  		return nil
   136  	default:
   137  		return maps.Clone(n.children)
   138  	}
   139  }
   140  
   141  // SetParent sets the node's parent to the given parent node.
   142  func (n *BasicNode[K, V]) SetParent(parent *BasicNode[K, V]) {
   143  	if n == nil {
   144  		return
   145  	}
   146  
   147  	if n.parent != nil {
   148  		delete(n.parent.children, n.key)
   149  	}
   150  
   151  	if parent.children == nil {
   152  		parent.children = make(map[K]*BasicNode[K, V])
   153  	}
   154  	parent.children[n.key] = n
   155  	n.parent = parent
   156  }
   157  
   158  // Add adds a new child with the given key and value to this node, returning
   159  // the new node.
   160  func (n *BasicNode[K, V]) Add(key K, value V) *BasicNode[K, V] {
   161  	node := &BasicNode[K, V]{
   162  		key:    key,
   163  		value:  value,
   164  		parent: n,
   165  	}
   166  
   167  	if n != nil {
   168  		if n.children == nil {
   169  			n.children = make(map[K]*BasicNode[K, V])
   170  		}
   171  		n.children[key] = node
   172  	}
   173  
   174  	return node
   175  }
   176  
   177  // Remove removes the child with the given key, if one exists.
   178  func (n *BasicNode[K, V]) Remove(
   179  	key K,
   180  ) (child *BasicNode[K, V], removed bool) {
   181  	if child, removed = n.children[key]; removed {
   182  		child.parent = nil
   183  		delete(n.children, key)
   184  	}
   185  	return
   186  }
   187  
   188  // Len returns the recursive length of the tree relative to the node.
   189  func (n *BasicNode[K, V]) Len() (total int) {
   190  	switch {
   191  	case n == nil:
   192  		return 0
   193  	case len(n.children) == 0:
   194  		return 1
   195  	default:
   196  		for _, child := range n.children {
   197  			total += child.Len()
   198  		}
   199  		total++ // for n itself
   200  		return
   201  	}
   202  }
   203  
   204  func handleWalkError(err error) (stop bool, unhandled error) {
   205  	switch {
   206  	case err == nil:
   207  		return false, nil
   208  	case errors.Is(err, ErrSkipSubtree):
   209  		return true, nil
   210  	default:
   211  		return true, err
   212  	}
   213  }
   214  
   215  // Walk walks through the tree depth-first.
   216  func (n *BasicNode[K, V]) Walk(fn NodeWalker[K, V]) error {
   217  	if n == nil {
   218  		return nil
   219  	}
   220  
   221  	stop, err := handleWalkError(fn(n))
   222  	if stop || len(n.children) == 0 {
   223  		return err
   224  	}
   225  
   226  	keys := xmaps.Keys(n.children)
   227  	sort.Slice(keys, func(i int, j int) bool {
   228  		return keys[i] < keys[j]
   229  	})
   230  
   231  	for _, key := range keys {
   232  		if stop, err = handleWalkError(n.children[key].Walk(fn)); stop {
   233  			return err
   234  		}
   235  	}
   236  
   237  	return nil
   238  }
   239  
   240  // WalkRev walks through the tree depth-first in reverse.
   241  func (n *BasicNode[K, V]) WalkRev(fn NodeWalker[K, V]) error {
   242  	if n == nil {
   243  		return nil
   244  	}
   245  
   246  	var (
   247  		stop bool
   248  		err  error
   249  	)
   250  
   251  	if len(n.children) > 0 {
   252  		keys := xmaps.Keys(n.children)
   253  		sort.Slice(keys, func(i int, j int) bool {
   254  			return keys[i] < keys[j]
   255  		})
   256  
   257  		for _, key := range keys {
   258  			if stop, err = handleWalkError(n.children[key].WalkRev(fn)); stop {
   259  				return err
   260  			}
   261  		}
   262  	}
   263  
   264  	_, err = handleWalkError(fn(n))
   265  	return err
   266  }