github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/mobile/mysterium/wireguard_connection_setup_test.go (about) 1 /* 2 * Copyright (C) 2020 The "MysteriumNetwork/node" Authors. 3 * 4 * This program is free software: you can redistribute it and/or modify 5 * it under the terms of the GNU General Public License as published by 6 * the Free Software Foundation, either version 3 of the License, or 7 * (at your option) any later version. 8 * 9 * This program is distributed in the hope that it will be useful, 10 * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 * GNU General Public License for more details. 13 * 14 * You should have received a copy of the GNU General Public License 15 * along with this program. If not, see <http://www.gnu.org/licenses/>. 16 */ 17 18 package mysterium 19 20 import ( 21 "context" 22 "encoding/json" 23 "errors" 24 "net" 25 "testing" 26 "time" 27 28 "github.com/stretchr/testify/assert" 29 30 "github.com/mysteriumnetwork/node/core/connection" 31 "github.com/mysteriumnetwork/node/core/connection/connectionstate" 32 "github.com/mysteriumnetwork/node/core/ip" 33 wg "github.com/mysteriumnetwork/node/services/wireguard" 34 "github.com/mysteriumnetwork/node/services/wireguard/wgcfg" 35 ) 36 37 func TestConnectionStartStop(t *testing.T) { 38 conn := newConn(t) 39 40 // Start connection. 41 sessionConfig, _ := json.Marshal(newServiceConfig()) 42 err := conn.Start(context.Background(), connection.ConnectOptions{ 43 Params: connection.ConnectParams{DNS: "1.2.3.4"}, 44 SessionConfig: sessionConfig, 45 }) 46 47 assert.NoError(t, err) 48 assert.Equal(t, connectionstate.Connecting, <-conn.State()) 49 assert.Equal(t, connectionstate.Connected, <-conn.State()) 50 stats, err := conn.Statistics() 51 assert.NoError(t, err) 52 assert.EqualValues(t, 10, stats.BytesSent) 53 assert.EqualValues(t, 11, stats.BytesReceived) 54 55 // Stop connection. 56 go func() { 57 conn.Stop() 58 }() 59 assert.NoError(t, err) 60 } 61 62 func TestConnectionStopAfterHandshakeError(t *testing.T) { 63 conn := newConn(t) 64 handshakeTimeoutErr := errors.New("handshake timeout") 65 conn.handshakeWaiter = &mockHandshakeWaiter{err: handshakeTimeoutErr} 66 sessionConfig, _ := json.Marshal(newServiceConfig()) 67 68 err := conn.Start(context.Background(), connection.ConnectOptions{SessionConfig: sessionConfig}) 69 assert.Error(t, handshakeTimeoutErr, err) 70 assert.Equal(t, connectionstate.Connecting, <-conn.State()) 71 assert.Equal(t, connectionstate.Disconnecting, <-conn.State()) 72 assert.Equal(t, connectionstate.NotConnected, <-conn.State()) 73 } 74 75 func TestConnectionStopOnceAfterHandshakeErrorAndStopCall(t *testing.T) { 76 conn := newConn(t) 77 handshakeTimeoutErr := errors.New("handshake timeout") 78 conn.handshakeWaiter = &mockHandshakeWaiter{err: handshakeTimeoutErr} 79 sessionConfig, _ := json.Marshal(newServiceConfig()) 80 81 err := conn.Start(context.Background(), connection.ConnectOptions{SessionConfig: sessionConfig}) 82 83 stopCh := make(chan struct{}) 84 go func() { 85 conn.Stop() 86 stopCh <- struct{}{} 87 }() 88 <-stopCh 89 90 assert.Error(t, handshakeTimeoutErr, err) 91 assert.Equal(t, connectionstate.Connecting, <-conn.State()) 92 assert.Equal(t, connectionstate.Disconnecting, <-conn.State()) 93 assert.Equal(t, connectionstate.NotConnected, <-conn.State()) 94 } 95 96 func newConn(t *testing.T) *wireguardConnection { 97 opts := wireGuardOptions{ 98 statsUpdateInterval: 1 * time.Millisecond, 99 } 100 conn, err := NewWireGuardConnection(opts, &mockWireGuardDevice{}, ip.NewResolverMock("172.44.1.12"), &mockHandshakeWaiter{}) 101 assert.NoError(t, err) 102 return conn.(*wireguardConnection) 103 } 104 105 func newServiceConfig() wg.ServiceConfig { 106 endpoint, _ := net.ResolveUDPAddr("udp4", "127.0.0.1:51001") 107 return wg.ServiceConfig{ 108 LocalPort: 51000, 109 RemotePort: 51001, 110 Provider: struct { 111 PublicKey string 112 Endpoint net.UDPAddr 113 }{ 114 PublicKey: "wg1", 115 Endpoint: *endpoint, 116 }, 117 Consumer: struct { 118 IPAddress net.IPNet 119 DNSIPs string 120 }{ 121 IPAddress: net.IPNet{ 122 IP: net.IPv4(127, 0, 0, 1), 123 Mask: net.IPv4Mask(255, 255, 255, 128), 124 }, 125 DNSIPs: "128.0.0.1", 126 }, 127 } 128 } 129 130 type mockWireGuardDevice struct{} 131 132 func (m mockWireGuardDevice) Start(_ string, _ wg.ServiceConfig, _ *net.UDPConn, _ connection.DNSOption) error { 133 return nil 134 } 135 136 func (m mockWireGuardDevice) Stop() { 137 } 138 139 func (m mockWireGuardDevice) Stats() (wgcfg.Stats, error) { 140 return wgcfg.Stats{BytesSent: 10, BytesReceived: 11}, nil 141 } 142 143 type mockHandshakeWaiter struct { 144 err error 145 } 146 147 func (m *mockHandshakeWaiter) Wait(ctx context.Context, statsFetch func() (wgcfg.Stats, error), timeout time.Duration, stop <-chan struct{}) error { 148 return m.err 149 }