github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/nat/behavior/discovery.go (about)

     1  /*
     2   * Copyright (C) 2021 The "MysteriumNetwork/node" Authors.
     3   *
     4   * This program is free software: you can redistribute it and/or modify
     5   * it under the terms of the GNU General Public License as published by
     6   * the Free Software Foundation, either version 3 of the License, or
     7   * (at your option) any later version.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   */
    17  
    18  package behavior
    19  
    20  import (
    21  	"context"
    22  	"errors"
    23  	"fmt"
    24  	"net"
    25  	"sync"
    26  	"time"
    27  
    28  	"github.com/mysteriumnetwork/node/nat"
    29  	"github.com/pion/stun"
    30  )
    31  
    32  // Enum of DiscoverNATMapping return values
    33  const (
    34  	MappingNone                 = "none"
    35  	MappingIndependent          = "independent"
    36  	MappingAddressDependent     = "address"
    37  	MappingAddressPortDependent = "addressport"
    38  )
    39  
    40  // Enum of DiscoverNATFiltering return values
    41  const (
    42  	FilteringIndependent = "independent"
    43  	FilteringAddress     = "address"
    44  	FilteringAddressPort = "addressport"
    45  )
    46  
    47  // DefaultTimeout for each single STUN request.
    48  const DefaultTimeout = 3 * time.Second
    49  
    50  // STUN protocol compatibility errors
    51  var (
    52  	ErrResponseMessage = errors.New("error reading from response message channel")
    53  	ErrNoXorAddress    = errors.New("no XOR-MAPPED-ADDRESS in message")
    54  	ErrNoOtherAddress  = errors.New("no OTHER-ADDRESS in message")
    55  )
    56  
    57  // CHANGE-REQUEST value constants from RFC 5780 Section 7.2
    58  var (
    59  	changeRequestAddressPort = []byte{0x00, 0x00, 0x00, 0x06}
    60  	changeRequestPort        = []byte{0x00, 0x00, 0x00, 0x02}
    61  )
    62  
    63  type stunServerConn struct {
    64  	conn        *net.UDPConn
    65  	LocalAddr   net.Addr
    66  	RemoteAddr  *net.UDPAddr
    67  	OtherAddr   *net.UDPAddr
    68  	messageChan chan *stun.Message
    69  	stopChan    chan struct{}
    70  	stopOnce    sync.Once
    71  }
    72  
    73  func (c *stunServerConn) Close() error {
    74  	c.stopOnce.Do(func() {
    75  		close(c.stopChan)
    76  	})
    77  	return c.conn.Close()
    78  }
    79  
    80  // DiscoverNATBehavior returns either one of NATType* constants describing
    81  // NAT behavior in practical sense for P2P connections or error
    82  func DiscoverNATBehavior(ctx context.Context, address string, timeout time.Duration) (nat.NATType, error) {
    83  	if timeout == 0 {
    84  		timeout = DefaultTimeout
    85  	}
    86  
    87  	mapping, err := DiscoverNATMapping(ctx, address, timeout)
    88  	if err != nil {
    89  		return "", err
    90  	}
    91  	switch mapping {
    92  	case MappingAddressDependent, MappingAddressPortDependent:
    93  		return nat.NATTypeSymmetric, nil
    94  	case MappingNone:
    95  		return nat.NATTypeNone, nil
    96  	}
    97  
    98  	filtering, err := DiscoverNATFiltering(ctx, address, timeout)
    99  	if err != nil {
   100  		return "", err
   101  	}
   102  	switch filtering {
   103  	case FilteringIndependent:
   104  		return nat.NATTypeFullCone, nil
   105  	case FilteringAddress:
   106  		return nat.NATTypeRestrictedCone, nil
   107  	default:
   108  		return nat.NATTypePortRestrictedCone, nil
   109  	}
   110  }
   111  
   112  // DiscoverNATMapping returns either one of Mapping* constants describing
   113  // NAT mapping behavior or error
   114  func DiscoverNATMapping(ctx context.Context, address string, timeout time.Duration) (string, error) {
   115  	mapTestConn, err := connect(address)
   116  	if err != nil {
   117  		return "", fmt.Errorf("STUN connection init failed: %w", err)
   118  	}
   119  	defer mapTestConn.Close()
   120  
   121  	// Test I: Regular binding request
   122  	request := stun.MustBuild(stun.TransactionID, stun.BindingRequest)
   123  
   124  	ctx1, cl := context.WithTimeout(ctx, timeout)
   125  	defer cl()
   126  	resp, err := mapTestConn.roundTrip(ctx1, request, mapTestConn.RemoteAddr)
   127  	if err != nil {
   128  		return "", fmt.Errorf("mapping test I RT failed: %w", err)
   129  	}
   130  
   131  	// Parse response message for XOR-MAPPED-ADDRESS and make sure OTHER-ADDRESS valid
   132  	resps1 := parse(resp)
   133  	if resps1.xorAddr == nil {
   134  		return "", ErrNoXorAddress
   135  	}
   136  	if resps1.otherAddr == nil {
   137  		return "", ErrNoOtherAddress
   138  	}
   139  	addr, err := net.ResolveUDPAddr("udp4", resps1.otherAddr.String())
   140  	if err != nil {
   141  		return "", fmt.Errorf("other-address resolve failed: %w", err)
   142  	}
   143  	mapTestConn.OtherAddr = addr
   144  
   145  	// Assert mapping behavior
   146  	// TODO: it doesn't actually work because we bind at wildcard address
   147  	// so condition is always false
   148  	myIP, _, err := net.SplitHostPort(resps1.xorAddr.String())
   149  	if err != nil {
   150  		return "", fmt.Errorf("can't parse mirrored address: %w", err)
   151  	}
   152  	outboundIP, err := getOutboundIP(mapTestConn.RemoteAddr.String())
   153  	if err != nil {
   154  		return "", fmt.Errorf("can't get outbound address: %w", err)
   155  	}
   156  	if myIP == outboundIP.String() {
   157  		return MappingNone, nil
   158  	}
   159  
   160  	// Test II: Send binding request to the other address but primary port
   161  	ctx1, cl = context.WithTimeout(ctx, timeout)
   162  	defer cl()
   163  	oaddr := *mapTestConn.OtherAddr
   164  	oaddr.Port = mapTestConn.RemoteAddr.Port
   165  	resp, err = mapTestConn.roundTrip(ctx1, request, &oaddr)
   166  	if err != nil {
   167  		return "", fmt.Errorf("mapping test II RT failed: %w", err)
   168  	}
   169  
   170  	// Assert mapping behavior
   171  	resps2 := parse(resp)
   172  	if resps2.xorAddr.String() == resps1.xorAddr.String() {
   173  		return MappingIndependent, nil
   174  	}
   175  
   176  	// Test III: Send binding request to the other address and port
   177  	ctx1, cl = context.WithTimeout(ctx, timeout)
   178  	defer cl()
   179  	resp, err = mapTestConn.roundTrip(ctx1, request, mapTestConn.OtherAddr)
   180  	if err != nil {
   181  		return "", fmt.Errorf("mapping test III RT failed: %w", err)
   182  	}
   183  
   184  	// Assert mapping behavior
   185  	resps3 := parse(resp)
   186  	if resps3.xorAddr.String() == resps2.xorAddr.String() {
   187  		return MappingAddressDependent, nil
   188  	}
   189  
   190  	return MappingAddressPortDependent, nil
   191  }
   192  
   193  // DiscoverNATFiltering returns either one of FILTERING_* constants describing
   194  // NAT filtering behavior or error
   195  func DiscoverNATFiltering(ctx context.Context, address string, timeout time.Duration) (string, error) {
   196  	mapTestConn, err := connect(address)
   197  	if err != nil {
   198  		return "", fmt.Errorf("STUN connection init failed: %w", err)
   199  	}
   200  	defer mapTestConn.Close()
   201  
   202  	// Test I: Regular binding request
   203  	request := stun.MustBuild(stun.TransactionID, stun.BindingRequest)
   204  
   205  	ctx1, cl := context.WithTimeout(ctx, timeout)
   206  	defer cl()
   207  	resp, err := mapTestConn.roundTrip(ctx1, request, mapTestConn.RemoteAddr)
   208  	if err != nil {
   209  		return "", fmt.Errorf("filtering test I RT failed: %w", err)
   210  	}
   211  	resps := parse(resp)
   212  	if resps.xorAddr == nil {
   213  		return "", fmt.Errorf("filtering test I got bad response: %w", ErrNoXorAddress)
   214  	}
   215  
   216  	if resps.otherAddr == nil {
   217  		return "", fmt.Errorf("filtering test I got bad response: %w", ErrNoOtherAddress)
   218  	}
   219  	addr, err := net.ResolveUDPAddr("udp4", resps.otherAddr.String())
   220  	if err != nil {
   221  		return "", fmt.Errorf("other-address resolve failed: %w", err)
   222  	}
   223  	mapTestConn.OtherAddr = addr
   224  
   225  	// Test II: Request to change both IP and port
   226  	request = stun.MustBuild(stun.TransactionID, stun.BindingRequest)
   227  	request.Add(stun.AttrChangeRequest, changeRequestAddressPort)
   228  
   229  	ctx1, cl = context.WithTimeout(ctx, timeout)
   230  	defer cl()
   231  	_, err = mapTestConn.roundTrip(ctx1, request, mapTestConn.RemoteAddr)
   232  	switch {
   233  	case err == nil:
   234  		return FilteringIndependent, nil
   235  	case err == ctx1.Err() && ctx.Err() == nil:
   236  		// Nothing, just no response. Proceed to next test.
   237  	default:
   238  		return "", fmt.Errorf("filtering test II failed: %w", err)
   239  	}
   240  
   241  	// Test III: Request to change port only
   242  	request = stun.MustBuild(stun.TransactionID, stun.BindingRequest)
   243  	request.Add(stun.AttrChangeRequest, changeRequestPort)
   244  
   245  	ctx1, cl = context.WithTimeout(ctx, timeout)
   246  	defer cl()
   247  	_, err = mapTestConn.roundTrip(ctx1, request, mapTestConn.RemoteAddr)
   248  	switch {
   249  	case err == nil:
   250  		return FilteringAddress, nil
   251  	case err == ctx1.Err() && ctx.Err() == nil:
   252  		return FilteringAddressPort, nil
   253  	default:
   254  		return "", fmt.Errorf("filtering test III failed: %w", err)
   255  	}
   256  }
   257  
   258  // Parse a STUN message
   259  func parse(msg *stun.Message) (ret struct {
   260  	xorAddr    *stun.XORMappedAddress
   261  	otherAddr  *stun.OtherAddress
   262  	mappedAddr *stun.MappedAddress
   263  	software   *stun.Software
   264  }) {
   265  	ret.mappedAddr = &stun.MappedAddress{}
   266  	ret.xorAddr = &stun.XORMappedAddress{}
   267  	ret.otherAddr = &stun.OtherAddress{}
   268  	ret.software = &stun.Software{}
   269  	if ret.xorAddr.GetFrom(msg) != nil {
   270  		ret.xorAddr = nil
   271  	}
   272  	if ret.otherAddr.GetFrom(msg) != nil {
   273  		ret.otherAddr = nil
   274  	}
   275  	if ret.mappedAddr.GetFrom(msg) != nil {
   276  		ret.mappedAddr = nil
   277  	}
   278  	if ret.software.GetFrom(msg) != nil {
   279  		ret.software = nil
   280  	}
   281  	return ret
   282  }
   283  
   284  // Given an address string, returns a stunServerConn
   285  func connect(address string) (*stunServerConn, error) {
   286  	addr, err := net.ResolveUDPAddr("udp4", address)
   287  	if err != nil {
   288  		return nil, err
   289  	}
   290  
   291  	c, err := net.ListenUDP("udp4", nil)
   292  	if err != nil {
   293  		return nil, err
   294  	}
   295  	serverConn := &stunServerConn{
   296  		conn:        c,
   297  		LocalAddr:   c.LocalAddr(),
   298  		RemoteAddr:  addr,
   299  		messageChan: make(chan *stun.Message),
   300  		stopChan:    make(chan struct{}),
   301  	}
   302  	serverConn.listen()
   303  	return serverConn, nil
   304  }
   305  
   306  // Send request and wait for response or timeout
   307  func (c *stunServerConn) roundTrip(ctx context.Context, msg *stun.Message, addr net.Addr) (*stun.Message, error) {
   308  	err := msg.NewTransactionID()
   309  	if err != nil {
   310  		return nil, err
   311  	}
   312  	_, err = c.conn.WriteTo(msg.Raw, addr)
   313  	if err != nil {
   314  		return nil, err
   315  	}
   316  
   317  	// Wait for response or timeout
   318  	for {
   319  		select {
   320  		case m, ok := <-c.messageChan:
   321  			if !ok {
   322  				return nil, ErrResponseMessage
   323  			}
   324  			if m.TransactionID == msg.TransactionID {
   325  				return m, nil
   326  			}
   327  		case <-ctx.Done():
   328  			return nil, ctx.Err()
   329  		}
   330  	}
   331  }
   332  
   333  func (c *stunServerConn) listen() {
   334  	go func() {
   335  		defer close(c.messageChan)
   336  		for {
   337  			buf := make([]byte, 1024)
   338  
   339  			n, _, err := c.conn.ReadFromUDP(buf)
   340  			if err != nil {
   341  				if n == 0 {
   342  					return
   343  				}
   344  				continue
   345  			}
   346  			buf = buf[:n]
   347  
   348  			m := new(stun.Message)
   349  			m.Raw = buf
   350  			err = m.Decode()
   351  			if err != nil {
   352  				continue
   353  			}
   354  
   355  			select {
   356  			case c.messageChan <- m:
   357  			case <-c.stopChan:
   358  				return
   359  			}
   360  		}
   361  	}()
   362  	return
   363  }
   364  
   365  func getOutboundIP(remoteAddr string) (net.IP, error) {
   366  	dialer := net.Dialer{}
   367  
   368  	conn, err := dialer.Dial("udp4", remoteAddr)
   369  	if err != nil {
   370  		return nil, fmt.Errorf("failed to determine outbound IP: %w", err)
   371  	}
   372  	defer conn.Close()
   373  
   374  	return conn.LocalAddr().(*net.UDPAddr).IP, nil
   375  }