github.com/projectdiscovery/nuclei/v2@v2.9.15/pkg/protocols/common/randomip/randomip.go (about)

     1  package randomip
     2  
     3  import (
     4  	"crypto/rand"
     5  	"net"
     6  
     7  	"github.com/pkg/errors"
     8  	iputil "github.com/projectdiscovery/utils/ip"
     9  	randutil "github.com/projectdiscovery/utils/rand"
    10  )
    11  
    12  const (
    13  	maxIterations = 255
    14  )
    15  
    16  func GetRandomIPWithCidr(cidrs ...string) (net.IP, error) {
    17  	if len(cidrs) == 0 {
    18  		return nil, errors.Errorf("must specify at least one cidr")
    19  	}
    20  
    21  	randIdx, err := randutil.IntN(len(cidrs))
    22  	if err != nil {
    23  		return nil, err
    24  	}
    25  
    26  	cidr := cidrs[randIdx]
    27  
    28  	if !iputil.IsCIDR(cidr) {
    29  		return nil, errors.Errorf("%s is not a valid cidr", cidr)
    30  	}
    31  
    32  	baseIp, ipnet, err := net.ParseCIDR(cidr)
    33  	if err != nil {
    34  		return nil, err
    35  	}
    36  
    37  	switch {
    38  	case 255 == ipnet.Mask[len(ipnet.Mask)-1]:
    39  		return baseIp, nil
    40  	case iputil.IsIPv4(baseIp.String()):
    41  		return getRandomIP(ipnet, 4), nil
    42  	case iputil.IsIPv6(baseIp.String()):
    43  		return getRandomIP(ipnet, 16), nil
    44  	default:
    45  		return nil, errors.New("invalid base ip")
    46  	}
    47  }
    48  
    49  func getRandomIP(ipnet *net.IPNet, size int) net.IP {
    50  	ip := ipnet.IP
    51  	var iteration int
    52  
    53  	for iteration < maxIterations {
    54  		iteration++
    55  		ones, _ := ipnet.Mask.Size()
    56  		quotient := ones / 8
    57  		remainder := ones % 8
    58  		var r []byte
    59  		switch size {
    60  		case 4, 16:
    61  			r = make([]byte, size)
    62  		default:
    63  			return ip
    64  		}
    65  
    66  		_, _ = rand.Read(r)
    67  
    68  		for i := 0; i <= quotient; i++ {
    69  			if i == quotient {
    70  				shifted := byte(r[i]) >> remainder
    71  				r[i] = ipnet.IP[i] + (^ipnet.IP[i] & shifted)
    72  			} else {
    73  				r[i] = ipnet.IP[i]
    74  			}
    75  		}
    76  
    77  		ip = r
    78  
    79  		if !ip.Equal(ipnet.IP) {
    80  			break
    81  		}
    82  	}
    83  
    84  	return ip
    85  }