github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/p2p/stun.go (about)

     1  /*
     2   * Copyright (C) 2021 The "MysteriumNetwork/node" Authors.
     3   *
     4   * This program is free software: you can redistribute it and/or modify
     5   * it under the terms of the GNU General Public License as published by
     6   * the Free Software Foundation, either version 3 of the License, or
     7   * (at your option) any later version.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   */
    17  
    18  package p2p
    19  
    20  import (
    21  	"fmt"
    22  	"net"
    23  	"sync"
    24  	"time"
    25  
    26  	"github.com/pion/stun"
    27  	"github.com/rs/zerolog/log"
    28  
    29  	"github.com/mysteriumnetwork/node/config"
    30  	"github.com/mysteriumnetwork/node/eventbus"
    31  	"github.com/mysteriumnetwork/node/identity"
    32  	"github.com/mysteriumnetwork/node/requests/resolver"
    33  )
    34  
    35  // AppTopicSTUN represents the STUN detection topic.
    36  const AppTopicSTUN = "STUN detection"
    37  
    38  // STUNDetectionStatus represents information about detected NAT type using STUN servers.
    39  type STUNDetectionStatus struct {
    40  	Identity string
    41  	NATType  string
    42  }
    43  
    44  func stunPorts(identity identity.Identity, eventBus eventbus.Publisher, localPorts ...int) (remotePorts []int) {
    45  	serverList := config.GetStringSlice(config.FlagSTUNservers)
    46  	if len(serverList) == 0 {
    47  		return localPorts
    48  	}
    49  
    50  	m := make(map[int]int)
    51  
    52  	mu := sync.Mutex{}
    53  	wg := sync.WaitGroup{}
    54  	wg.Add(len(localPorts))
    55  
    56  	for _, p := range localPorts {
    57  		go func(p int) {
    58  			defer wg.Done()
    59  
    60  			resp := multiServerSTUN(serverList, p, 2)
    61  
    62  			mu.Lock()
    63  			defer mu.Unlock()
    64  
    65  			natType := "unknown"
    66  
    67  			for _, port := range resp {
    68  				if m[p] != 0 {
    69  					switch {
    70  					case m[p] == port && p == port:
    71  						natType = "full"
    72  					case m[p] == port:
    73  						natType = "restricted"
    74  					default:
    75  						natType = "fail"
    76  					}
    77  				}
    78  
    79  				if port != 0 {
    80  					m[p] = port
    81  				}
    82  			}
    83  
    84  			if natType == "fail" {
    85  				delete(m, p)
    86  			}
    87  
    88  			if eventBus != nil {
    89  				eventBus.Publish(AppTopicSTUN, STUNDetectionStatus{
    90  					Identity: identity.Address,
    91  					NATType:  natType,
    92  				})
    93  			}
    94  		}(p)
    95  	}
    96  
    97  	wg.Wait()
    98  
    99  	for _, p := range localPorts {
   100  		if port, ok := m[p]; ok {
   101  			remotePorts = append(remotePorts, port)
   102  		} else {
   103  			remotePorts = append(remotePorts, p)
   104  		}
   105  	}
   106  
   107  	return remotePorts
   108  }
   109  
   110  func multiServerSTUN(servers []string, p, limit int) (respPort []int) {
   111  	conn, err := net.ListenUDP("udp4", &net.UDPAddr{Port: p})
   112  	if err != nil {
   113  		log.Error().Err(err).Msg("failed to listen UDP address for STUN server")
   114  		return nil
   115  	}
   116  
   117  	if err := conn.SetDeadline(time.Now().Add(2 * time.Second)); err != nil {
   118  		log.Error().Err(err).Msg("failed to set connection deadline for STUN server")
   119  		return nil
   120  	}
   121  
   122  	defer conn.Close()
   123  
   124  	ch := make(chan int, len(servers))
   125  	wg := sync.WaitGroup{}
   126  	wg.Add(len(servers))
   127  
   128  	go func() {
   129  		wg.Wait()
   130  		close(ch)
   131  	}()
   132  
   133  	for _, server := range servers {
   134  		go func(server string) {
   135  			defer wg.Done()
   136  
   137  			port, err := stunPort(conn, server)
   138  			if err != nil {
   139  				log.Trace().Err(err).Msg("failed to get public UDP port from STUN server")
   140  				return
   141  			}
   142  
   143  			ch <- port
   144  		}(server)
   145  	}
   146  
   147  	for p := range ch {
   148  		respPort = append(respPort, p)
   149  		if len(respPort) == limit {
   150  			return respPort
   151  		}
   152  	}
   153  
   154  	return respPort
   155  }
   156  
   157  func stunPort(conn *net.UDPConn, server string) (remotePort int, err error) {
   158  	host, port, err := net.SplitHostPort(server)
   159  	if err != nil {
   160  		return 0, fmt.Errorf("failed to parse STUN server address: %w", err)
   161  	}
   162  
   163  	if addrs := resolver.FetchDNSFromCache(host); len(addrs) > 0 {
   164  		server = net.JoinHostPort(addrs[0], port)
   165  	}
   166  
   167  	serverAddr, err := net.ResolveUDPAddr("udp", server)
   168  	if err != nil {
   169  		return 0, fmt.Errorf("failed to resolve STUN server address: %w", err)
   170  	}
   171  
   172  	resolver.CacheDNSRecord(host, []string{serverAddr.IP.String()})
   173  
   174  	m := stun.MustBuild(stun.TransactionID, stun.BindingRequest)
   175  
   176  	if _, err = conn.WriteToUDP(m.Raw, serverAddr); err != nil {
   177  		return 0, fmt.Errorf("failed to send binding request to STUN server: %w", err)
   178  	}
   179  
   180  	msg := make([]byte, 1024)
   181  
   182  	n, _, err := conn.ReadFromUDP(msg)
   183  	if err != nil {
   184  		return 0, fmt.Errorf("failed to read message from STUN server: %w", err)
   185  	}
   186  
   187  	msg = msg[:n]
   188  
   189  	if !stun.IsMessage(msg) {
   190  		return 0, fmt.Errorf("not correct response from STUN server")
   191  	}
   192  
   193  	resp := &stun.Message{Raw: msg}
   194  
   195  	if err := resp.Decode(); err != nil {
   196  		return 0, fmt.Errorf("failed to decode STUN server message: %w", err)
   197  	}
   198  
   199  	var xorAddr stun.XORMappedAddress
   200  	if err := xorAddr.GetFrom(resp); err != nil {
   201  		return 0, fmt.Errorf("failed to decode STUN server message: %w", err)
   202  	}
   203  
   204  	return xorAddr.Port, nil
   205  }