github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/nat/traversal/pinger.go (about)

     1  /*
     2   * Copyright (C) 2019 The "MysteriumNetwork/node" Authors.
     3   *
     4   * This program is free software: you can redistribute it and/or modify
     5   * it under the terms of the GNU General Public License as published by
     6   * the Free Software Foundation, either version 3 of the License, or
     7   * (at your option) any later version.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   */
    17  
    18  package traversal
    19  
    20  import (
    21  	"context"
    22  	"errors"
    23  	"fmt"
    24  	"net"
    25  	"sort"
    26  	"strings"
    27  	"sync"
    28  	"time"
    29  
    30  	"github.com/rs/zerolog/log"
    31  	"golang.org/x/net/ipv4"
    32  
    33  	"github.com/mysteriumnetwork/node/core/port"
    34  	"github.com/mysteriumnetwork/node/eventbus"
    35  	"github.com/mysteriumnetwork/node/nat/event"
    36  	"github.com/mysteriumnetwork/node/router"
    37  )
    38  
    39  // StageName represents hole-punching stage of NAT traversal
    40  const StageName = "hole_punching"
    41  
    42  const (
    43  	bufferLen = 64
    44  
    45  	maxTTL            = 128
    46  	msgOK             = "OK"
    47  	msgOKACK          = "OK_ACK"
    48  	msgPing           = "continuously pinging to "
    49  	sendRetryInterval = 5 * time.Millisecond
    50  	sendRetries       = 10
    51  )
    52  
    53  // ErrTooFew indicates there were too few successful ping
    54  // responses to build requested number of connections
    55  var ErrTooFew = errors.New("too few connections were built")
    56  
    57  // NATPinger is responsible for pinging nat holes
    58  type NATPinger interface {
    59  	PingProviderPeer(ctx context.Context, localIP, remoteIP string, localPorts, remotePorts []int, initialTTL int, n int) (conns []*net.UDPConn, err error)
    60  	PingConsumerPeer(ctx context.Context, id string, remoteIP string, localPorts, remotePorts []int, initialTTL int, n int) (conns []*net.UDPConn, err error)
    61  }
    62  
    63  // PingConfig represents NAT pinger config.
    64  type PingConfig struct {
    65  	Interval            time.Duration
    66  	Timeout             time.Duration
    67  	SendConnACKInterval time.Duration
    68  }
    69  
    70  // DefaultPingConfig returns default NAT pinger config.
    71  func DefaultPingConfig() *PingConfig {
    72  	return &PingConfig{
    73  		Interval:            5 * time.Millisecond,
    74  		Timeout:             10 * time.Second,
    75  		SendConnACKInterval: 100 * time.Millisecond,
    76  	}
    77  }
    78  
    79  // Pinger represents NAT pinger structure
    80  type Pinger struct {
    81  	pingConfig     *PingConfig
    82  	eventPublisher eventbus.Publisher
    83  }
    84  
    85  // PortSupplier provides port needed to run a service on
    86  type PortSupplier interface {
    87  	Acquire() (port.Port, error)
    88  }
    89  
    90  // NewPinger returns Pinger instance
    91  func NewPinger(pingConfig *PingConfig, publisher eventbus.Publisher) NATPinger {
    92  	return &Pinger{
    93  		pingConfig:     pingConfig,
    94  		eventPublisher: publisher,
    95  	}
    96  }
    97  
    98  func drainPingResponses(responses <-chan pingResponse) {
    99  	for response := range responses {
   100  		log.Warn().Err(response.err).Msgf("Sanitizing ping response on %#v", response)
   101  		if response.conn != nil {
   102  			response.conn.Close()
   103  			log.Warn().Msgf("Collected dangling socket: %s", response.conn.LocalAddr().String())
   104  		}
   105  	}
   106  }
   107  
   108  func cleanupConnections(responses []pingResponse) {
   109  	for _, response := range responses {
   110  		response.conn.Close()
   111  	}
   112  }
   113  
   114  // PingConsumerPeer pings remote peer with a defined configuration
   115  // and notifies peer which connections will be used.
   116  // It returns n connections if possible or error.
   117  func (p *Pinger) PingConsumerPeer(ctx context.Context, id string, remoteIP string, localPorts, remotePorts []int, initialTTL int, n int) ([]*net.UDPConn, error) {
   118  	ctx, cancel := context.WithTimeout(ctx, p.pingConfig.Timeout)
   119  	defer cancel()
   120  
   121  	log.Info().Msg("NAT pinging to remote peer")
   122  
   123  	stop := make(chan struct{})
   124  	defer close(stop)
   125  
   126  	ch, err := p.multiPingN(ctx, "", remoteIP, localPorts, remotePorts, initialTTL, n)
   127  	if err != nil {
   128  		log.Err(err).Msg("Failed to ping remote peer")
   129  		return nil, err
   130  	}
   131  
   132  	pingsCh := make(chan pingResponse, n)
   133  	go func() {
   134  		var wg sync.WaitGroup
   135  		for res := range ch {
   136  			if res.err != nil {
   137  				if !errors.Is(res.err, context.Canceled) {
   138  					log.Warn().Err(res.err).Msg("One of the pings has error")
   139  				}
   140  				continue
   141  			}
   142  
   143  			if err := ipv4.NewConn(res.conn).SetTTL(maxTTL); err != nil {
   144  				res.conn.Close()
   145  				log.Warn().Err(res.err).Msg("Failed to set connection TTL")
   146  				continue
   147  			}
   148  
   149  			p.sendMsg(res.conn, msgOK) // Notify peer that we are using this connection.
   150  
   151  			wg.Add(1)
   152  			go func(ping pingResponse) {
   153  				defer wg.Done()
   154  				if err := p.sendConnACK(ctx, ping.conn); err != nil {
   155  					if !errors.Is(err, context.Canceled) {
   156  						log.Warn().Err(err).Msg("Failed to send conn ACK to consumer")
   157  					}
   158  					ping.conn.Close()
   159  				} else {
   160  					pingsCh <- ping
   161  				}
   162  			}(res)
   163  		}
   164  		wg.Wait()
   165  		close(pingsCh)
   166  	}()
   167  
   168  	var pings []pingResponse
   169  	for ping := range pingsCh {
   170  		pings = append(pings, ping)
   171  		if len(pings) == n {
   172  			p.eventPublisher.Publish(event.AppTopicTraversal, event.BuildSuccessfulEvent(id, StageName))
   173  			go drainPingResponses(pingsCh)
   174  			return sortedConns(pings), nil
   175  		}
   176  	}
   177  
   178  	p.eventPublisher.Publish(event.AppTopicTraversal, event.BuildFailureEvent(id, StageName, ErrTooFew))
   179  	cleanupConnections(pings)
   180  	return nil, ErrTooFew
   181  }
   182  
   183  // PingProviderPeer pings remote peer with a defined configuration
   184  // and waits for peer to send ack with connection selected ids.
   185  // It returns n connections if possible or error.
   186  func (p *Pinger) PingProviderPeer(ctx context.Context, localIP, remoteIP string, localPorts, remotePorts []int, initialTTL int, n int) ([]*net.UDPConn, error) {
   187  	ctx, cancel := context.WithTimeout(ctx, p.pingConfig.Timeout)
   188  	defer cancel()
   189  
   190  	log.Info().Msg("NAT pinging to remote peer")
   191  
   192  	ch, err := p.multiPingN(ctx, localIP, remoteIP, localPorts, remotePorts, initialTTL, n)
   193  	if err != nil {
   194  		log.Err(err).Msg("Failed to ping remote peer")
   195  		return nil, err
   196  	}
   197  
   198  	pingsCh := make(chan pingResponse, len(localPorts))
   199  	go func() {
   200  		var wg sync.WaitGroup
   201  		for res := range ch {
   202  			if res.err != nil {
   203  				if !errors.Is(res.err, context.Canceled) {
   204  					log.Warn().Err(res.err).Msg("One of the pings has error")
   205  				}
   206  				continue
   207  			}
   208  
   209  			if err := ipv4.NewConn(res.conn).SetTTL(maxTTL); err != nil {
   210  				res.conn.Close()
   211  				log.Warn().Err(res.err).Msg("Failed to set connection TTL")
   212  				continue
   213  			}
   214  
   215  			p.sendMsg(res.conn, msgOK) // Notify peer that we are using this connection.
   216  
   217  			// Wait for peer to notify that it uses this connection too.
   218  			wg.Add(1)
   219  			go func(ping pingResponse) {
   220  				defer wg.Done()
   221  				if err := p.waitMsg(ctx, ping.conn, msgOK); err != nil {
   222  					if !errors.Is(err, context.Canceled) {
   223  						log.Err(err).Msg("Failed to wait for conn OK from provider")
   224  					}
   225  					ping.conn.Close()
   226  					return
   227  				}
   228  				pingsCh <- ping
   229  			}(res)
   230  		}
   231  		wg.Wait()
   232  		close(pingsCh)
   233  	}()
   234  
   235  	var pings []pingResponse
   236  	for ping := range pingsCh {
   237  		pings = append(pings, ping)
   238  		p.sendMsg(ping.conn, msgOKACK)
   239  		if len(pings) == n {
   240  			go drainPingResponses(pingsCh)
   241  			return sortedConns(pings), nil
   242  		}
   243  	}
   244  
   245  	cleanupConnections(pings)
   246  	return nil, ErrTooFew
   247  }
   248  
   249  // sendConnACK notifies peer that we are using this connection
   250  // and waits for ack or returns timeout err.
   251  func (p *Pinger) sendConnACK(ctx context.Context, conn *net.UDPConn) error {
   252  	ackWaitErr := make(chan error)
   253  	go func() {
   254  		ackWaitErr <- p.waitMsg(ctx, conn, msgOKACK)
   255  	}()
   256  
   257  	for {
   258  		select {
   259  		case err := <-ackWaitErr:
   260  			return err
   261  		case <-time.After(p.pingConfig.SendConnACKInterval):
   262  			p.sendMsg(conn, msgOK)
   263  		}
   264  	}
   265  }
   266  
   267  func sortedConns(pings []pingResponse) []*net.UDPConn {
   268  	sort.Slice(pings, func(i, j int) bool {
   269  		return pings[i].id < pings[j].id
   270  	})
   271  	var conns []*net.UDPConn
   272  	for _, p := range pings {
   273  		conns = append(conns, p.conn)
   274  	}
   275  	return conns
   276  }
   277  
   278  // waitMsg waits until conn receives given message or timeouts.
   279  func (p *Pinger) waitMsg(ctx context.Context, conn *net.UDPConn, msg string) error {
   280  	var (
   281  		n   int
   282  		err error
   283  	)
   284  	buf := make([]byte, 1024)
   285  	// just reasonable upper boundary for receive errors to not enter infinite
   286  	// loop on closed socket, but still skim errors of closed port etc
   287  	// +1 in denominator is to avoid division by zero
   288  	recvErrLimit := 2 * int(p.pingConfig.Timeout/(p.pingConfig.Interval+1))
   289  	for errCount := 0; errCount < recvErrLimit; {
   290  		n, err = readFromConnWithContext(ctx, conn, buf)
   291  		if ctx.Err() != nil {
   292  			return ctx.Err()
   293  		}
   294  		// process returned data unconditionally as io.Reader dictates to
   295  		v := string(buf[:n])
   296  		if v == msg {
   297  			return nil
   298  		}
   299  		if err != nil {
   300  			errCount++
   301  			log.Error().Err(err).Msgf("got error in waitMsg, trying to recover. %d attempts left",
   302  				recvErrLimit-errCount)
   303  			continue
   304  		}
   305  	}
   306  	return fmt.Errorf("too many recv errors, last one: %w", err)
   307  }
   308  
   309  func (p *Pinger) sendMsg(conn *net.UDPConn, msg string) {
   310  	for i := 0; i < sendRetries; i++ {
   311  		_, err := conn.Write([]byte(msg))
   312  		if err != nil {
   313  			log.Error().Err(err).Msg("pinger message send failed")
   314  			time.Sleep(p.pingConfig.Interval)
   315  		} else {
   316  			return
   317  		}
   318  	}
   319  }
   320  
   321  func (p *Pinger) ping(ctx context.Context, conn *net.UDPConn, remoteAddr *net.UDPAddr, ttl int) error {
   322  	err := ipv4.NewConn(conn).SetTTL(ttl)
   323  	if err != nil {
   324  		return fmt.Errorf("pinger setting ttl failed: %w", err)
   325  	}
   326  
   327  	for {
   328  		select {
   329  		case <-ctx.Done():
   330  			return nil
   331  		case <-time.After(p.pingConfig.Interval):
   332  			_, err := conn.WriteToUDP([]byte(msgPing+remoteAddr.String()), remoteAddr)
   333  			if ctx.Err() != nil {
   334  				return nil
   335  			}
   336  			if err != nil {
   337  				return fmt.Errorf("pinging request failed: %w", err)
   338  			}
   339  		}
   340  	}
   341  }
   342  
   343  func readFromConnWithContext(ctx context.Context, conn net.Conn, buf []byte) (n int, err error) {
   344  	readDone := make(chan struct{})
   345  	go func() {
   346  		n, err = conn.Read(buf)
   347  		close(readDone)
   348  	}()
   349  
   350  	select {
   351  	case <-ctx.Done():
   352  		conn.SetReadDeadline(time.Unix(0, 0))
   353  		<-readDone
   354  		conn.SetReadDeadline(time.Time{})
   355  		return 0, ctx.Err()
   356  	case <-readDone:
   357  		return
   358  	}
   359  }
   360  
   361  func readFromUDPWithContext(ctx context.Context, conn *net.UDPConn, buf []byte) (n int, from *net.UDPAddr, err error) {
   362  	readDone := make(chan struct{})
   363  	go func() {
   364  		n, from, err = conn.ReadFromUDP(buf)
   365  		close(readDone)
   366  	}()
   367  
   368  	select {
   369  	case <-ctx.Done():
   370  		conn.SetReadDeadline(time.Unix(0, 0))
   371  		<-readDone
   372  		conn.SetReadDeadline(time.Time{})
   373  		return 0, nil, ctx.Err()
   374  	case <-readDone:
   375  		return
   376  	}
   377  }
   378  
   379  func (p *Pinger) pingReceiver(ctx context.Context, conn *net.UDPConn) (*net.UDPAddr, error) {
   380  	buf := make([]byte, bufferLen)
   381  
   382  	for {
   383  		n, raddr, err := readFromUDPWithContext(ctx, conn, buf)
   384  		if ctx.Err() != nil {
   385  			return nil, ctx.Err()
   386  		}
   387  
   388  		if err != nil || n == 0 {
   389  			log.Debug().Err(err).Msgf("Failed to read remote peer: %s - attempting to continue", raddr)
   390  			continue
   391  		}
   392  
   393  		msg := string(buf[:n])
   394  		log.Debug().Msgf("Remote peer data received, len: %d", n)
   395  
   396  		if msg == msgOK || strings.HasPrefix(msg, msgPing) {
   397  			return raddr, nil
   398  		}
   399  
   400  		log.Debug().Err(err).Msgf("Unexpected message: %s - attempting to continue", msg)
   401  	}
   402  }
   403  
   404  // Valid returns that this pinger is a valid pinger
   405  func (p *Pinger) Valid() bool {
   406  	return true
   407  }
   408  
   409  type pingResponse struct {
   410  	conn *net.UDPConn
   411  	err  error
   412  	id   int
   413  }
   414  
   415  func (p *Pinger) multiPingN(ctx context.Context, localIP, remoteIP string, localPorts, remotePorts []int, initialTTL int, n int) (<-chan pingResponse, error) {
   416  	if len(localPorts) != len(remotePorts) {
   417  		return nil, errors.New("number of local and remote ports does not match")
   418  	}
   419  
   420  	var wg sync.WaitGroup
   421  	ch := make(chan pingResponse, len(localPorts))
   422  	ttl := initialTTL
   423  	resetTTL := initialTTL + (len(localPorts) / n)
   424  
   425  	for i := range localPorts {
   426  		wg.Add(1)
   427  
   428  		go func(i, ttl int) {
   429  			defer wg.Done()
   430  			conn, err := p.singlePing(ctx, localIP, remoteIP, localPorts[i], remotePorts[i], ttl)
   431  			ch <- pingResponse{conn: conn, err: err, id: i}
   432  		}(i, ttl)
   433  
   434  		// TTL increase is only needed for provider side which starts with low TTL value.
   435  		if ttl < maxTTL {
   436  			ttl++
   437  		}
   438  		if ttl == resetTTL {
   439  			ttl = initialTTL
   440  		}
   441  	}
   442  
   443  	go func() { wg.Wait(); close(ch) }()
   444  
   445  	return ch, nil
   446  }
   447  
   448  func (p *Pinger) singlePing(ctx context.Context, localIP, remoteIP string, localPort, remotePort, ttl int) (*net.UDPConn, error) {
   449  	conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP(localIP), Port: localPort})
   450  	if err != nil {
   451  		return nil, fmt.Errorf("failed to get connection: %w", err)
   452  	}
   453  
   454  	defer conn.Close()
   455  
   456  	if err := router.ProtectUDPConn(conn); err != nil {
   457  		return nil, fmt.Errorf("failed to protect udp connection: %w", err)
   458  	}
   459  
   460  	log.Debug().Msgf("Local socket: %s", conn.LocalAddr())
   461  
   462  	remoteAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", remoteIP, remotePort))
   463  	if err != nil {
   464  		return nil, fmt.Errorf("failed to resolve remote address: %w", err)
   465  	}
   466  
   467  	ctx1, cl := context.WithCancel(ctx)
   468  	go func() {
   469  		err := p.ping(ctx1, conn, remoteAddr, ttl)
   470  		if err != nil {
   471  			log.Warn().Err(err).Msg("Error while pinging")
   472  		}
   473  	}()
   474  
   475  	laddr := conn.LocalAddr().(*net.UDPAddr)
   476  	raddr, err := p.pingReceiver(ctx, conn)
   477  	cl()
   478  	if err != nil {
   479  		return nil, fmt.Errorf("ping receiver error: %w", err)
   480  	}
   481  	// need to dial same connection further
   482  	conn.Close()
   483  
   484  	newConn, err := net.DialUDP("udp4", laddr, raddr)
   485  	if err != nil {
   486  		return nil, err
   487  	}
   488  
   489  	if err := router.ProtectUDPConn(newConn); err != nil {
   490  		newConn.Close()
   491  		return nil, fmt.Errorf("failed to protect udp connection: %w", err)
   492  	}
   493  
   494  	return newConn, nil
   495  }