github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/network/p2p/netutil/net.go (about)

     1  package netutil
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"sort"
     9  	"strings"
    10  )
    11  
    12  var lan4, lan6, special4, special6 Netlist
    13  
    14  func init() {
    15  	lan4.Add("0.0.0.0/8")
    16  	lan4.Add("10.0.0.0/8")
    17  	lan4.Add("172.16.0.0/12")
    18  	lan4.Add("192.168.0.0/16")
    19  	lan6.Add("fe80::/10")
    20  	lan6.Add("fc00::/7")
    21  	special4.Add("192.0.0.0/29")
    22  	special4.Add("192.0.0.9/32")
    23  	special4.Add("192.0.0.170/32")
    24  	special4.Add("192.0.0.171/32")
    25  	special4.Add("192.0.2.0/24")
    26  	special4.Add("192.31.196.0/24")
    27  	special4.Add("192.52.193.0/24")
    28  	special4.Add("192.88.99.0/24")
    29  	special4.Add("192.175.48.0/24")
    30  	special4.Add("198.18.0.0/15")
    31  	special4.Add("198.51.100.0/24")
    32  	special4.Add("203.0.113.0/24")
    33  	special4.Add("255.255.255.255/32")
    34  
    35  	special6.Add("100::/64")
    36  	special6.Add("2001::/32")
    37  	special6.Add("2001:1::1/128")
    38  	special6.Add("2001:2::/48")
    39  	special6.Add("2001:3::/32")
    40  	special6.Add("2001:4:112::/48")
    41  	special6.Add("2001:5::/32")
    42  	special6.Add("2001:10::/28")
    43  	special6.Add("2001:20::/28")
    44  	special6.Add("2001:db8::/32")
    45  	special6.Add("2002::/16")
    46  }
    47  
    48  type Netlist []net.IPNet
    49  
    50  func ParseNetlist(s string) (*Netlist, error) {
    51  	ws := strings.NewReplacer(" ", "", "\n", "", "\t", "")
    52  	masks := strings.Split(ws.Replace(s), ",")
    53  	l := make(Netlist, 0)
    54  	for _, mask := range masks {
    55  		if mask == "" {
    56  			continue
    57  		}
    58  		_, n, err := net.ParseCIDR(mask)
    59  		if err != nil {
    60  			return nil, err
    61  		}
    62  		l = append(l, *n)
    63  	}
    64  	return &l, nil
    65  }
    66  
    67  func (l Netlist) MarshalTOML() interface{} {
    68  	list := make([]string, 0, len(l))
    69  	for _, net := range l {
    70  		list = append(list, net.String())
    71  	}
    72  	return list
    73  }
    74  
    75  func (l *Netlist) UnmarshalTOML(fn func(interface{}) error) error {
    76  	var masks []string
    77  	if err := fn(&masks); err != nil {
    78  		return err
    79  	}
    80  	for _, mask := range masks {
    81  		_, n, err := net.ParseCIDR(mask)
    82  		if err != nil {
    83  			return err
    84  		}
    85  		*l = append(*l, *n)
    86  	}
    87  	return nil
    88  }
    89  
    90  func (l *Netlist) Add(cidr string) {
    91  	_, n, err := net.ParseCIDR(cidr)
    92  	if err != nil {
    93  		panic(err)
    94  	}
    95  	*l = append(*l, *n)
    96  }
    97  
    98  func (l *Netlist) Contains(ip net.IP) bool {
    99  	if l == nil {
   100  		return false
   101  	}
   102  	for _, net := range *l {
   103  		if net.Contains(ip) {
   104  			return true
   105  		}
   106  	}
   107  	return false
   108  }
   109  
   110  func IsLAN(ip net.IP) bool {
   111  	if ip.IsLoopback() {
   112  		return true
   113  	}
   114  	if v4 := ip.To4(); v4 != nil {
   115  		return lan4.Contains(v4)
   116  	}
   117  	return lan6.Contains(ip)
   118  }
   119  
   120  func IsSpecialNetwork(ip net.IP) bool {
   121  	if ip.IsMulticast() {
   122  		return true
   123  	}
   124  	if v4 := ip.To4(); v4 != nil {
   125  		return special4.Contains(v4)
   126  	}
   127  	return special6.Contains(ip)
   128  }
   129  
   130  var (
   131  	errInvalid     = errors.New("invalid IP")
   132  	errUnspecified = errors.New("zero address")
   133  	errSpecial     = errors.New("special network")
   134  	errLoopback    = errors.New("loopback address from non-loopback host")
   135  	errLAN         = errors.New("LAN address from WAN host")
   136  )
   137  
   138  func CheckRelayIP(sender, addr net.IP) error {
   139  	if len(addr) != net.IPv4len && len(addr) != net.IPv6len {
   140  		return errInvalid
   141  	}
   142  	if addr.IsUnspecified() {
   143  		return errUnspecified
   144  	}
   145  	if IsSpecialNetwork(addr) {
   146  		return errSpecial
   147  	}
   148  	if addr.IsLoopback() && !sender.IsLoopback() {
   149  		return errLoopback
   150  	}
   151  	if IsLAN(addr) && !IsLAN(sender) {
   152  		return errLAN
   153  	}
   154  	return nil
   155  }
   156  
   157  func SameNet(bits uint, ip, other net.IP) bool {
   158  	ip4, other4 := ip.To4(), other.To4()
   159  	switch {
   160  	case (ip4 == nil) != (other4 == nil):
   161  		return false
   162  	case ip4 != nil:
   163  		return sameNet(bits, ip4, other4)
   164  	default:
   165  		return sameNet(bits, ip.To16(), other.To16())
   166  	}
   167  }
   168  
   169  func sameNet(bits uint, ip, other net.IP) bool {
   170  	nb := int(bits / 8)
   171  	mask := ^byte(0xFF >> (bits % 8))
   172  	if mask != 0 && nb < len(ip) && ip[nb]&mask != other[nb]&mask {
   173  		return false
   174  	}
   175  	return nb <= len(ip) && bytes.Equal(ip[:nb], other[:nb])
   176  }
   177  
   178  type DistinctNetSet struct {
   179  	Subnet uint
   180  	Limit  uint
   181  
   182  	members map[string]uint
   183  	buf     net.IP
   184  }
   185  
   186  func (s *DistinctNetSet) Add(ip net.IP) bool {
   187  	key := s.key(ip)
   188  	n := s.members[string(key)]
   189  	if n < s.Limit {
   190  		s.members[string(key)] = n + 1
   191  		return true
   192  	}
   193  	return false
   194  }
   195  
   196  func (s *DistinctNetSet) Remove(ip net.IP) {
   197  	key := s.key(ip)
   198  	if n, ok := s.members[string(key)]; ok {
   199  		if n == 1 {
   200  			delete(s.members, string(key))
   201  		} else {
   202  			s.members[string(key)] = n - 1
   203  		}
   204  	}
   205  }
   206  
   207  func (s DistinctNetSet) Contains(ip net.IP) bool {
   208  	key := s.key(ip)
   209  	_, ok := s.members[string(key)]
   210  	return ok
   211  }
   212  
   213  func (s DistinctNetSet) Len() int {
   214  	n := uint(0)
   215  	for _, i := range s.members {
   216  		n += i
   217  	}
   218  	return int(n)
   219  }
   220  
   221  func (s *DistinctNetSet) key(ip net.IP) net.IP {
   222  	if s.members == nil {
   223  		s.members = make(map[string]uint)
   224  		s.buf = make(net.IP, 17)
   225  	}
   226  	typ := byte('6')
   227  	if ip4 := ip.To4(); ip4 != nil {
   228  		typ, ip = '4', ip4
   229  	}
   230  	bits := s.Subnet
   231  	if bits > uint(len(ip)*8) {
   232  		bits = uint(len(ip) * 8)
   233  	}
   234  	nb := int(bits / 8)
   235  	mask := ^byte(0xFF >> (bits % 8))
   236  	s.buf[0] = typ
   237  	buf := append(s.buf[:1], ip[:nb]...)
   238  	if nb < len(ip) && mask != 0 {
   239  		buf = append(buf, ip[nb]&mask)
   240  	}
   241  	return buf
   242  }
   243  
   244  func (s DistinctNetSet) String() string {
   245  	var buf bytes.Buffer
   246  	buf.WriteString("{")
   247  	keys := make([]string, 0, len(s.members))
   248  	for k := range s.members {
   249  		keys = append(keys, k)
   250  	}
   251  	sort.Strings(keys)
   252  	for i, k := range keys {
   253  		var ip net.IP
   254  		if k[0] == '4' {
   255  			ip = make(net.IP, 4)
   256  		} else {
   257  			ip = make(net.IP, 16)
   258  		}
   259  		copy(ip, k[1:])
   260  		fmt.Fprintf(&buf, "%vĂ—%d", ip, s.members[k])
   261  		if i != len(keys)-1 {
   262  			buf.WriteString(" ")
   263  		}
   264  	}
   265  	buf.WriteString("}")
   266  	return buf.String()
   267  }