github.com/slackhq/nebula@v1.9.0/cidr/tree6.go (about)

     1  package cidr
     2  
     3  import (
     4  	"net"
     5  
     6  	"github.com/slackhq/nebula/iputil"
     7  )
     8  
     9  const startbit6 = uint64(1 << 63)
    10  
    11  type Tree6[T any] struct {
    12  	root4 *Node[T]
    13  	root6 *Node[T]
    14  }
    15  
    16  func NewTree6[T any]() *Tree6[T] {
    17  	tree := new(Tree6[T])
    18  	tree.root4 = &Node[T]{}
    19  	tree.root6 = &Node[T]{}
    20  	return tree
    21  }
    22  
    23  func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) {
    24  	var node, next *Node[T]
    25  
    26  	cidrIP, ipv4 := isIPV4(cidr.IP)
    27  	if ipv4 {
    28  		node = tree.root4
    29  		next = tree.root4
    30  
    31  	} else {
    32  		node = tree.root6
    33  		next = tree.root6
    34  	}
    35  
    36  	for i := 0; i < len(cidrIP); i += 4 {
    37  		ip := iputil.Ip2VpnIp(cidrIP[i : i+4])
    38  		mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4])
    39  		bit := startbit
    40  
    41  		// Find our last ancestor in the tree
    42  		for bit&mask != 0 {
    43  			if ip&bit != 0 {
    44  				next = node.right
    45  			} else {
    46  				next = node.left
    47  			}
    48  
    49  			if next == nil {
    50  				break
    51  			}
    52  
    53  			bit = bit >> 1
    54  			node = next
    55  		}
    56  
    57  		// Build up the rest of the tree we don't already have
    58  		for bit&mask != 0 {
    59  			next = &Node[T]{}
    60  			next.parent = node
    61  
    62  			if ip&bit != 0 {
    63  				node.right = next
    64  			} else {
    65  				node.left = next
    66  			}
    67  
    68  			bit >>= 1
    69  			node = next
    70  		}
    71  	}
    72  
    73  	// Final node marks our cidr, set the value
    74  	node.value = val
    75  	node.hasValue = true
    76  }
    77  
    78  // Finds the most specific match
    79  func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) {
    80  	var node *Node[T]
    81  
    82  	wholeIP, ipv4 := isIPV4(ip)
    83  	if ipv4 {
    84  		node = tree.root4
    85  	} else {
    86  		node = tree.root6
    87  	}
    88  
    89  	for i := 0; i < len(wholeIP); i += 4 {
    90  		ip := iputil.Ip2VpnIp(wholeIP[i : i+4])
    91  		bit := startbit
    92  
    93  		for node != nil {
    94  			if node.hasValue {
    95  				value = node.value
    96  				ok = true
    97  			}
    98  
    99  			if bit == 0 {
   100  				break
   101  			}
   102  
   103  			if ip&bit != 0 {
   104  				node = node.right
   105  			} else {
   106  				node = node.left
   107  			}
   108  
   109  			bit >>= 1
   110  		}
   111  	}
   112  
   113  	return ok, value
   114  }
   115  
   116  func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) {
   117  	bit := startbit
   118  	node := tree.root4
   119  
   120  	for node != nil {
   121  		if node.hasValue {
   122  			value = node.value
   123  			ok = true
   124  		}
   125  
   126  		if ip&bit != 0 {
   127  			node = node.right
   128  		} else {
   129  			node = node.left
   130  		}
   131  
   132  		bit >>= 1
   133  	}
   134  
   135  	return ok, value
   136  }
   137  
   138  func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) {
   139  	ip := hi
   140  	node := tree.root6
   141  
   142  	for i := 0; i < 2; i++ {
   143  		bit := startbit6
   144  
   145  		for node != nil {
   146  			if node.hasValue {
   147  				value = node.value
   148  				ok = true
   149  			}
   150  
   151  			if bit == 0 {
   152  				break
   153  			}
   154  
   155  			if ip&bit != 0 {
   156  				node = node.right
   157  			} else {
   158  				node = node.left
   159  			}
   160  
   161  			bit >>= 1
   162  		}
   163  
   164  		ip = lo
   165  	}
   166  
   167  	return ok, value
   168  }
   169  
   170  func isIPV4(ip net.IP) (net.IP, bool) {
   171  	if len(ip) == net.IPv4len {
   172  		return ip, true
   173  	}
   174  
   175  	if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
   176  		return ip[12:16], true
   177  	}
   178  
   179  	return ip, false
   180  }
   181  
   182  func isZeros(p net.IP) bool {
   183  	for i := 0; i < len(p); i++ {
   184  		if p[i] != 0 {
   185  			return false
   186  		}
   187  	}
   188  	return true
   189  }