github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/services/wireguard/connection/handshake.go (about)

     1  /*
     2   * Copyright (C) 2020 The "MysteriumNetwork/node" Authors.
     3   *
     4   * This program is free software: you can redistribute it and/or modify
     5   * it under the terms of the GNU General Public License as published by
     6   * the Free Software Foundation, either version 3 of the License, or
     7   * (at your option) any later version.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   */
    17  
    18  package connection
    19  
    20  import (
    21  	"context"
    22  	"errors"
    23  	"fmt"
    24  	"net"
    25  	"time"
    26  
    27  	"github.com/mysteriumnetwork/node/services/wireguard/wgcfg"
    28  )
    29  
    30  // HandshakeWaiter waits for handshake.
    31  type HandshakeWaiter interface {
    32  	// Wait waits until WireGuard does initial handshake.
    33  	Wait(ctx context.Context, statsFetch func() (wgcfg.Stats, error), timeout time.Duration, stop <-chan struct{}) error
    34  }
    35  
    36  // NewHandshakeWaiter returns handshake waiter instance.
    37  func NewHandshakeWaiter() HandshakeWaiter {
    38  	return &handshakeWaiter{}
    39  }
    40  
    41  type handshakeWaiter struct{}
    42  
    43  func (h *handshakeWaiter) Wait(ctx context.Context, statsFetch func() (wgcfg.Stats, error), timeout time.Duration, stop <-chan struct{}) error {
    44  	// We need to send any packet to initialize handshake process.
    45  	handshakePingConn, err := net.DialTimeout("tcp", "8.8.8.8:53", 100*time.Millisecond)
    46  	if err == nil {
    47  		defer handshakePingConn.Close()
    48  	}
    49  	timeoutCh := time.After(timeout)
    50  	for {
    51  		select {
    52  		case <-time.After(100 * time.Millisecond):
    53  			stats, err := statsFetch()
    54  			if err != nil {
    55  				return fmt.Errorf("failed to fetch stats: %w", err)
    56  			}
    57  			if !stats.LastHandshake.IsZero() {
    58  				return nil
    59  			}
    60  		case <-timeoutCh:
    61  			return errors.New("failed to receive initial handshake")
    62  		case <-stop:
    63  			return errors.New("stop received")
    64  		case <-ctx.Done():
    65  			return ctx.Err()
    66  		}
    67  	}
    68  }