github.com/decred/dcrlnd@v0.7.6/nat/pmp.go (about)

     1  package nat
     2  
     3  import (
     4  	"fmt"
     5  	"net"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/jackpal/gateway"
    10  	natpmp "github.com/jackpal/go-nat-pmp"
    11  )
    12  
    13  // Compile-time check to ensure PMP implements the Traversal interface.
    14  var _ Traversal = (*PMP)(nil)
    15  
    16  // PMP is a concrete implementation of the Traversal interface that uses the
    17  // NAT-PMP technique.
    18  type PMP struct {
    19  	client *natpmp.Client
    20  
    21  	forwardedPortsMtx sync.Mutex
    22  	forwardedPorts    map[uint16]struct{}
    23  }
    24  
    25  // DiscoverPMP attempts to scan the local network for a NAT-PMP enabled device
    26  // within the given timeout.
    27  func DiscoverPMP(timeout time.Duration) (*PMP, error) {
    28  	// Retrieve the gateway IP address of the local network.
    29  	gatewayIP, err := gateway.DiscoverGateway()
    30  	if err != nil {
    31  		return nil, err
    32  	}
    33  
    34  	pmp := &PMP{
    35  		client:         natpmp.NewClientWithTimeout(gatewayIP, timeout),
    36  		forwardedPorts: make(map[uint16]struct{}),
    37  	}
    38  
    39  	// We'll then attempt to retrieve the external IP address of this
    40  	// device to ensure it is not behind multiple NATs.
    41  	if _, err := pmp.ExternalIP(); err != nil {
    42  		return nil, err
    43  	}
    44  
    45  	return pmp, nil
    46  }
    47  
    48  // ExternalIP returns the external IP address of the NAT-PMP enabled device.
    49  func (p *PMP) ExternalIP() (net.IP, error) {
    50  	res, err := p.client.GetExternalAddress()
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  
    55  	ip := net.IP(res.ExternalIPAddress[:])
    56  	if isPrivateIP(ip) {
    57  		return nil, ErrMultipleNAT
    58  	}
    59  
    60  	return ip, nil
    61  }
    62  
    63  // AddPortMapping enables port forwarding for the given port.
    64  func (p *PMP) AddPortMapping(port uint16) error {
    65  	p.forwardedPortsMtx.Lock()
    66  	defer p.forwardedPortsMtx.Unlock()
    67  
    68  	_, err := p.client.AddPortMapping("tcp", int(port), int(port), 0)
    69  	if err != nil {
    70  		return err
    71  	}
    72  
    73  	p.forwardedPorts[port] = struct{}{}
    74  
    75  	return nil
    76  }
    77  
    78  // DeletePortMapping disables port forwarding for the given port.
    79  func (p *PMP) DeletePortMapping(port uint16) error {
    80  	p.forwardedPortsMtx.Lock()
    81  	defer p.forwardedPortsMtx.Unlock()
    82  
    83  	if _, exists := p.forwardedPorts[port]; !exists {
    84  		return fmt.Errorf("port %d is not being forwarded", port)
    85  	}
    86  
    87  	_, err := p.client.AddPortMapping("tcp", int(port), 0, 0)
    88  	if err != nil {
    89  		return err
    90  	}
    91  
    92  	delete(p.forwardedPorts, port)
    93  
    94  	return nil
    95  }
    96  
    97  // ForwardedPorts returns a list of ports currently being forwarded.
    98  func (p *PMP) ForwardedPorts() []uint16 {
    99  	p.forwardedPortsMtx.Lock()
   100  	defer p.forwardedPortsMtx.Unlock()
   101  
   102  	ports := make([]uint16, 0, len(p.forwardedPorts))
   103  	for port := range p.forwardedPorts {
   104  		ports = append(ports, port)
   105  	}
   106  
   107  	return ports
   108  }
   109  
   110  // Name returns the name of the specific NAT traversal technique used.
   111  func (p *PMP) Name() string {
   112  	return "NAT-PMP"
   113  }