go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/netutil/cidr.go (about)

     1  /*
     2  
     3  Copyright (c) 2023 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  /*
     9  NOTE: this file is largely based on the work done here: https://github.com/apparentlymart/go-cidr
    10  
    11  The original pre-amble is as follows:
    12  
    13  Package cidr is a collection of assorted utilities for computing
    14  network and host addresses within network ranges.
    15  
    16  It expects a CIDR-type address structure where addresses are divided into
    17  some number of prefix bits representing the network and then the remaining
    18  suffix bits represent the host.
    19  
    20  For example, it can help to calculate addresses for sub-networks of a
    21  parent network, or to calculate host addresses within a particular prefix.
    22  
    23  At present this package is prioritizing simplicity of implementation and
    24  de-prioritizing speed and memory usage. Thus caution is advised before
    25  using this package in performance-critical applications or hot codepaths.
    26  Patches to improve the speed and memory usage may be accepted as long as
    27  they do not result in a significant increase in code complexity.
    28  
    29  */
    30  
    31  package netutil
    32  
    33  import (
    34  	"fmt"
    35  	"math/big"
    36  	"net"
    37  )
    38  
    39  // ParseCIDR wraps net.ParseCIDR and returns an augmented net.IPNet
    40  // that can do range "math" e.g. `.Next()`.
    41  func ParseCIDR(cidr string) (net.IP, *IPNet, error) {
    42  	ip, m, err := net.ParseCIDR(cidr)
    43  	if err != nil {
    44  		return nil, nil, err
    45  	}
    46  	if m == nil {
    47  		return nil, nil, nil
    48  	}
    49  	return ip, &IPNet{m}, nil
    50  }
    51  
    52  // IPNet is a wrapper for a `net.IPNet` with additional helper functions.
    53  type IPNet struct {
    54  	*net.IPNet
    55  }
    56  
    57  // AddressCount returns the number of distinct host addresses within the given
    58  // CIDR range.
    59  //
    60  // Since the result is a uint64, this function returns meaningful information
    61  // only for IPv4 ranges and IPv6 ranges with a prefix size of at least 65.
    62  func (i IPNet) AddressCount() uint64 {
    63  	prefixLen, bits := i.Mask.Size()
    64  	return 1 << (uint64(bits) - uint64(prefixLen))
    65  }
    66  
    67  // Next produces the next, non-overlapping mask for this mask.
    68  //
    69  // An example:
    70  //
    71  //	_, r0, _ := netutil.ParseCIDR("192.168.0.0/24")
    72  //	r1 := r1.Next()
    73  //	fmt.Println(r1.String()) // prints "192.168.1.0/24"
    74  //
    75  // It will return false if the ip range is not incrementable, that is
    76  // if the combination of address and size are the last combination representable.
    77  func (i IPNet) Next() (*IPNet, bool) {
    78  	_, currentLast := i.Range()
    79  	prefixLen, _ := i.Mask.Size()
    80  	mask := net.CIDRMask(prefixLen, 8*len(currentLast))
    81  	currentSubnet := &IPNet{&net.IPNet{IP: currentLast.Mask(mask), Mask: mask}}
    82  	_, last := currentSubnet.Range()
    83  	last = incrementAddress(last)
    84  	if last.Equal(net.IPv4zero) || last.Equal(net.IPv6zero) {
    85  		return nil, false
    86  	}
    87  	next := &IPNet{&net.IPNet{IP: last.Mask(mask), Mask: mask}}
    88  	return next, true
    89  }
    90  
    91  // Range returns the first and last IP address of the IPNet.
    92  func (i IPNet) Range() (net.IP, net.IP) {
    93  	firstIP := i.IP
    94  	prefixLen, bits := i.Mask.Size()
    95  	if prefixLen == bits {
    96  		lastIP := make([]byte, len(firstIP))
    97  		copy(lastIP, firstIP)
    98  		return firstIP, lastIP
    99  	}
   100  	firstIPInt, bits, _ := ipToBigInt(firstIP)
   101  	hostLen := uint(bits) - uint(prefixLen)
   102  	lastIPInt := big.NewInt(1)
   103  	lastIPInt.Lsh(lastIPInt, hostLen)
   104  	lastIPInt.Sub(lastIPInt, big.NewInt(1))
   105  	lastIPInt.Or(lastIPInt, firstIPInt)
   106  	return firstIP, bigIntToIP(lastIPInt, bits)
   107  }
   108  
   109  // Overlaps returns true if this subnet overlaps with a given subnet.
   110  func (i IPNet) Overlaps(other *IPNet) bool {
   111  	first, last := i.Range()
   112  	otherFirst, otherLast := other.Range()
   113  	if i.Contains(otherFirst) || i.Contains(otherLast) {
   114  		return true
   115  	}
   116  	if other.Contains(first) || other.Contains(last) {
   117  		return true
   118  	}
   119  	return false
   120  }
   121  
   122  // incrementAddress increases the IP by one this returns a new []byte for the IP
   123  func incrementAddress(ip net.IP) net.IP {
   124  	ip = maybeResizeV4(ip)
   125  	incIP := make([]byte, len(ip))
   126  	copy(incIP, ip)
   127  	for j := len(incIP) - 1; j >= 0; j-- {
   128  		incIP[j]++
   129  		if incIP[j] > 0 {
   130  			break
   131  		}
   132  	}
   133  	return incIP
   134  }
   135  
   136  func ipToBigInt(ip net.IP) (*big.Int, int, error) {
   137  	val := new(big.Int)
   138  	val.SetBytes([]byte(ip))
   139  	if len(ip) == net.IPv4len {
   140  		return val, 32, nil
   141  	} else if len(ip) == net.IPv6len {
   142  		return val, 128, nil
   143  	} else {
   144  		return nil, 0, fmt.Errorf("cidr; unsupported address length %d", len(ip))
   145  	}
   146  }
   147  
   148  func bigIntToIP(ipInt *big.Int, bits int) net.IP {
   149  	ipBytes := ipInt.Bytes()
   150  	ret := make([]byte, bits>>3)
   151  	for i := 1; i <= len(ipBytes); i++ {
   152  		ret[len(ret)-i] = ipBytes[len(ipBytes)-i]
   153  	}
   154  	return net.IP(ret)
   155  }
   156  
   157  func maybeResizeV4(ip net.IP) net.IP {
   158  	if v4 := ip.To4(); v4 != nil {
   159  		return v4
   160  	}
   161  	return ip
   162  }