github.com/slackhq/nebula@v1.9.0/cidr/tree4.go (about)

     1  package cidr
     2  
     3  import (
     4  	"net"
     5  
     6  	"github.com/slackhq/nebula/iputil"
     7  )
     8  
     9  type Node[T any] struct {
    10  	left     *Node[T]
    11  	right    *Node[T]
    12  	parent   *Node[T]
    13  	hasValue bool
    14  	value    T
    15  }
    16  
    17  type entry[T any] struct {
    18  	CIDR  *net.IPNet
    19  	Value T
    20  }
    21  
    22  type Tree4[T any] struct {
    23  	root *Node[T]
    24  	list []entry[T]
    25  }
    26  
    27  const (
    28  	startbit = iputil.VpnIp(0x80000000)
    29  )
    30  
    31  func NewTree4[T any]() *Tree4[T] {
    32  	tree := new(Tree4[T])
    33  	tree.root = &Node[T]{}
    34  	tree.list = []entry[T]{}
    35  	return tree
    36  }
    37  
    38  func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) {
    39  	bit := startbit
    40  	node := tree.root
    41  	next := tree.root
    42  
    43  	ip := iputil.Ip2VpnIp(cidr.IP)
    44  	mask := iputil.Ip2VpnIp(cidr.Mask)
    45  
    46  	// Find our last ancestor in the tree
    47  	for bit&mask != 0 {
    48  		if ip&bit != 0 {
    49  			next = node.right
    50  		} else {
    51  			next = node.left
    52  		}
    53  
    54  		if next == nil {
    55  			break
    56  		}
    57  
    58  		bit = bit >> 1
    59  		node = next
    60  	}
    61  
    62  	// We already have this range so update the value
    63  	if next != nil {
    64  		addCIDR := cidr.String()
    65  		for i, v := range tree.list {
    66  			if addCIDR == v.CIDR.String() {
    67  				tree.list = append(tree.list[:i], tree.list[i+1:]...)
    68  				break
    69  			}
    70  		}
    71  
    72  		tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
    73  		node.value = val
    74  		node.hasValue = true
    75  		return
    76  	}
    77  
    78  	// Build up the rest of the tree we don't already have
    79  	for bit&mask != 0 {
    80  		next = &Node[T]{}
    81  		next.parent = node
    82  
    83  		if ip&bit != 0 {
    84  			node.right = next
    85  		} else {
    86  			node.left = next
    87  		}
    88  
    89  		bit >>= 1
    90  		node = next
    91  	}
    92  
    93  	// Final node marks our cidr, set the value
    94  	node.value = val
    95  	node.hasValue = true
    96  	tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
    97  }
    98  
    99  // Contains finds the first match, which may be the least specific
   100  func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) {
   101  	bit := startbit
   102  	node := tree.root
   103  
   104  	for node != nil {
   105  		if node.hasValue {
   106  			return true, node.value
   107  		}
   108  
   109  		if ip&bit != 0 {
   110  			node = node.right
   111  		} else {
   112  			node = node.left
   113  		}
   114  
   115  		bit >>= 1
   116  
   117  	}
   118  
   119  	return false, value
   120  }
   121  
   122  // MostSpecificContains finds the most specific match
   123  func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
   124  	bit := startbit
   125  	node := tree.root
   126  
   127  	for node != nil {
   128  		if node.hasValue {
   129  			value = node.value
   130  			ok = true
   131  		}
   132  
   133  		if ip&bit != 0 {
   134  			node = node.right
   135  		} else {
   136  			node = node.left
   137  		}
   138  
   139  		bit >>= 1
   140  	}
   141  
   142  	return ok, value
   143  }
   144  
   145  type eachFunc[T any] func(T) bool
   146  
   147  // EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete
   148  // The final return value will be true if the provided function returned true
   149  func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool {
   150  	bit := startbit
   151  	node := tree.root
   152  
   153  	for node != nil {
   154  		if node.hasValue {
   155  			// If the each func returns true then we can exit the loop
   156  			if each(node.value) {
   157  				return true
   158  			}
   159  		}
   160  
   161  		if ip&bit != 0 {
   162  			node = node.right
   163  		} else {
   164  			node = node.left
   165  		}
   166  
   167  		bit >>= 1
   168  	}
   169  
   170  	return false
   171  }
   172  
   173  // GetCIDR returns the entry added by the most recent matching AddCIDR call
   174  func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) {
   175  	bit := startbit
   176  	node := tree.root
   177  
   178  	ip := iputil.Ip2VpnIp(cidr.IP)
   179  	mask := iputil.Ip2VpnIp(cidr.Mask)
   180  
   181  	// Find our last ancestor in the tree
   182  	for node != nil && bit&mask != 0 {
   183  		if ip&bit != 0 {
   184  			node = node.right
   185  		} else {
   186  			node = node.left
   187  		}
   188  
   189  		bit = bit >> 1
   190  	}
   191  
   192  	if bit&mask == 0 && node != nil {
   193  		value = node.value
   194  		ok = node.hasValue
   195  	}
   196  
   197  	return ok, value
   198  }
   199  
   200  // List will return all CIDRs and their current values. Do not modify the contents!
   201  func (tree *Tree4[T]) List() []entry[T] {
   202  	return tree.list
   203  }