github.com/richardwilkes/toolbox@v1.121.0/xio/network/natpmp/natpmp.go (about)

     1  // Copyright (c) 2016-2024 by Richard A. Wilkes. All rights reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the Mozilla Public
     4  // License, version 2.0. If a copy of the MPL was not distributed with
     5  // this file, You can obtain one at http://mozilla.org/MPL/2.0/.
     6  //
     7  // This Source Code Form is "Incompatible With Secondary Licenses", as
     8  // defined by the Mozilla Public License, version 2.0.
     9  
    10  // Package natpmp provides an implementation of NAT-PMP.
    11  // See https://tools.ietf.org/html/rfc6886
    12  package natpmp
    13  
    14  import (
    15  	"encoding/binary"
    16  	"errors"
    17  	"net"
    18  	"sync"
    19  	"time"
    20  
    21  	"github.com/jackpal/gateway"
    22  	"github.com/richardwilkes/toolbox/atexit"
    23  	"github.com/richardwilkes/toolbox/errs"
    24  )
    25  
    26  const (
    27  	protocolVersion = 0
    28  	expiration      = uint32(time.Hour / time.Second)
    29  	tcpFlag         = 0x10000
    30  )
    31  
    32  const (
    33  	opExternalAddress = iota
    34  	opMapUDP
    35  	opMapTCP
    36  )
    37  
    38  type mapping struct {
    39  	notifyChan chan any
    40  	renew      time.Time
    41  	external   int
    42  }
    43  
    44  var (
    45  	once     sync.Once
    46  	gw       net.IP
    47  	lock     sync.RWMutex
    48  	mappings = make(map[int]mapping)
    49  )
    50  
    51  // ExternalAddress returns the external address the internet sees you as having.
    52  func ExternalAddress() (net.IP, error) {
    53  	buffer := make([]byte, 2)
    54  	buffer[0] = protocolVersion
    55  	buffer[1] = opExternalAddress
    56  	response, err := call(buffer, 12)
    57  	if err != nil {
    58  		return nil, err
    59  	}
    60  	return response[8:12], nil
    61  }
    62  
    63  // MapTCP maps the specified TCP port for external access. It returns the port on the external address that can be used
    64  // to connect to the internal port. If you wish to be notified of changes to the external port mapping, provide a notify
    65  // channel. It will be sent an int containing the updated external port mapping when it changes or an error if a renewal
    66  // fails. The channel will only be sent to if it is ready.
    67  func MapTCP(port int, notifyChan chan any) (int, error) {
    68  	if err := checkPort(port); err != nil {
    69  		return 0, err
    70  	}
    71  	response, err := call(makeMapBuffer(opMapTCP, uint16(port)), 16)
    72  	if err != nil {
    73  		return 0, err
    74  	}
    75  	external := int(binary.BigEndian.Uint16(response[10:12]))
    76  	addMapping(port|tcpFlag, external, notifyChan)
    77  	return external, nil
    78  }
    79  
    80  // MapUDP maps the specified UDP port for external access. It returns the port on the external address that can be used
    81  // to connect to the internal port. If you wish to be notified of changes to the external port mapping, provide a notify
    82  // channel. It will be sent an int containing the updated external port mapping when it changes or an error if a renewal
    83  // fails. The channel will only be sent to if it is ready.
    84  func MapUDP(port int, notifyChan chan any) (int, error) {
    85  	if err := checkPort(port); err != nil {
    86  		return 0, err
    87  	}
    88  	response, err := call(makeMapBuffer(opMapUDP, uint16(port)), 16)
    89  	if err != nil {
    90  		return 0, err
    91  	}
    92  	external := int(binary.BigEndian.Uint16(response[10:12]))
    93  	addMapping(port, external, notifyChan)
    94  	return external, nil
    95  }
    96  
    97  // UnmapTCP unmaps a previously mapped internal TCP port.
    98  func UnmapTCP(port int) error {
    99  	if err := checkPort(port); err != nil {
   100  		return err
   101  	}
   102  	_, err := call(makeUnmapBuffer(opMapTCP, uint16(port)), 16)
   103  	if err != nil {
   104  		return err
   105  	}
   106  	removeMapping(port | tcpFlag)
   107  	return err
   108  }
   109  
   110  // UnmapUDP unmaps a previously mapped internal UDP port.
   111  func UnmapUDP(port int) error {
   112  	if err := checkPort(port); err != nil {
   113  		return err
   114  	}
   115  	_, err := call(makeUnmapBuffer(opMapUDP, uint16(port)), 16)
   116  	if err != nil {
   117  		return err
   118  	}
   119  	removeMapping(port)
   120  	return err
   121  }
   122  
   123  func checkPort(port int) error {
   124  	if port > 0 && port < 65536 {
   125  		return nil
   126  	}
   127  	return errs.Newf("port (%d) must be in the range 1-65535", port)
   128  }
   129  
   130  func makeMapBuffer(op byte, port uint16) []byte {
   131  	buffer := makeUnmapBuffer(op, port)
   132  	binary.BigEndian.PutUint16(buffer[6:8], port)
   133  	binary.BigEndian.PutUint32(buffer[8:12], expiration)
   134  	return buffer
   135  }
   136  
   137  func makeUnmapBuffer(op byte, port uint16) []byte {
   138  	buffer := make([]byte, 12)
   139  	buffer[0] = protocolVersion
   140  	buffer[1] = op
   141  	binary.BigEndian.PutUint16(buffer[4:6], port)
   142  	return buffer
   143  }
   144  
   145  func addMapping(internal, external int, notifyChan chan any) {
   146  	lock.Lock()
   147  	mappings[internal] = mapping{
   148  		external:   external,
   149  		renew:      time.Now().Add(50 * time.Minute),
   150  		notifyChan: notifyChan,
   151  	}
   152  	lock.Unlock()
   153  }
   154  
   155  func removeMapping(internal int) {
   156  	lock.Lock()
   157  	delete(mappings, internal)
   158  	lock.Unlock()
   159  }
   160  
   161  func call(msg []byte, resultSize int) ([]byte, error) {
   162  	once.Do(setupGateway)
   163  	if gw == nil {
   164  		return nil, errs.New("No gateway found")
   165  	}
   166  	conn, err := net.DialUDP("udp", nil, &net.UDPAddr{
   167  		IP:   gw,
   168  		Port: 5351,
   169  	})
   170  	if err != nil {
   171  		return nil, errs.Wrap(err)
   172  	}
   173  	defer func() { err = conn.Close() }()
   174  	timeout := time.Now().Add(30 * time.Second)
   175  	err = conn.SetDeadline(timeout)
   176  	if err != nil {
   177  		return nil, errs.Wrap(err)
   178  	}
   179  	result := make([]byte, resultSize)
   180  	for time.Now().Before(timeout) {
   181  		if _, err = conn.Write(msg); err != nil {
   182  			return nil, errs.Wrap(err)
   183  		}
   184  		var n int
   185  		var remote *net.UDPAddr
   186  		if n, remote, err = conn.ReadFromUDP(result); err != nil {
   187  			var nerr net.Error
   188  			if errors.As(err, &nerr) && nerr.Timeout() {
   189  				continue
   190  			}
   191  			return nil, errs.Wrap(err)
   192  		}
   193  		if !remote.IP.Equal(gw) {
   194  			continue
   195  		}
   196  		if n != resultSize {
   197  			return nil, errs.Newf("unexpected result size (received %d, expected %d)", n, resultSize)
   198  		}
   199  		if result[0] != 0 {
   200  			return nil, errs.Newf("unknown protocol version (%d)", result[0])
   201  		}
   202  		expectedOp := msg[1] | 0x80
   203  		if result[1] != expectedOp {
   204  			return nil, errs.Newf("unexpected opcode (received %d, expected %d)", result[1], expectedOp)
   205  		}
   206  		code := binary.BigEndian.Uint16(result[2:4])
   207  		switch code {
   208  		case 0:
   209  			return result, nil
   210  		case 1:
   211  			return nil, errs.New("unsupported version")
   212  		case 2:
   213  			return nil, errs.New("not authorized")
   214  		case 3:
   215  			return nil, errs.New("network failure")
   216  		case 4:
   217  			return nil, errs.New("out of resources")
   218  		case 5:
   219  			return nil, errs.New("unsupported opcode")
   220  		default:
   221  			return nil, errs.Newf("unknown result code %d", code)
   222  		}
   223  	}
   224  	return nil, errs.Newf("timed out trying to contact gateway")
   225  }
   226  
   227  func setupGateway() {
   228  	var err error
   229  	if gw, err = gateway.DiscoverGateway(); err == nil {
   230  		atexit.Register(cleanup)
   231  		go renewals()
   232  	}
   233  }
   234  
   235  func renewals() {
   236  	for {
   237  		time.Sleep(time.Minute)
   238  		now := time.Now()
   239  		lock.RLock()
   240  		renew := make(map[int]mapping, len(mappings))
   241  		for k, v := range mappings {
   242  			if !now.Before(v.renew) {
   243  				renew[k] = v
   244  			}
   245  		}
   246  		lock.RUnlock()
   247  		for port, v := range renew {
   248  			var external int
   249  			var err error
   250  			if port&tcpFlag != 0 {
   251  				external, err = MapTCP(port&^tcpFlag, v.notifyChan)
   252  			} else {
   253  				external, err = MapUDP(port, v.notifyChan)
   254  			}
   255  			if v.notifyChan != nil {
   256  				if err != nil {
   257  					var portType string
   258  					if port&tcpFlag != 0 {
   259  						portType = "TCP"
   260  					} else {
   261  						portType = "UDP"
   262  					}
   263  					select {
   264  					case v.notifyChan <- errs.NewWithCausef(err, "mapping renewal for %s port %d failed", portType, port):
   265  					default:
   266  					}
   267  				} else if v.external != external {
   268  					select {
   269  					case v.notifyChan <- external:
   270  					default:
   271  					}
   272  				}
   273  			}
   274  		}
   275  	}
   276  }
   277  
   278  func cleanup() {
   279  	lock.RLock()
   280  	ports := make([]int, len(mappings))
   281  	for port := range mappings {
   282  		ports = append(ports, port)
   283  	}
   284  	lock.RUnlock()
   285  	for _, port := range ports {
   286  		var err error
   287  		if port&tcpFlag != 0 {
   288  			err = UnmapTCP(port &^ tcpFlag)
   289  		} else {
   290  			err = UnmapUDP(port)
   291  		}
   292  		if err != nil {
   293  			errs.Log(err)
   294  		}
   295  	}
   296  }