github.com/microsoft/moc@v0.17.1/pkg/net/net.go (about)

     1  // Copyright (c) Microsoft Corporation.
     2  // Licensed under the Apache v2.0 license.
     3  package net
     4  
     5  import (
     6  	"fmt"
     7  	"math/big"
     8  	"net"
     9  )
    10  
    11  const (
    12  	LOOPBACK_ADDRESS = "127.0.0.1"
    13  )
    14  
    15  func GetIPAddress() (string, error) {
    16  	conn, err := net.Dial("udp", "8.8.8.8:80")
    17  	if err != nil {
    18  		return "", err
    19  	}
    20  	defer conn.Close()
    21  
    22  	return conn.LocalAddr().(*net.UDPAddr).IP.String(), nil
    23  }
    24  
    25  func StringToNetIPAddress(ipString string) net.IP {
    26  	return net.ParseIP(ipString)
    27  }
    28  
    29  func ParseMAC(macString string) (net.HardwareAddr, error) {
    30  	var macInt big.Int
    31  
    32  	// Hyper-V uses non-standard MAC address formats (with no colons and no dashes)
    33  	_, success := macInt.SetString(macString, 16)
    34  	if success {
    35  		macBytes := macInt.Bytes()
    36  		for i := len(macBytes); i < 6; i++ {
    37  			macBytes = append([]byte{0}, macBytes...)
    38  		}
    39  		hardwareAddr := net.HardwareAddr(macBytes)
    40  		return hardwareAddr, nil
    41  	}
    42  
    43  	hardwareAddr, err := net.ParseMAC(macString)
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  
    48  	return hardwareAddr, nil
    49  }
    50  
    51  func Or(ip, ip1 net.IP) net.IP {
    52  	b := make([]byte, len(ip))
    53  	for i := 0; i < len(ip); i++ {
    54  		b[i] = ip[i] | ip1[i]
    55  	}
    56  	return b
    57  }
    58  
    59  func Not(ip net.IP) net.IP {
    60  	b := make([]byte, len(ip))
    61  	for i := 0; i < len(ip); i++ {
    62  		b[i] = ^ip[i]
    63  	}
    64  	return b
    65  }
    66  
    67  func Increment(ip net.IP) net.IP {
    68  	newip := make([]byte, len(ip))
    69  	copy(newip, ip)
    70  	for i := len(ip) - 1; i >= 0; i-- {
    71  		newip[i] = ip[i] + 1
    72  		if newip[i] > 0 {
    73  			break
    74  		}
    75  	}
    76  	return newip
    77  }
    78  
    79  func Decrement(ip net.IP) net.IP {
    80  	newip := make([]byte, len(ip))
    81  	copy(newip, ip)
    82  	for i := len(ip) - 1; i >= 0; i-- {
    83  		newip[i] = ip[i] - 1
    84  		if newip[i] < 255 {
    85  			break
    86  		}
    87  	}
    88  	return newip
    89  }
    90  
    91  func GetCIDR(startip, endip net.IP) (*net.IPNet, error) {
    92  
    93  	if len(startip) != len(endip) {
    94  		return nil, fmt.Errorf("Can not compute CIDR for %s and %s.  Start and end range have different sizes (%d %d)", startip, endip, len(startip), len(endip))
    95  	}
    96  
    97  	var prefixlen uint = 0
    98  	exit := false
    99  	for i := 0; i < len(startip) && !exit; i++ {
   100  		for j := 0; j < 8 && !exit; j++ {
   101  			mask := byte(1 << (7 - j))
   102  			if (startip[i] & mask) == (endip[i] & mask) {
   103  				prefixlen++
   104  			} else {
   105  				exit = true
   106  			}
   107  		}
   108  	}
   109  	mask := net.CIDRMask(int(prefixlen), len(startip)*8)
   110  
   111  	//Find the start of the CIDR we need to allocate
   112  	rangeStartIP := startip.Mask(mask)
   113  	//fmt.Printf("the range to allocate for %s - %s is: %s\\%d\n", sip, eip, rangeStartIP, prefixlen)
   114  
   115  	return &net.IPNet{
   116  		IP:   rangeStartIP,
   117  		Mask: mask,
   118  	}, nil
   119  }
   120  
   121  func GetBroadcastAddress(cidr net.IPNet) net.IP {
   122  	broadcastip := Or(cidr.IP, Not(net.IP(cidr.Mask)))
   123  	if len(broadcastip) == net.IPv6len {
   124  		return broadcastip
   125  	}
   126  
   127  	// IPv4 (10.0.0.255) addresses are typically represented as IPv4 mapped IPv6 address (0:0:0:0:0:FFFF:10.0.0.255) in the net.IP structure.
   128  	// However, the net.IPNet structure stores the IPv4 address in a net.IPv4Len array.
   129  	// So, we convert the ipv4 address to a ipv4 mapped ipv6 address to be consistent with net.ParseIP()
   130  	// By converting to a ipv4 mappend ipv6 address callers can use this function in a more natural manner like
   131  	// GetBroadcastAddress(cidr) == net.ParseIP(10.0.0.255)
   132  	broadcastip = net.ParseIP(broadcastip.String())
   133  	return broadcastip
   134  }
   135  
   136  func PrefixesOverlap(cidr1 net.IPNet, cidr2 net.IPNet) bool {
   137  	if cidr1.Contains(cidr2.IP) || cidr2.Contains(cidr1.IP) {
   138  		return true
   139  	}
   140  	return false
   141  }
   142  
   143  func GetNetworkInterface() (string, error) {
   144  	// get primary public IP address
   145  	primaryPublicIP, err := GetIPAddress()
   146  	if err != nil {
   147  		return "", err
   148  	}
   149  
   150  	networkInterfaces, err := net.Interfaces()
   151  	if err != nil {
   152  		return "", err
   153  	}
   154  
   155  	for _, networkInterface := range networkInterfaces {
   156  		// return this interface iff it contains the
   157  		// primary public IP address
   158  
   159  		// skip down interface
   160  		if networkInterface.Flags&net.FlagUp == 0 {
   161  			continue
   162  		}
   163  		// skip loopback
   164  		if networkInterface.Flags&net.FlagLoopback != 0 {
   165  			continue
   166  		}
   167  		// list of unicast interface addresses for specific interface
   168  		addresses, err := networkInterface.Addrs()
   169  		if err != nil {
   170  			return "", err
   171  		}
   172  		// network end point address
   173  		for _, address := range addresses {
   174  			var ip net.IP
   175  			switch typedAddress := address.(type) {
   176  			case *net.IPNet:
   177  				ip = typedAddress.IP
   178  			case *net.IPAddr:
   179  				ip = typedAddress.IP
   180  			}
   181  			// skip loopback or wrong type
   182  			if ip == nil || ip.IsLoopback() {
   183  				continue
   184  			}
   185  
   186  			if ip.String() == primaryPublicIP {
   187  				// return this interface
   188  				return networkInterface.Name, nil
   189  			}
   190  		}
   191  	}
   192  
   193  	return "", fmt.Errorf("No network interfaces found")
   194  }
   195  
   196  func lessThan(left, right net.IP) bool {
   197  	var l, r big.Int
   198  	l.SetBytes(left)
   199  	r.SetBytes(right)
   200  	if l.Cmp(&r) == -1 {
   201  		return true
   202  	}
   203  	return false
   204  }
   205  
   206  func greaterThan(left, right net.IP) bool {
   207  	var l, r big.Int
   208  	l.SetBytes(left)
   209  	r.SetBytes(right)
   210  	if l.Cmp(&r) == 1 {
   211  		return true
   212  	}
   213  	return false
   214  }
   215  
   216  func RangesOverlap(range1start, range1end, range2start, range2end net.IP) bool {
   217  
   218  	if lessThan(range1start, range2start) && lessThan(range2start, range1end) {
   219  		return true
   220  	}
   221  	if lessThan(range2start, range1start) && lessThan(range1start, range2end) {
   222  		return true
   223  	}
   224  	if range1start.Equal(range2start) || range1end.Equal(range2end) {
   225  		return true
   226  	}
   227  	if range1end.Equal(range2start) || range1start.Equal(range2end) {
   228  		return true
   229  	}
   230  	return false
   231  }
   232  
   233  func IsRangeInCIDR(start, end net.IP, cidr *net.IPNet) bool {
   234  	if cidr.Contains(start) && cidr.Contains(end) {
   235  		return true
   236  	}
   237  	return false
   238  }
   239  
   240  func RangeContains(start, end, ip net.IP) bool {
   241  	if x := start.To4(); x != nil {
   242  		start = x
   243  	}
   244  	if x := end.To4(); x != nil {
   245  		end = x
   246  	}
   247  	if x := ip.To4(); x != nil {
   248  		ip = x
   249  	}
   250  
   251  	if len(ip) != len(start) {
   252  		return false
   253  	}
   254  	if len(ip) != len(end) {
   255  		return false
   256  	}
   257  
   258  	if ip.Equal(start) || ip.Equal(end) {
   259  		return true
   260  	}
   261  	if greaterThan(ip, start) && lessThan(ip, end) {
   262  		return true
   263  	}
   264  	return false
   265  }
   266  
   267  //TODO: Create an IPRange class.