github.com/richardwilkes/toolbox@v1.121.0/collection/redblack/node.go (about)

     1  // Copyright (c) 2016-2024 by Richard A. Wilkes. All rights reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the Mozilla Public
     4  // License, version 2.0. If a copy of the MPL was not distributed with
     5  // this file, You can obtain one at http://mozilla.org/MPL/2.0/.
     6  //
     7  // This Source Code Form is "Incompatible With Secondary Licenses", as
     8  // defined by the Mozilla Public License, version 2.0.
     9  
    10  package redblack
    11  
    12  import (
    13  	"fmt"
    14  	"strings"
    15  )
    16  
    17  type node[K, V any] struct {
    18  	key    K
    19  	value  V
    20  	parent *node[K, V]
    21  	left   *node[K, V]
    22  	right  *node[K, V]
    23  	black  bool
    24  }
    25  
    26  func (n *node[K, V]) isBlack() bool {
    27  	return n == nil || n.black
    28  }
    29  
    30  func (n *node[K, V]) isRed() bool {
    31  	return n != nil && !n.black
    32  }
    33  
    34  func (n *node[K, V]) find(compareFunc func(a, b K) int, key K) *node[K, V] {
    35  	if n == nil {
    36  		return nil
    37  	}
    38  	result := compareFunc(key, n.key)
    39  	switch {
    40  	case result < 0:
    41  		return n.left.find(compareFunc, key)
    42  	case result > 0:
    43  		return n.right.find(compareFunc, key)
    44  	default:
    45  		// Always return the left-most one in the case of multiple matches
    46  		cur := n
    47  		for cur.left != nil && compareFunc(key, cur.left.key) == 0 {
    48  			cur = cur.left
    49  		}
    50  		return cur
    51  	}
    52  }
    53  
    54  func (n *node[K, V]) traverse(visitorFunc func(key K, value V) bool) bool {
    55  	if n == nil {
    56  		return true
    57  	}
    58  	if n.left.traverse(visitorFunc) {
    59  		if visitorFunc(n.key, n.value) {
    60  			return n.right.traverse(visitorFunc)
    61  		}
    62  	}
    63  	return false
    64  }
    65  
    66  func (n *node[K, V]) reverseTraverse(visitorFunc func(key K, value V) bool) bool {
    67  	if n == nil {
    68  		return true
    69  	}
    70  	if n.right.reverseTraverse(visitorFunc) {
    71  		if visitorFunc(n.key, n.value) {
    72  			return n.left.reverseTraverse(visitorFunc)
    73  		}
    74  	}
    75  	return false
    76  }
    77  
    78  func (n *node[K, V]) traverseEqualOrGreater(compareFunc func(a, b K) int, key K, visitorFunc func(key K, value V) bool) bool {
    79  	if n == nil {
    80  		return true
    81  	}
    82  	result := compareFunc(key, n.key)
    83  	if result < 0 {
    84  		if !n.left.traverseEqualOrGreater(compareFunc, key, visitorFunc) {
    85  			return false
    86  		}
    87  	}
    88  	if result <= 0 {
    89  		if !visitorFunc(n.key, n.value) {
    90  			return false
    91  		}
    92  	}
    93  	return n.right.traverseEqualOrGreater(compareFunc, key, visitorFunc)
    94  }
    95  
    96  func (n *node[K, V]) traverseEqualOrLess(compareFunc func(a, b K) int, key K, visitorFunc func(key K, value V) bool) bool {
    97  	if n == nil {
    98  		return true
    99  	}
   100  	result := compareFunc(key, n.key)
   101  	if result > 0 {
   102  		if !n.right.traverseEqualOrLess(compareFunc, key, visitorFunc) {
   103  			return false
   104  		}
   105  	}
   106  	if result >= 0 {
   107  		if !visitorFunc(n.key, n.value) {
   108  			return false
   109  		}
   110  	}
   111  	return n.left.traverseEqualOrLess(compareFunc, key, visitorFunc)
   112  }
   113  
   114  func (n *node[K, V]) dump(depth int, side string) {
   115  	if n == nil {
   116  		return
   117  	}
   118  	br := "r"
   119  	if n.black {
   120  		br = "b"
   121  	}
   122  	fmt.Printf("%s%s%s%v\n", strings.Repeat("  ", depth), br, side, n.key)
   123  	n.left.dump(depth+1, "L ")
   124  	n.right.dump(depth+1, "R ")
   125  }