github.com/cawidtu/notwireguard-go/device@v0.0.0-20230523131112-68e8e5ce9cdf/device_test.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "bytes" 10 "encoding/hex" 11 "fmt" 12 "io" 13 "math/rand" 14 "net/netip" 15 "runtime" 16 "runtime/pprof" 17 "sync" 18 "sync/atomic" 19 "testing" 20 "time" 21 22 "github.com/cawidtu/notwireguard-go/conn" 23 "github.com/cawidtu/notwireguard-go/conn/bindtest" 24 "github.com/cawidtu/notwireguard-go/tun/tuntest" 25 ) 26 27 // uapiCfg returns a string that contains cfg formatted use with IpcSet. 28 // cfg is a series of alternating key/value strings. 29 // uapiCfg exists because editors and humans like to insert 30 // whitespace into configs, which can cause failures, some of which are silent. 31 // For example, a leading blank newline causes the remainder 32 // of the config to be silently ignored. 33 func uapiCfg(cfg ...string) string { 34 if len(cfg)%2 != 0 { 35 panic("odd number of args to uapiReader") 36 } 37 buf := new(bytes.Buffer) 38 for i, s := range cfg { 39 buf.WriteString(s) 40 sep := byte('\n') 41 if i%2 == 0 { 42 sep = '=' 43 } 44 buf.WriteByte(sep) 45 } 46 return buf.String() 47 } 48 49 // genConfigs generates a pair of configs that connect to each other. 50 // The configs use distinct, probably-usable ports. 51 func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { 52 var key1, key2 NoisePrivateKey 53 _, err := rand.Read(key1[:]) 54 if err != nil { 55 tb.Errorf("unable to generate private key random bytes: %v", err) 56 } 57 _, err = rand.Read(key2[:]) 58 if err != nil { 59 tb.Errorf("unable to generate private key random bytes: %v", err) 60 } 61 pub1, pub2 := key1.publicKey(), key2.publicKey() 62 63 cfgs[0] = uapiCfg( 64 "private_key", hex.EncodeToString(key1[:]), 65 "listen_port", "0", 66 "replace_peers", "true", 67 "public_key", hex.EncodeToString(pub2[:]), 68 "protocol_version", "1", 69 "replace_allowed_ips", "true", 70 "allowed_ip", "1.0.0.2/32", 71 ) 72 endpointCfgs[0] = uapiCfg( 73 "public_key", hex.EncodeToString(pub2[:]), 74 "endpoint", "127.0.0.1:%d", 75 ) 76 cfgs[1] = uapiCfg( 77 "private_key", hex.EncodeToString(key2[:]), 78 "listen_port", "0", 79 "replace_peers", "true", 80 "public_key", hex.EncodeToString(pub1[:]), 81 "protocol_version", "1", 82 "replace_allowed_ips", "true", 83 "allowed_ip", "1.0.0.1/32", 84 ) 85 endpointCfgs[1] = uapiCfg( 86 "public_key", hex.EncodeToString(pub1[:]), 87 "endpoint", "127.0.0.1:%d", 88 ) 89 return 90 } 91 92 // A testPair is a pair of testPeers. 93 type testPair [2]testPeer 94 95 // A testPeer is a peer used for testing. 96 type testPeer struct { 97 tun *tuntest.ChannelTUN 98 dev *Device 99 ip netip.Addr 100 } 101 102 type SendDirection bool 103 104 const ( 105 Ping SendDirection = true 106 Pong SendDirection = false 107 ) 108 109 func (d SendDirection) String() string { 110 if d == Ping { 111 return "ping" 112 } 113 return "pong" 114 } 115 116 func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) { 117 tb.Helper() 118 p0, p1 := pair[0], pair[1] 119 if !ping { 120 // pong is the new ping 121 p0, p1 = p1, p0 122 } 123 msg := tuntest.Ping(p0.ip, p1.ip) 124 p1.tun.Outbound <- msg 125 timer := time.NewTimer(5 * time.Second) 126 defer timer.Stop() 127 var err error 128 select { 129 case msgRecv := <-p0.tun.Inbound: 130 if !bytes.Equal(msg, msgRecv) { 131 err = fmt.Errorf("%s did not transit correctly", ping) 132 } 133 case <-timer.C: 134 err = fmt.Errorf("%s did not transit", ping) 135 case <-done: 136 } 137 if err != nil { 138 // The error may have occurred because the test is done. 139 select { 140 case <-done: 141 return 142 default: 143 } 144 // Real error. 145 tb.Error(err) 146 } 147 } 148 149 // genTestPair creates a testPair. 150 func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { 151 cfg, endpointCfg := genConfigs(tb) 152 var binds [2]conn.Bind 153 if realSocket { 154 binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind() 155 } else { 156 binds = bindtest.NewChannelBinds() 157 } 158 // Bring up a ChannelTun for each config. 159 for i := range pair { 160 p := &pair[i] 161 p.tun = tuntest.NewChannelTUN() 162 p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)}) 163 level := LogLevelVerbose 164 if _, ok := tb.(*testing.B); ok && !testing.Verbose() { 165 level = LogLevelError 166 } 167 p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i))) 168 if err := p.dev.IpcSet(cfg[i]); err != nil { 169 tb.Errorf("failed to configure device %d: %v", i, err) 170 p.dev.Close() 171 continue 172 } 173 if err := p.dev.Up(); err != nil { 174 tb.Errorf("failed to bring up device %d: %v", i, err) 175 p.dev.Close() 176 continue 177 } 178 endpointCfg[i^1] = fmt.Sprintf(endpointCfg[i^1], p.dev.net.port) 179 } 180 for i := range pair { 181 p := &pair[i] 182 if err := p.dev.IpcSet(endpointCfg[i]); err != nil { 183 tb.Errorf("failed to configure device endpoint %d: %v", i, err) 184 p.dev.Close() 185 continue 186 } 187 // The device is ready. Close it when the test completes. 188 tb.Cleanup(p.dev.Close) 189 } 190 return 191 } 192 193 func TestTwoDevicePing(t *testing.T) { 194 goroutineLeakCheck(t) 195 pair := genTestPair(t, true) 196 t.Run("ping 1.0.0.1", func(t *testing.T) { 197 pair.Send(t, Ping, nil) 198 }) 199 t.Run("ping 1.0.0.2", func(t *testing.T) { 200 pair.Send(t, Pong, nil) 201 }) 202 } 203 204 func TestUpDown(t *testing.T) { 205 goroutineLeakCheck(t) 206 const itrials = 50 207 const otrials = 10 208 209 for n := 0; n < otrials; n++ { 210 pair := genTestPair(t, false) 211 for i := range pair { 212 for k := range pair[i].dev.peers.keyMap { 213 pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:]))) 214 } 215 } 216 var wg sync.WaitGroup 217 wg.Add(len(pair)) 218 for i := range pair { 219 go func(d *Device) { 220 defer wg.Done() 221 for i := 0; i < itrials; i++ { 222 if err := d.Up(); err != nil { 223 t.Errorf("failed up bring up device: %v", err) 224 } 225 time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) 226 if err := d.Down(); err != nil { 227 t.Errorf("failed to bring down device: %v", err) 228 } 229 time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) 230 } 231 }(pair[i].dev) 232 } 233 wg.Wait() 234 for i := range pair { 235 pair[i].dev.Up() 236 pair[i].dev.Close() 237 } 238 } 239 } 240 241 // TestConcurrencySafety does other things concurrently with tunnel use. 242 // It is intended to be used with the race detector to catch data races. 243 func TestConcurrencySafety(t *testing.T) { 244 pair := genTestPair(t, true) 245 done := make(chan struct{}) 246 247 const warmupIters = 10 248 var warmup sync.WaitGroup 249 warmup.Add(warmupIters) 250 go func() { 251 // Send data continuously back and forth until we're done. 252 // Note that we may continue to attempt to send data 253 // even after done is closed. 254 i := warmupIters 255 for ping := Ping; ; ping = !ping { 256 pair.Send(t, ping, done) 257 select { 258 case <-done: 259 return 260 default: 261 } 262 if i > 0 { 263 warmup.Done() 264 i-- 265 } 266 } 267 }() 268 warmup.Wait() 269 270 applyCfg := func(cfg string) { 271 err := pair[0].dev.IpcSet(cfg) 272 if err != nil { 273 t.Fatal(err) 274 } 275 } 276 277 // Change persistent_keepalive_interval concurrently with tunnel use. 278 t.Run("persistentKeepaliveInterval", func(t *testing.T) { 279 var pub NoisePublicKey 280 for key := range pair[0].dev.peers.keyMap { 281 pub = key 282 break 283 } 284 cfg := uapiCfg( 285 "public_key", hex.EncodeToString(pub[:]), 286 "persistent_keepalive_interval", "1", 287 ) 288 for i := 0; i < 1000; i++ { 289 applyCfg(cfg) 290 } 291 }) 292 293 // Change private keys concurrently with tunnel use. 294 t.Run("privateKey", func(t *testing.T) { 295 bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777") 296 good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:])) 297 // Set iters to a large number like 1000 to flush out data races quickly. 298 // Don't leave it large. That can cause logical races 299 // in which the handshake is interleaved with key changes 300 // such that the private key appears to be unchanging but 301 // other state gets reset, which can cause handshake failures like 302 // "Received packet with invalid mac1". 303 const iters = 1 304 for i := 0; i < iters; i++ { 305 applyCfg(bad) 306 applyCfg(good) 307 } 308 }) 309 310 close(done) 311 } 312 313 func BenchmarkLatency(b *testing.B) { 314 pair := genTestPair(b, true) 315 316 // Establish a connection. 317 pair.Send(b, Ping, nil) 318 pair.Send(b, Pong, nil) 319 320 b.ResetTimer() 321 for i := 0; i < b.N; i++ { 322 pair.Send(b, Ping, nil) 323 pair.Send(b, Pong, nil) 324 } 325 } 326 327 func BenchmarkThroughput(b *testing.B) { 328 pair := genTestPair(b, true) 329 330 // Establish a connection. 331 pair.Send(b, Ping, nil) 332 pair.Send(b, Pong, nil) 333 334 // Measure how long it takes to receive b.N packets, 335 // starting when we receive the first packet. 336 var recv uint64 337 var elapsed time.Duration 338 var wg sync.WaitGroup 339 wg.Add(1) 340 go func() { 341 defer wg.Done() 342 var start time.Time 343 for { 344 <-pair[0].tun.Inbound 345 new := atomic.AddUint64(&recv, 1) 346 if new == 1 { 347 start = time.Now() 348 } 349 // Careful! Don't change this to else if; b.N can be equal to 1. 350 if new == uint64(b.N) { 351 elapsed = time.Since(start) 352 return 353 } 354 } 355 }() 356 357 // Send packets as fast as we can until we've received enough. 358 ping := tuntest.Ping(pair[0].ip, pair[1].ip) 359 pingc := pair[1].tun.Outbound 360 var sent uint64 361 for atomic.LoadUint64(&recv) != uint64(b.N) { 362 sent++ 363 pingc <- ping 364 } 365 wg.Wait() 366 367 b.ReportMetric(float64(elapsed)/float64(b.N), "ns/op") 368 b.ReportMetric(1-float64(b.N)/float64(sent), "packet-loss") 369 } 370 371 func BenchmarkUAPIGet(b *testing.B) { 372 pair := genTestPair(b, true) 373 pair.Send(b, Ping, nil) 374 pair.Send(b, Pong, nil) 375 b.ReportAllocs() 376 b.ResetTimer() 377 for i := 0; i < b.N; i++ { 378 pair[0].dev.IpcGetOperation(io.Discard) 379 } 380 } 381 382 func goroutineLeakCheck(t *testing.T) { 383 goroutines := func() (int, []byte) { 384 p := pprof.Lookup("goroutine") 385 b := new(bytes.Buffer) 386 p.WriteTo(b, 1) 387 return p.Count(), b.Bytes() 388 } 389 390 startGoroutines, startStacks := goroutines() 391 t.Cleanup(func() { 392 if t.Failed() { 393 return 394 } 395 // Give goroutines time to exit, if they need it. 396 for i := 0; i < 10000; i++ { 397 if runtime.NumGoroutine() <= startGoroutines { 398 return 399 } 400 time.Sleep(1 * time.Millisecond) 401 } 402 endGoroutines, endStacks := goroutines() 403 t.Logf("starting stacks:\n%s\n", startStacks) 404 t.Logf("ending stacks:\n%s\n", endStacks) 405 t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines) 406 }) 407 }