github.com/hashicorp/vault/sdk@v0.11.0/helper/cidrutil/cidr.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package cidrutil
     5  
     6  import (
     7  	"fmt"
     8  	"net"
     9  	"strings"
    10  
    11  	"github.com/hashicorp/errwrap"
    12  	"github.com/hashicorp/go-secure-stdlib/strutil"
    13  	sockaddr "github.com/hashicorp/go-sockaddr"
    14  )
    15  
    16  func isIPAddr(cidr sockaddr.SockAddr) bool {
    17  	return (cidr.Type() & sockaddr.TypeIP) != 0
    18  }
    19  
    20  // RemoteAddrIsOk checks if the given remote address is either:
    21  //   - OK because there's no CIDR whitelist
    22  //   - OK because it's in the CIDR whitelist
    23  func RemoteAddrIsOk(remoteAddr string, boundCIDRs []*sockaddr.SockAddrMarshaler) bool {
    24  	if len(boundCIDRs) == 0 {
    25  		// There's no CIDR whitelist.
    26  		return true
    27  	}
    28  	remoteSockAddr, err := sockaddr.NewSockAddr(remoteAddr)
    29  	if err != nil {
    30  		// Can't tell, err on the side of less access.
    31  		return false
    32  	}
    33  	for _, cidr := range boundCIDRs {
    34  		if isIPAddr(cidr) && cidr.Contains(remoteSockAddr) {
    35  			// Whitelisted.
    36  			return true
    37  		}
    38  	}
    39  	// Not whitelisted.
    40  	return false
    41  }
    42  
    43  // IPBelongsToCIDR checks if the given IP is encompassed by the given CIDR block
    44  func IPBelongsToCIDR(ipAddr string, cidr string) (bool, error) {
    45  	if ipAddr == "" {
    46  		return false, fmt.Errorf("missing IP address")
    47  	}
    48  
    49  	ip := net.ParseIP(ipAddr)
    50  	if ip == nil {
    51  		return false, fmt.Errorf("invalid IP address")
    52  	}
    53  
    54  	_, ipnet, err := net.ParseCIDR(cidr)
    55  	if err != nil {
    56  		return false, err
    57  	}
    58  
    59  	if !ipnet.Contains(ip) {
    60  		return false, nil
    61  	}
    62  
    63  	return true, nil
    64  }
    65  
    66  // IPBelongsToCIDRBlocksSlice checks if the given IP is encompassed by any of the given
    67  // CIDR blocks
    68  func IPBelongsToCIDRBlocksSlice(ipAddr string, cidrs []string) (bool, error) {
    69  	if ipAddr == "" {
    70  		return false, fmt.Errorf("missing IP address")
    71  	}
    72  
    73  	if len(cidrs) == 0 {
    74  		return false, fmt.Errorf("missing CIDR blocks to be checked against")
    75  	}
    76  
    77  	if ip := net.ParseIP(ipAddr); ip == nil {
    78  		return false, fmt.Errorf("invalid IP address")
    79  	}
    80  
    81  	for _, cidr := range cidrs {
    82  		belongs, err := IPBelongsToCIDR(ipAddr, cidr)
    83  		if err != nil {
    84  			return false, err
    85  		}
    86  		if belongs {
    87  			return true, nil
    88  		}
    89  	}
    90  
    91  	return false, nil
    92  }
    93  
    94  // ValidateCIDRListString checks if the list of CIDR blocks are valid, given
    95  // that the input is a string composed by joining all the CIDR blocks using a
    96  // separator. The input is separated based on the given separator and validity
    97  // of each is checked.
    98  func ValidateCIDRListString(cidrList string, separator string) (bool, error) {
    99  	if cidrList == "" {
   100  		return false, fmt.Errorf("missing CIDR list that needs validation")
   101  	}
   102  	if separator == "" {
   103  		return false, fmt.Errorf("missing separator")
   104  	}
   105  
   106  	return ValidateCIDRListSlice(strutil.ParseDedupLowercaseAndSortStrings(cidrList, separator))
   107  }
   108  
   109  // ValidateCIDRListSlice checks if the given list of CIDR blocks are valid
   110  func ValidateCIDRListSlice(cidrBlocks []string) (bool, error) {
   111  	if len(cidrBlocks) == 0 {
   112  		return false, fmt.Errorf("missing CIDR blocks that needs validation")
   113  	}
   114  
   115  	for _, block := range cidrBlocks {
   116  		if _, _, err := net.ParseCIDR(strings.TrimSpace(block)); err != nil {
   117  			return false, err
   118  		}
   119  	}
   120  
   121  	return true, nil
   122  }
   123  
   124  // Subset checks if the IPs belonging to a given CIDR block is a subset of IPs
   125  // belonging to another CIDR block.
   126  func Subset(cidr1, cidr2 string) (bool, error) {
   127  	if cidr1 == "" {
   128  		return false, fmt.Errorf("missing CIDR to be checked against")
   129  	}
   130  
   131  	if cidr2 == "" {
   132  		return false, fmt.Errorf("missing CIDR that needs to be checked")
   133  	}
   134  
   135  	ip1, net1, err := net.ParseCIDR(cidr1)
   136  	if err != nil {
   137  		return false, errwrap.Wrapf("failed to parse the CIDR to be checked against: {{err}}", err)
   138  	}
   139  
   140  	zeroAddr := false
   141  	if ip := ip1.To4(); ip != nil && ip.Equal(net.IPv4zero) {
   142  		zeroAddr = true
   143  	}
   144  	if ip := ip1.To16(); ip != nil && ip.Equal(net.IPv6zero) {
   145  		zeroAddr = true
   146  	}
   147  
   148  	maskLen1, _ := net1.Mask.Size()
   149  	if !zeroAddr && maskLen1 == 0 {
   150  		return false, fmt.Errorf("CIDR to be checked against is not in its canonical form")
   151  	}
   152  
   153  	ip2, net2, err := net.ParseCIDR(cidr2)
   154  	if err != nil {
   155  		return false, errwrap.Wrapf("failed to parse the CIDR that needs to be checked: {{err}}", err)
   156  	}
   157  
   158  	zeroAddr = false
   159  	if ip := ip2.To4(); ip != nil && ip.Equal(net.IPv4zero) {
   160  		zeroAddr = true
   161  	}
   162  	if ip := ip2.To16(); ip != nil && ip.Equal(net.IPv6zero) {
   163  		zeroAddr = true
   164  	}
   165  
   166  	maskLen2, _ := net2.Mask.Size()
   167  	if !zeroAddr && maskLen2 == 0 {
   168  		return false, fmt.Errorf("CIDR that needs to be checked is not in its canonical form")
   169  	}
   170  
   171  	// If the mask length of the CIDR that needs to be checked is smaller
   172  	// then the mask length of the CIDR to be checked against, then the
   173  	// former will encompass more IPs than the latter, and hence can't be a
   174  	// subset of the latter.
   175  	if maskLen2 < maskLen1 {
   176  		return false, nil
   177  	}
   178  
   179  	belongs, err := IPBelongsToCIDR(net2.IP.String(), cidr1)
   180  	if err != nil {
   181  		return false, err
   182  	}
   183  
   184  	return belongs, nil
   185  }
   186  
   187  // SubsetBlocks checks if each CIDR block of a given set of CIDR blocks, is a
   188  // subset of at least one CIDR block belonging to another set of CIDR blocks.
   189  // First parameter is the set of CIDR blocks to check against and the second
   190  // parameter is the set of CIDR blocks that needs to be checked.
   191  func SubsetBlocks(cidrBlocks1, cidrBlocks2 []string) (bool, error) {
   192  	if len(cidrBlocks1) == 0 {
   193  		return false, fmt.Errorf("missing CIDR blocks to be checked against")
   194  	}
   195  
   196  	if len(cidrBlocks2) == 0 {
   197  		return false, fmt.Errorf("missing CIDR blocks that needs to be checked")
   198  	}
   199  
   200  	// Check if all the elements of cidrBlocks2 is a subset of at least one
   201  	// element of cidrBlocks1
   202  	for _, cidrBlock2 := range cidrBlocks2 {
   203  		isSubset := false
   204  		for _, cidrBlock1 := range cidrBlocks1 {
   205  			subset, err := Subset(cidrBlock1, cidrBlock2)
   206  			if err != nil {
   207  				return false, err
   208  			}
   209  			// If CIDR is a subset of any of the CIDR block, its
   210  			// good enough. Break out.
   211  			if subset {
   212  				isSubset = true
   213  				break
   214  			}
   215  		}
   216  		// CIDR block was not a subset of any of the CIDR blocks in the
   217  		// set of blocks to check against
   218  		if !isSubset {
   219  			return false, nil
   220  		}
   221  	}
   222  
   223  	return true, nil
   224  }