github.com/bytom/bytom@v1.1.2-0.20221014091027-bbcba3df6075/p2p/netutil/net.go (about)

     1  // Package netutil contains extensions to the net package.
     2  package netutil
     3  
     4  import (
     5  	"bytes"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"sort"
    10  	"strconv"
    11  	"strings"
    12  
    13  	log "github.com/sirupsen/logrus"
    14  )
    15  
    16  var lan4, lan6, special4, special6 Netlist
    17  
    18  var (
    19  	logModule = "netutil"
    20  
    21  	errInvalidIP   = errors.New("ip is invalid")
    22  	errInvalidPort = errors.New("port is invalid")
    23  )
    24  
    25  func init() {
    26  	// Lists from RFC 5735, RFC 5156,
    27  	// https://www.iana.org/assignments/iana-ipv4-special-registry/
    28  	lan4.Add("0.0.0.0/8")              // "This" network
    29  	lan4.Add("10.0.0.0/8")             // Private Use
    30  	lan4.Add("172.16.0.0/12")          // Private Use
    31  	lan4.Add("192.168.0.0/16")         // Private Use
    32  	lan6.Add("fe80::/10")              // Link-Local
    33  	lan6.Add("fc00::/7")               // Unique-Local
    34  	special4.Add("192.0.0.0/29")       // IPv4 Service Continuity
    35  	special4.Add("192.0.0.9/32")       // PCP Anycast
    36  	special4.Add("192.0.0.170/32")     // NAT64/DNS64 Discovery
    37  	special4.Add("192.0.0.171/32")     // NAT64/DNS64 Discovery
    38  	special4.Add("192.0.2.0/24")       // TEST-NET-1
    39  	special4.Add("192.31.196.0/24")    // AS112
    40  	special4.Add("192.52.193.0/24")    // AMT
    41  	special4.Add("192.88.99.0/24")     // 6to4 Relay Anycast
    42  	special4.Add("192.175.48.0/24")    // AS112
    43  	special4.Add("198.18.0.0/15")      // Device Benchmark Testing
    44  	special4.Add("198.51.100.0/24")    // TEST-NET-2
    45  	special4.Add("203.0.113.0/24")     // TEST-NET-3
    46  	special4.Add("255.255.255.255/32") // Limited Broadcast
    47  
    48  	// http://www.iana.org/assignments/iana-ipv6-special-registry/
    49  	special6.Add("100::/64")
    50  	special6.Add("2001::/32")
    51  	special6.Add("2001:1::1/128")
    52  	special6.Add("2001:2::/48")
    53  	special6.Add("2001:3::/32")
    54  	special6.Add("2001:4:112::/48")
    55  	special6.Add("2001:5::/32")
    56  	special6.Add("2001:10::/28")
    57  	special6.Add("2001:20::/28")
    58  	special6.Add("2001:db8::/32")
    59  	special6.Add("2002::/16")
    60  }
    61  
    62  // Netlist is a list of IP networks.
    63  type Netlist []net.IPNet
    64  
    65  // ParseNetlist parses a comma-separated list of CIDR masks.
    66  // Whitespace and extra commas are ignored.
    67  func ParseNetlist(s string) (*Netlist, error) {
    68  	ws := strings.NewReplacer(" ", "", "\n", "", "\t", "")
    69  	masks := strings.Split(ws.Replace(s), ",")
    70  	l := make(Netlist, 0)
    71  	for _, mask := range masks {
    72  		if mask == "" {
    73  			continue
    74  		}
    75  		_, n, err := net.ParseCIDR(mask)
    76  		if err != nil {
    77  			return nil, err
    78  		}
    79  		l = append(l, *n)
    80  	}
    81  	return &l, nil
    82  }
    83  
    84  // MarshalTOML implements toml.MarshalerRec.
    85  func (l Netlist) MarshalTOML() interface{} {
    86  	list := make([]string, 0, len(l))
    87  	for _, net := range l {
    88  		list = append(list, net.String())
    89  	}
    90  	return list
    91  }
    92  
    93  // UnmarshalTOML implements toml.UnmarshalerRec.
    94  func (l *Netlist) UnmarshalTOML(fn func(interface{}) error) error {
    95  	var masks []string
    96  	if err := fn(&masks); err != nil {
    97  		return err
    98  	}
    99  	for _, mask := range masks {
   100  		_, n, err := net.ParseCIDR(mask)
   101  		if err != nil {
   102  			return err
   103  		}
   104  		*l = append(*l, *n)
   105  	}
   106  	return nil
   107  }
   108  
   109  // Add parses a CIDR mask and appends it to the list. It panics for invalid masks and is
   110  // intended to be used for setting up static lists.
   111  func (l *Netlist) Add(cidr string) {
   112  	_, n, err := net.ParseCIDR(cidr)
   113  	if err != nil {
   114  		panic(err)
   115  	}
   116  	*l = append(*l, *n)
   117  }
   118  
   119  // Contains reports whether the given IP is contained in the list.
   120  func (l *Netlist) Contains(ip net.IP) bool {
   121  	if l == nil {
   122  		return false
   123  	}
   124  	for _, net := range *l {
   125  		if net.Contains(ip) {
   126  			return true
   127  		}
   128  	}
   129  	return false
   130  }
   131  
   132  // IsLAN reports whether an IP is a local network address.
   133  func IsLAN(ip net.IP) bool {
   134  	if ip.IsLoopback() {
   135  		return true
   136  	}
   137  	if v4 := ip.To4(); v4 != nil {
   138  		return lan4.Contains(v4)
   139  	}
   140  	return lan6.Contains(ip)
   141  }
   142  
   143  // IsSpecialNetwork reports whether an IP is located in a special-use network range
   144  // This includes broadcast, multicast and documentation addresses.
   145  func IsSpecialNetwork(ip net.IP) bool {
   146  	if ip.IsMulticast() {
   147  		return true
   148  	}
   149  	if v4 := ip.To4(); v4 != nil {
   150  		return special4.Contains(v4)
   151  	}
   152  	return special6.Contains(ip)
   153  }
   154  
   155  var (
   156  	errInvalid     = errors.New("invalid IP")
   157  	errUnspecified = errors.New("zero address")
   158  	errSpecial     = errors.New("special network")
   159  	errLoopback    = errors.New("loopback address from non-loopback host")
   160  	errLAN         = errors.New("LAN address from WAN host")
   161  )
   162  
   163  // CheckRelayIP reports whether an IP relayed from the given sender IP
   164  // is a valid connection target.
   165  //
   166  // There are four rules:
   167  //   - Special network addresses are never valid.
   168  //   - Loopback addresses are OK if relayed by a loopback host.
   169  //   - LAN addresses are OK if relayed by a LAN host.
   170  //   - All other addresses are always acceptable.
   171  func CheckRelayIP(sender, addr net.IP) error {
   172  	if len(addr) != net.IPv4len && len(addr) != net.IPv6len {
   173  		return errInvalid
   174  	}
   175  	if addr.IsUnspecified() {
   176  		return errUnspecified
   177  	}
   178  	if IsSpecialNetwork(addr) {
   179  		return errSpecial
   180  	}
   181  	if addr.IsLoopback() && !sender.IsLoopback() {
   182  		return errLoopback
   183  	}
   184  	if IsLAN(addr) && !IsLAN(sender) {
   185  		return errLAN
   186  	}
   187  	return nil
   188  }
   189  
   190  // SameNet reports whether two IP addresses have an equal prefix of the given bit length.
   191  func SameNet(bits uint, ip, other net.IP) bool {
   192  	ip4, other4 := ip.To4(), other.To4()
   193  	switch {
   194  	case (ip4 == nil) != (other4 == nil):
   195  		return false
   196  	case ip4 != nil:
   197  		return sameNet(bits, ip4, other4)
   198  	default:
   199  		return sameNet(bits, ip.To16(), other.To16())
   200  	}
   201  }
   202  
   203  func sameNet(bits uint, ip, other net.IP) bool {
   204  	nb := int(bits / 8)
   205  	mask := ^byte(0xFF >> (bits % 8))
   206  	if mask != 0 && nb < len(ip) && ip[nb]&mask != other[nb]&mask {
   207  		return false
   208  	}
   209  	return nb <= len(ip) && bytes.Equal(ip[:nb], other[:nb])
   210  }
   211  
   212  // DistinctNetSet tracks IPs, ensuring that at most N of them
   213  // fall into the same network range.
   214  type DistinctNetSet struct {
   215  	Subnet uint // number of common prefix bits
   216  	Limit  uint // maximum number of IPs in each subnet
   217  
   218  	members map[string]uint
   219  	buf     net.IP
   220  }
   221  
   222  // Add adds an IP address to the set. It returns false (and doesn't add the IP) if the
   223  // number of existing IPs in the defined range exceeds the limit.
   224  func (s *DistinctNetSet) Add(ip net.IP) bool {
   225  	key := s.key(ip)
   226  	n := s.members[string(key)]
   227  	if n < s.Limit {
   228  		s.members[string(key)] = n + 1
   229  		return true
   230  	}
   231  	return false
   232  }
   233  
   234  // Remove removes an IP from the set.
   235  func (s *DistinctNetSet) Remove(ip net.IP) {
   236  	key := s.key(ip)
   237  	if n, ok := s.members[string(key)]; ok {
   238  		if n == 1 {
   239  			delete(s.members, string(key))
   240  		} else {
   241  			s.members[string(key)] = n - 1
   242  		}
   243  	}
   244  }
   245  
   246  // Contains whether the given IP is contained in the set.
   247  func (s DistinctNetSet) Contains(ip net.IP) bool {
   248  	key := s.key(ip)
   249  	_, ok := s.members[string(key)]
   250  	return ok
   251  }
   252  
   253  // Len returns the number of tracked IPs.
   254  func (s DistinctNetSet) Len() int {
   255  	n := uint(0)
   256  	for _, i := range s.members {
   257  		n += i
   258  	}
   259  	return int(n)
   260  }
   261  
   262  // key encodes the map key for an address into a temporary buffer.
   263  //
   264  // The first byte of key is '4' or '6' to distinguish IPv4/IPv6 address types.
   265  // The remainder of the key is the IP, truncated to the number of bits.
   266  func (s *DistinctNetSet) key(ip net.IP) net.IP {
   267  	// Lazily initialize storage.
   268  	if s.members == nil {
   269  		s.members = make(map[string]uint)
   270  		s.buf = make(net.IP, 17)
   271  	}
   272  	// Canonicalize ip and bits.
   273  	typ := byte('6')
   274  	if ip4 := ip.To4(); ip4 != nil {
   275  		typ, ip = '4', ip4
   276  	}
   277  	bits := s.Subnet
   278  	if bits > uint(len(ip)*8) {
   279  		bits = uint(len(ip) * 8)
   280  	}
   281  	// Encode the prefix into s.buf.
   282  	nb := int(bits / 8)
   283  	mask := ^byte(0xFF >> (bits % 8))
   284  	s.buf[0] = typ
   285  	buf := append(s.buf[:1], ip[:nb]...)
   286  	if nb < len(ip) && mask != 0 {
   287  		buf = append(buf, ip[nb]&mask)
   288  	}
   289  	return buf
   290  }
   291  
   292  // String implements fmt.Stringer
   293  func (s DistinctNetSet) String() string {
   294  	var buf bytes.Buffer
   295  	buf.WriteString("{")
   296  	keys := make([]string, 0, len(s.members))
   297  	for k := range s.members {
   298  		keys = append(keys, k)
   299  	}
   300  	sort.Strings(keys)
   301  	for i, k := range keys {
   302  		var ip net.IP
   303  		if k[0] == '4' {
   304  			ip = make(net.IP, 4)
   305  		} else {
   306  			ip = make(net.IP, 16)
   307  		}
   308  		copy(ip, k[1:])
   309  		fmt.Fprintf(&buf, "%vĂ—%d", ip, s.members[k])
   310  		if i != len(keys)-1 {
   311  			buf.WriteString(" ")
   312  		}
   313  	}
   314  	buf.WriteString("}")
   315  	return buf.String()
   316  }
   317  
   318  func CheckAndSplitAddresses(addressesStr string) []string {
   319  	if addressesStr == "" {
   320  		return nil
   321  	}
   322  
   323  	var addresses []string
   324  	splits := strings.Split(addressesStr, ",")
   325  	for _, address := range splits {
   326  		ip, port, err := net.SplitHostPort(address)
   327  		if err != nil {
   328  			log.WithFields(log.Fields{"module": logModule, "err": err, "address": address}).Warn("net.SplitHostPort")
   329  			continue
   330  		}
   331  
   332  		if validIP := net.ParseIP(ip); validIP == nil {
   333  			log.WithFields(log.Fields{"module": logModule, "err": errInvalidIP, "ip": ip}).Warn("net.ParseIP")
   334  			continue
   335  		}
   336  
   337  		if _, err := strconv.ParseUint(port, 10, 16); err != nil {
   338  			log.WithFields(log.Fields{"module": logModule, "err": errInvalidPort, "port": port}).Warn("strconv parse port")
   339  			continue
   340  		}
   341  
   342  		addresses = append(addresses, address)
   343  	}
   344  	return addresses
   345  }