github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/mobile/mysterium/wireguard_connection_setup_test.go (about)

     1  /*
     2   * Copyright (C) 2020 The "MysteriumNetwork/node" Authors.
     3   *
     4   * This program is free software: you can redistribute it and/or modify
     5   * it under the terms of the GNU General Public License as published by
     6   * the Free Software Foundation, either version 3 of the License, or
     7   * (at your option) any later version.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   */
    17  
    18  package mysterium
    19  
    20  import (
    21  	"context"
    22  	"encoding/json"
    23  	"errors"
    24  	"net"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/stretchr/testify/assert"
    29  
    30  	"github.com/mysteriumnetwork/node/core/connection"
    31  	"github.com/mysteriumnetwork/node/core/connection/connectionstate"
    32  	"github.com/mysteriumnetwork/node/core/ip"
    33  	wg "github.com/mysteriumnetwork/node/services/wireguard"
    34  	"github.com/mysteriumnetwork/node/services/wireguard/wgcfg"
    35  )
    36  
    37  func TestConnectionStartStop(t *testing.T) {
    38  	conn := newConn(t)
    39  
    40  	// Start connection.
    41  	sessionConfig, _ := json.Marshal(newServiceConfig())
    42  	err := conn.Start(context.Background(), connection.ConnectOptions{
    43  		Params:        connection.ConnectParams{DNS: "1.2.3.4"},
    44  		SessionConfig: sessionConfig,
    45  	})
    46  
    47  	assert.NoError(t, err)
    48  	assert.Equal(t, connectionstate.Connecting, <-conn.State())
    49  	assert.Equal(t, connectionstate.Connected, <-conn.State())
    50  	stats, err := conn.Statistics()
    51  	assert.NoError(t, err)
    52  	assert.EqualValues(t, 10, stats.BytesSent)
    53  	assert.EqualValues(t, 11, stats.BytesReceived)
    54  
    55  	// Stop connection.
    56  	go func() {
    57  		conn.Stop()
    58  	}()
    59  	assert.NoError(t, err)
    60  }
    61  
    62  func TestConnectionStopAfterHandshakeError(t *testing.T) {
    63  	conn := newConn(t)
    64  	handshakeTimeoutErr := errors.New("handshake timeout")
    65  	conn.handshakeWaiter = &mockHandshakeWaiter{err: handshakeTimeoutErr}
    66  	sessionConfig, _ := json.Marshal(newServiceConfig())
    67  
    68  	err := conn.Start(context.Background(), connection.ConnectOptions{SessionConfig: sessionConfig})
    69  	assert.Error(t, handshakeTimeoutErr, err)
    70  	assert.Equal(t, connectionstate.Connecting, <-conn.State())
    71  	assert.Equal(t, connectionstate.Disconnecting, <-conn.State())
    72  	assert.Equal(t, connectionstate.NotConnected, <-conn.State())
    73  }
    74  
    75  func TestConnectionStopOnceAfterHandshakeErrorAndStopCall(t *testing.T) {
    76  	conn := newConn(t)
    77  	handshakeTimeoutErr := errors.New("handshake timeout")
    78  	conn.handshakeWaiter = &mockHandshakeWaiter{err: handshakeTimeoutErr}
    79  	sessionConfig, _ := json.Marshal(newServiceConfig())
    80  
    81  	err := conn.Start(context.Background(), connection.ConnectOptions{SessionConfig: sessionConfig})
    82  
    83  	stopCh := make(chan struct{})
    84  	go func() {
    85  		conn.Stop()
    86  		stopCh <- struct{}{}
    87  	}()
    88  	<-stopCh
    89  
    90  	assert.Error(t, handshakeTimeoutErr, err)
    91  	assert.Equal(t, connectionstate.Connecting, <-conn.State())
    92  	assert.Equal(t, connectionstate.Disconnecting, <-conn.State())
    93  	assert.Equal(t, connectionstate.NotConnected, <-conn.State())
    94  }
    95  
    96  func newConn(t *testing.T) *wireguardConnection {
    97  	opts := wireGuardOptions{
    98  		statsUpdateInterval: 1 * time.Millisecond,
    99  	}
   100  	conn, err := NewWireGuardConnection(opts, &mockWireGuardDevice{}, ip.NewResolverMock("172.44.1.12"), &mockHandshakeWaiter{})
   101  	assert.NoError(t, err)
   102  	return conn.(*wireguardConnection)
   103  }
   104  
   105  func newServiceConfig() wg.ServiceConfig {
   106  	endpoint, _ := net.ResolveUDPAddr("udp4", "127.0.0.1:51001")
   107  	return wg.ServiceConfig{
   108  		LocalPort:  51000,
   109  		RemotePort: 51001,
   110  		Provider: struct {
   111  			PublicKey string
   112  			Endpoint  net.UDPAddr
   113  		}{
   114  			PublicKey: "wg1",
   115  			Endpoint:  *endpoint,
   116  		},
   117  		Consumer: struct {
   118  			IPAddress net.IPNet
   119  			DNSIPs    string
   120  		}{
   121  			IPAddress: net.IPNet{
   122  				IP:   net.IPv4(127, 0, 0, 1),
   123  				Mask: net.IPv4Mask(255, 255, 255, 128),
   124  			},
   125  			DNSIPs: "128.0.0.1",
   126  		},
   127  	}
   128  }
   129  
   130  type mockWireGuardDevice struct{}
   131  
   132  func (m mockWireGuardDevice) Start(_ string, _ wg.ServiceConfig, _ *net.UDPConn, _ connection.DNSOption) error {
   133  	return nil
   134  }
   135  
   136  func (m mockWireGuardDevice) Stop() {
   137  }
   138  
   139  func (m mockWireGuardDevice) Stats() (wgcfg.Stats, error) {
   140  	return wgcfg.Stats{BytesSent: 10, BytesReceived: 11}, nil
   141  }
   142  
   143  type mockHandshakeWaiter struct {
   144  	err error
   145  }
   146  
   147  func (m *mockHandshakeWaiter) Wait(ctx context.Context, statsFetch func() (wgcfg.Stats, error), timeout time.Duration, stop <-chan struct{}) error {
   148  	return m.err
   149  }