github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/network/p2p/nat/natpmp.go (about)

     1  package nat
     2  
     3  import (
     4  	"fmt"
     5  	"net"
     6  	"strings"
     7  	"time"
     8  
     9  	natpmp "github.com/jackpal/go-nat-pmp"
    10  )
    11  
    12  type pmp struct {
    13  	gw net.IP
    14  	c  *natpmp.Client
    15  }
    16  
    17  func (n *pmp) String() string {
    18  	return fmt.Sprintf("NAT-PMP(%v)", n.gw)
    19  }
    20  
    21  func (n *pmp) ExternalIP() (net.IP, error) {
    22  	response, err := n.c.GetExternalAddress()
    23  	if err != nil {
    24  		return nil, err
    25  	}
    26  	return response.ExternalIPAddress[:], nil
    27  }
    28  
    29  func (n *pmp) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
    30  	if lifetime <= 0 {
    31  		return fmt.Errorf("lifetime must not be <= 0")
    32  	}
    33  
    34  	_, err := n.c.AddPortMapping(strings.ToLower(protocol), intport, extport, int(lifetime/time.Second))
    35  	return err
    36  }
    37  
    38  func (n *pmp) DeleteMapping(protocol string, extport, intport int) (err error) {
    39  
    40  	_, err = n.c.AddPortMapping(strings.ToLower(protocol), intport, 0, 0)
    41  	return err
    42  }
    43  
    44  func discoverPMP() Interface {
    45  
    46  	gws := potentialGateways()
    47  	found := make(chan *pmp, len(gws))
    48  	for i := range gws {
    49  		gw := gws[i]
    50  		go func() {
    51  			c := natpmp.NewClient(gw)
    52  			if _, err := c.GetExternalAddress(); err != nil {
    53  				found <- nil
    54  			} else {
    55  				found <- &pmp{gw, c}
    56  			}
    57  		}()
    58  	}
    59  
    60  	timeout := time.NewTimer(1 * time.Second)
    61  	defer timeout.Stop()
    62  	for range gws {
    63  		select {
    64  		case c := <-found:
    65  			if c != nil {
    66  				return c
    67  			}
    68  		case <-timeout.C:
    69  			return nil
    70  		}
    71  	}
    72  	return nil
    73  }
    74  
    75  var (
    76  	_, lan10, _  = net.ParseCIDR("10.0.0.0/8")
    77  	_, lan176, _ = net.ParseCIDR("172.16.0.0/12")
    78  	_, lan192, _ = net.ParseCIDR("192.168.0.0/16")
    79  )
    80  
    81  func potentialGateways() (gws []net.IP) {
    82  	ifaces, err := net.Interfaces()
    83  	if err != nil {
    84  		return nil
    85  	}
    86  	for _, iface := range ifaces {
    87  		ifaddrs, err := iface.Addrs()
    88  		if err != nil {
    89  			return gws
    90  		}
    91  		for _, addr := range ifaddrs {
    92  			switch x := addr.(type) {
    93  			case *net.IPNet:
    94  				if lan10.Contains(x.IP) || lan176.Contains(x.IP) || lan192.Contains(x.IP) {
    95  					ip := x.IP.Mask(x.Mask).To4()
    96  					if ip != nil {
    97  						ip[3] = ip[3] | 0x01
    98  						gws = append(gws, ip)
    99  					}
   100  				}
   101  			}
   102  		}
   103  	}
   104  	return gws
   105  }