github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/subnet/subnet.go (about)

     1  // Package subnet contains functions for finding available subnets
     2  package subnet
     3  
     4  import (
     5  	"bytes"
     6  	"fmt"
     7  	"net"
     8  	"sort"
     9  )
    10  
    11  // CoveringCIDRs returns the ip networks needed to cover the given IPs with as
    12  // big mask as possible for each subnet. The analysis starts by finding all
    13  // subnets using a 16-bit mask for IPv4 and a 64 bit mask for IPv6 addresses.
    14  // Once the subnets are established, the mask for each one will be increased
    15  // to the maximum value that still masks all IPs that it was created for.
    16  //
    17  // Note: A similar method exists in Telepresence 1, but this method was not
    18  // compared to it when written.
    19  func CoveringCIDRs(ips []net.IP) []*net.IPNet {
    20  	// IPv4 subnet key. Identifies a class B subnet
    21  	type ipv4SubnetKey [2]byte
    22  
    23  	// IPv6 subnet key. This is the 48-bit route and 16-bit subnet identifier. Identifies a 64 bit subnet.
    24  	type ipv6SubnetKey [8]byte
    25  
    26  	ipv6Subnets := make(map[ipv6SubnetKey]*[7]Bitfield256)
    27  
    28  	// Divide into subnets with ByteSets.
    29  
    30  	// IPv4 has 2 byte subnets and one Bitfield256 representing the third byte.
    31  	// (last byte is skipped because no split on subnet is made on that byte).
    32  	ipv4Subnets := make(map[ipv4SubnetKey]*Bitfield256)
    33  
    34  	// IPv6 has 8 byte subnets and seven ByteSets representing all but the last
    35  	// byte of the subnet relative 64-bit address (last byte is skipped because
    36  	// no split into subnets is made using that byte).
    37  	for _, ip := range ips {
    38  		if ip4 := ip.To4(); ip4 != nil {
    39  			bk := ipv4SubnetKey{ip4[0], ip4[1]}
    40  			var bytes *Bitfield256
    41  			if bytes = ipv4Subnets[bk]; bytes == nil {
    42  				bytes = &Bitfield256{}
    43  				ipv4Subnets[bk] = bytes
    44  			}
    45  			bytes.SetBit(ip4[2])
    46  		} else if ip16 := ip.To16(); ip16 != nil {
    47  			r := ipv6SubnetKey{}
    48  			copy(r[:], ip16)
    49  			byteSets, ok := ipv6Subnets[r]
    50  			if !ok {
    51  				byteSets = &[7]Bitfield256{}
    52  				ipv6Subnets[r] = byteSets
    53  			}
    54  			for i := range byteSets {
    55  				byteSets[i].SetBit(ip16[i+8])
    56  			}
    57  		}
    58  	}
    59  
    60  	subnets := make([]*net.IPNet, len(ipv4Subnets)+len(ipv6Subnets))
    61  	i := 0
    62  	for network, bytes := range ipv4Subnets {
    63  		ones, thirdByte := bytes.Mask()
    64  		subnets[i] = &net.IPNet{
    65  			IP:   net.IP{network[0], network[1], thirdByte, 0},
    66  			Mask: net.CIDRMask(16+ones, 32),
    67  		}
    68  		i++
    69  	}
    70  	for subnet, byteSets := range ipv6Subnets {
    71  		maskOnes := 64
    72  		ip := make(net.IP, 16)
    73  		copy(ip, subnet[:])
    74  		for bi, bytes := range byteSets {
    75  			ones, nByte := bytes.Mask()
    76  			maskOnes += ones
    77  			ip[8+bi] = nByte
    78  			if ones != 8 {
    79  				break
    80  			}
    81  		}
    82  		subnets[i] = &net.IPNet{
    83  			IP:   ip,
    84  			Mask: net.CIDRMask(maskOnes, 128),
    85  		}
    86  		i++
    87  	}
    88  	sort.Slice(subnets, func(i, j int) bool { return compareIPs(subnets[i].IP, subnets[j].IP) < 0 })
    89  	return subnets
    90  }
    91  
    92  // compareIPs is like bytes.Compare but will always consider IPv4 less than IPv6.
    93  func compareIPs(a, b net.IP) int {
    94  	dl := len(a) - len(b)
    95  	switch {
    96  	case dl == 0:
    97  		dl = bytes.Compare(a, b)
    98  	case dl < 0:
    99  		dl = -1
   100  	default:
   101  		dl = 1
   102  	}
   103  	return dl
   104  }
   105  
   106  // Unique will drop any subnet that is covered by another subnet from the
   107  // given slice and return the resulting slice. This function will alter
   108  // the given slice.
   109  func Unique(subnets []*net.IPNet) []*net.IPNet {
   110  	ln := len(subnets)
   111  	for i, isn := range subnets {
   112  		if i >= ln {
   113  			break
   114  		}
   115  		for r, rsn := range subnets {
   116  			if i == r {
   117  				continue
   118  			}
   119  			if Covers(rsn, isn) {
   120  				ln--
   121  				subnets[i] = subnets[ln]
   122  				break
   123  			}
   124  		}
   125  	}
   126  	return subnets[:ln]
   127  }
   128  
   129  // Partition returns two slices, the first containing the subnets for which the filter evaluates
   130  // to true, the second containing the rest.
   131  func Partition(subnets []*net.IPNet, filter func(int, *net.IPNet) bool) (matched, notMatched []*net.IPNet) {
   132  	for i, sn := range subnets {
   133  		if filter(i, sn) {
   134  			matched = append(matched, sn)
   135  		} else {
   136  			notMatched = append(notMatched, sn)
   137  		}
   138  	}
   139  	return
   140  }
   141  
   142  // Equal returns true if a and b have equal IP and masks.
   143  func Equal(a, b *net.IPNet) bool {
   144  	if a.IP.Equal(b.IP) {
   145  		ao, ab := a.Mask.Size()
   146  		bo, bb := b.Mask.Size()
   147  		return ao == bo && ab == bb
   148  	}
   149  	return false
   150  }
   151  
   152  // Covers answers the question if network range a contains the full network range b.
   153  func Covers(a, b *net.IPNet) bool {
   154  	return a.Contains(b.IP) && a.Contains(MaxIP(b))
   155  }
   156  
   157  // Overlaps answers the question if there is an overlap between network range a and b.
   158  func Overlaps(a, b *net.IPNet) bool {
   159  	return a.Contains(b.IP) || a.Contains(MaxIP(b)) || b.Contains(a.IP) || b.Contains(MaxIP(a))
   160  }
   161  
   162  func MaxIP(cidr *net.IPNet) net.IP {
   163  	// create max IP in range b using its mask
   164  	ones, _ := cidr.Mask.Size()
   165  	l := len(cidr.IP)
   166  	m := make(net.IP, l)
   167  	n := uint(ones)
   168  	for i := 0; i < l; i++ {
   169  		switch {
   170  		case n >= 8:
   171  			m[i] = cidr.IP[i]
   172  			n -= 8
   173  		case n > 0:
   174  			m[i] = cidr.IP[i] | byte(0xff>>n)
   175  			n = 0
   176  		default:
   177  			m[i] = 0xff
   178  		}
   179  	}
   180  	return m
   181  }
   182  
   183  // incIP attempts to increase the given ip. The increase starts at the penultimate byte. The increased IP is
   184  // returned unless it is equal or larger than the given end, in which case nil is returned.
   185  func incIP(ip, end net.IP) net.IP {
   186  	ipc := make(net.IP, len(ip))
   187  	for bi := len(ip) - 2; bi >= 0; bi-- {
   188  		if bv := ip[bi]; bv < 255 {
   189  			copy(ipc, ip)
   190  			ipc[bi] = bv + 1
   191  			// set bytes to the right of the increased byt to zero.
   192  			for xi := bi + 1; xi < len(ipc)-1; xi++ {
   193  				ipc[xi] = 0
   194  			}
   195  			if compareIPs(ipc, end) < 0 {
   196  				return ipc
   197  			}
   198  			break
   199  		}
   200  	}
   201  	return nil
   202  }
   203  
   204  // RandomIPv4Subnet finds a random free subnet using the given mask. A subnet is considered
   205  // free if it doesn't overlap with any of the subnets returned by the net.InterfaceAddrs
   206  // function or with any of the subnets provided in the avoid parameter.
   207  // The returned subnet will be a private IPv4 subnet in either class C, B, or A range, and the search
   208  // for a free subnet uses that order.
   209  // See https://en.wikipedia.org/wiki/Private_network for more info about private subnets.
   210  func RandomIPv4Subnet(mask net.IPMask, avoid []*net.IPNet) (*net.IPNet, error) {
   211  	as, err := net.InterfaceAddrs()
   212  	if err != nil {
   213  		return nil, err
   214  	}
   215  	cidrs := make([]*net.IPNet, 0, len(as)+len(avoid))
   216  	for _, a := range as {
   217  		if _, cidr, err := net.ParseCIDR(a.String()); err == nil {
   218  			cidrs = append(cidrs, cidr)
   219  		}
   220  	}
   221  	cidrs = append(cidrs, avoid...)
   222  
   223  	// IP address range pairs, from - to (to is non-inclusive)
   224  	ranges := []net.IP{
   225  		{192, 168, 0, 0}, {192, 169, 0, 0}, // Class C private range
   226  		{172, 16, 0, 0}, {172, 32, 0, 0}, // Class B private range
   227  		{10, 0, 0, 0}, {11, 0, 0, 0}, // Class A private range
   228  	}
   229  
   230  	for i := 0; i < len(ranges); i += 2 {
   231  		ip := ranges[i]
   232  
   233  		end := ranges[i+1]
   234  		for {
   235  			ip1 := make(net.IP, len(ip))
   236  			copy(ip1, ip)
   237  			ip1[len(ip)-1] = 1
   238  			sn := net.IPNet{
   239  				IP:   ip1,
   240  				Mask: mask,
   241  			}
   242  			inUse := false
   243  			for _, cidr := range cidrs {
   244  				if Overlaps(cidr, &sn) {
   245  					inUse = true
   246  					break
   247  				}
   248  			}
   249  			if !inUse {
   250  				return &sn, nil
   251  			}
   252  			if ip = incIP(ip, end); ip == nil {
   253  				break
   254  			}
   255  		}
   256  	}
   257  	return nil, fmt.Errorf("unable to find a free subnet")
   258  }
   259  
   260  func IsZeroMask(n *net.IPNet) bool {
   261  	for _, b := range n.Mask {
   262  		if b != 0 {
   263  			return false
   264  		}
   265  	}
   266  	return true
   267  }
   268  
   269  // IsHalfOfDefault route returns true if the given subnet covers half the address space with a /1 mask.
   270  func IsHalfOfDefault(n *net.IPNet) bool {
   271  	ones, _ := n.Mask.Size()
   272  	return ones == 1
   273  }