github.com/cilium/cilium@v1.16.2/pkg/cidr/cidr.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package cidr
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"net"
    10  )
    11  
    12  // NewCIDR returns a new CIDR using a net.IPNet
    13  func NewCIDR(ipnet *net.IPNet) *CIDR {
    14  	if ipnet == nil {
    15  		return nil
    16  	}
    17  
    18  	return &CIDR{ipnet}
    19  }
    20  
    21  // CIDR is a network CIDR representation based on net.IPNet
    22  type CIDR struct {
    23  	*net.IPNet
    24  }
    25  
    26  // DeepEqual is an deepequal function, deeply comparing the receiver with other.
    27  // in must be non-nil.
    28  func (in *CIDR) DeepEqual(other *CIDR) bool {
    29  	if other == nil {
    30  		return false
    31  	}
    32  
    33  	if (in.IPNet == nil) != (other.IPNet == nil) {
    34  		return false
    35  	} else if in.IPNet != nil {
    36  		if !in.IPNet.IP.Equal(other.IPNet.IP) {
    37  			return false
    38  		}
    39  		inOnes, inBits := in.IPNet.Mask.Size()
    40  		otherOnes, otherBits := other.IPNet.Mask.Size()
    41  		return inOnes == otherOnes && inBits == otherBits
    42  	}
    43  
    44  	return true
    45  }
    46  
    47  // DeepCopy creates a deep copy of a CIDR
    48  func (n *CIDR) DeepCopy() *CIDR {
    49  	if n == nil {
    50  		return nil
    51  	}
    52  	out := new(CIDR)
    53  	n.DeepCopyInto(out)
    54  	return out
    55  }
    56  
    57  // DeepCopyInto is a deepcopy function, copying the receiver, writing into out. in must be non-nil.
    58  func (in *CIDR) DeepCopyInto(out *CIDR) {
    59  	*out = *in
    60  	if in.IPNet == nil {
    61  		return
    62  	}
    63  	out.IPNet = new(net.IPNet)
    64  	*out.IPNet = *in.IPNet
    65  	if in.IPNet.IP != nil {
    66  		in, out := &in.IPNet.IP, &out.IPNet.IP
    67  		*out = make(net.IP, len(*in))
    68  		copy(*out, *in)
    69  	}
    70  	if in.IPNet.Mask != nil {
    71  		in, out := &in.IPNet.Mask, &out.IPNet.Mask
    72  		*out = make(net.IPMask, len(*in))
    73  		copy(*out, *in)
    74  	}
    75  }
    76  
    77  // AvailableIPs returns the number of IPs available in a CIDR
    78  func (n *CIDR) AvailableIPs() int {
    79  	ones, bits := n.Mask.Size()
    80  	return 1 << (bits - ones)
    81  }
    82  
    83  // Equal returns true if the receiver's CIDR equals the other CIDR.
    84  func (n *CIDR) Equal(o *CIDR) bool {
    85  	if n == nil || o == nil {
    86  		return n == o
    87  	}
    88  	return Equal(n.IPNet, o.IPNet)
    89  }
    90  
    91  // Equal returns true if the n and o net.IPNet CIDRs are Equal.
    92  func Equal(n, o *net.IPNet) bool {
    93  	if n == nil || o == nil {
    94  		return n == o
    95  	}
    96  	if n == o {
    97  		return true
    98  	}
    99  	return n.IP.Equal(o.IP) &&
   100  		bytes.Equal(n.Mask, o.Mask)
   101  }
   102  
   103  // ZeroNet generates a zero net.IPNet object for the given address family
   104  func ZeroNet(family int) *net.IPNet {
   105  	switch family {
   106  	case FAMILY_V4:
   107  		return &net.IPNet{
   108  			IP:   net.IPv4zero,
   109  			Mask: net.CIDRMask(0, 8*net.IPv4len),
   110  		}
   111  	case FAMILY_V6:
   112  		return &net.IPNet{
   113  			IP:   net.IPv6zero,
   114  			Mask: net.CIDRMask(0, 8*net.IPv6len),
   115  		}
   116  	}
   117  	return nil
   118  }
   119  
   120  // ContainsAll returns true if 'ipNets1' contains all net.IPNet of 'ipNets2'
   121  func ContainsAll(ipNets1, ipNets2 []*net.IPNet) bool {
   122  	for _, n := range ipNets2 {
   123  		if !Contains(ipNets1, n) {
   124  			return false
   125  		}
   126  	}
   127  	return true
   128  }
   129  
   130  // Contains returns true if 'ipNets' contains ipNet.
   131  func Contains(ipNets []*net.IPNet, ipNet *net.IPNet) bool {
   132  	for _, n := range ipNets {
   133  		if Equal(n, ipNet) {
   134  			return true
   135  		}
   136  	}
   137  	return false
   138  }
   139  
   140  // RemoveAll removes all cidrs specified in 'toRemove' from 'ipNets'. ipNets
   141  // is clobbered (to ensure removed CIDRs can be garbage collected) and
   142  // must not be used after this function has been called.
   143  // Example usage:
   144  //
   145  //	cidrs = cidr.RemoveAll(cidrs, toRemove)
   146  func RemoveAll(ipNets, toRemove []*net.IPNet) []*net.IPNet {
   147  	newIPNets := ipNets[:0]
   148  	for _, n := range ipNets {
   149  		if !Contains(toRemove, n) {
   150  			newIPNets = append(newIPNets, n)
   151  		}
   152  	}
   153  	for i := len(newIPNets); i < len(ipNets); i++ {
   154  		ipNets[i] = nil // or the zero value of T
   155  	}
   156  	return newIPNets
   157  }
   158  
   159  // ParseCIDR parses the CIDR string using net.ParseCIDR
   160  func ParseCIDR(str string) (*CIDR, error) {
   161  	_, ipnet, err := net.ParseCIDR(str)
   162  	if err != nil {
   163  		return nil, err
   164  	}
   165  	return NewCIDR(ipnet), nil
   166  }
   167  
   168  // MustParseCIDR parses the CIDR string using net.ParseCIDR and panics if the
   169  // CIDR cannot be parsed
   170  func MustParseCIDR(str string) *CIDR {
   171  	c, err := ParseCIDR(str)
   172  	if err != nil {
   173  		panic(fmt.Sprintf("Unable to parse CIDR '%s': %s", str, err))
   174  	}
   175  	return c
   176  }