github.com/amnezia-vpn/amneziawg-go@v0.2.8/device/device_test.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "bytes" 10 "encoding/hex" 11 "fmt" 12 "io" 13 "math/rand" 14 "net/netip" 15 "os" 16 "runtime" 17 "runtime/pprof" 18 "sync" 19 "sync/atomic" 20 "testing" 21 "time" 22 23 "github.com/amnezia-vpn/amneziawg-go/conn" 24 "github.com/amnezia-vpn/amneziawg-go/conn/bindtest" 25 "github.com/amnezia-vpn/amneziawg-go/tun" 26 "github.com/amnezia-vpn/amneziawg-go/tun/tuntest" 27 ) 28 29 // uapiCfg returns a string that contains cfg formatted use with IpcSet. 30 // cfg is a series of alternating key/value strings. 31 // uapiCfg exists because editors and humans like to insert 32 // whitespace into configs, which can cause failures, some of which are silent. 33 // For example, a leading blank newline causes the remainder 34 // of the config to be silently ignored. 35 func uapiCfg(cfg ...string) string { 36 if len(cfg)%2 != 0 { 37 panic("odd number of args to uapiReader") 38 } 39 buf := new(bytes.Buffer) 40 for i, s := range cfg { 41 buf.WriteString(s) 42 sep := byte('\n') 43 if i%2 == 0 { 44 sep = '=' 45 } 46 buf.WriteByte(sep) 47 } 48 return buf.String() 49 } 50 51 // genConfigs generates a pair of configs that connect to each other. 52 // The configs use distinct, probably-usable ports. 53 func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { 54 var key1, key2 NoisePrivateKey 55 _, err := rand.Read(key1[:]) 56 if err != nil { 57 tb.Errorf("unable to generate private key random bytes: %v", err) 58 } 59 _, err = rand.Read(key2[:]) 60 if err != nil { 61 tb.Errorf("unable to generate private key random bytes: %v", err) 62 } 63 pub1, pub2 := key1.publicKey(), key2.publicKey() 64 65 cfgs[0] = uapiCfg( 66 "private_key", hex.EncodeToString(key1[:]), 67 "listen_port", "0", 68 "replace_peers", "true", 69 "public_key", hex.EncodeToString(pub2[:]), 70 "protocol_version", "1", 71 "replace_allowed_ips", "true", 72 "allowed_ip", "1.0.0.2/32", 73 ) 74 endpointCfgs[0] = uapiCfg( 75 "public_key", hex.EncodeToString(pub2[:]), 76 "endpoint", "127.0.0.1:%d", 77 ) 78 cfgs[1] = uapiCfg( 79 "private_key", hex.EncodeToString(key2[:]), 80 "listen_port", "0", 81 "replace_peers", "true", 82 "public_key", hex.EncodeToString(pub1[:]), 83 "protocol_version", "1", 84 "replace_allowed_ips", "true", 85 "allowed_ip", "1.0.0.1/32", 86 ) 87 endpointCfgs[1] = uapiCfg( 88 "public_key", hex.EncodeToString(pub1[:]), 89 "endpoint", "127.0.0.1:%d", 90 ) 91 return 92 } 93 94 func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { 95 var key1, key2 NoisePrivateKey 96 _, err := rand.Read(key1[:]) 97 if err != nil { 98 tb.Errorf("unable to generate private key random bytes: %v", err) 99 } 100 _, err = rand.Read(key2[:]) 101 if err != nil { 102 tb.Errorf("unable to generate private key random bytes: %v", err) 103 } 104 pub1, pub2 := key1.publicKey(), key2.publicKey() 105 106 cfgs[0] = uapiCfg( 107 "private_key", hex.EncodeToString(key1[:]), 108 "listen_port", "0", 109 "replace_peers", "true", 110 "jc", "5", 111 "jmin", "500", 112 "jmax", "501", 113 "s1", "30", 114 "s2", "40", 115 "h1", "123456", 116 "h2", "67543", 117 "h4", "32345", 118 "h3", "123123", 119 "public_key", hex.EncodeToString(pub2[:]), 120 "protocol_version", "1", 121 "replace_allowed_ips", "true", 122 "allowed_ip", "1.0.0.2/32", 123 ) 124 endpointCfgs[0] = uapiCfg( 125 "public_key", hex.EncodeToString(pub2[:]), 126 "endpoint", "127.0.0.1:%d", 127 ) 128 cfgs[1] = uapiCfg( 129 "private_key", hex.EncodeToString(key2[:]), 130 "listen_port", "0", 131 "replace_peers", "true", 132 "jc", "5", 133 "jmin", "500", 134 "jmax", "501", 135 "s1", "30", 136 "s2", "40", 137 "h1", "123456", 138 "h2", "67543", 139 "h4", "32345", 140 "h3", "123123", 141 "public_key", hex.EncodeToString(pub1[:]), 142 "protocol_version", "1", 143 "replace_allowed_ips", "true", 144 "allowed_ip", "1.0.0.1/32", 145 ) 146 endpointCfgs[1] = uapiCfg( 147 "public_key", hex.EncodeToString(pub1[:]), 148 "endpoint", "127.0.0.1:%d", 149 ) 150 return 151 } 152 153 // A testPair is a pair of testPeers. 154 type testPair [2]testPeer 155 156 // A testPeer is a peer used for testing. 157 type testPeer struct { 158 tun *tuntest.ChannelTUN 159 dev *Device 160 ip netip.Addr 161 } 162 163 type SendDirection bool 164 165 const ( 166 Ping SendDirection = true 167 Pong SendDirection = false 168 ) 169 170 func (d SendDirection) String() string { 171 if d == Ping { 172 return "ping" 173 } 174 return "pong" 175 } 176 177 func (pair *testPair) Send( 178 tb testing.TB, 179 ping SendDirection, 180 done chan struct{}, 181 ) { 182 tb.Helper() 183 p0, p1 := pair[0], pair[1] 184 if !ping { 185 // pong is the new ping 186 p0, p1 = p1, p0 187 } 188 msg := tuntest.Ping(p0.ip, p1.ip) 189 p1.tun.Outbound <- msg 190 timer := time.NewTimer(5 * time.Second) 191 defer timer.Stop() 192 var err error 193 select { 194 case msgRecv := <-p0.tun.Inbound: 195 if !bytes.Equal(msg, msgRecv) { 196 err = fmt.Errorf("%s did not transit correctly", ping) 197 } 198 case <-timer.C: 199 err = fmt.Errorf("%s did not transit", ping) 200 case <-done: 201 } 202 if err != nil { 203 // The error may have occurred because the test is done. 204 select { 205 case <-done: 206 return 207 default: 208 } 209 // Real error. 210 tb.Error(err) 211 } 212 } 213 214 // genTestPair creates a testPair. 215 func genTestPair( 216 tb testing.TB, 217 realSocket, withASecurity bool, 218 ) (pair testPair) { 219 var cfg, endpointCfg [2]string 220 if withASecurity { 221 cfg, endpointCfg = genASecurityConfigs(tb) 222 } else { 223 cfg, endpointCfg = genConfigs(tb) 224 } 225 var binds [2]conn.Bind 226 if realSocket { 227 binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind() 228 } else { 229 binds = bindtest.NewChannelBinds() 230 } 231 // Bring up a ChannelTun for each config. 232 for i := range pair { 233 p := &pair[i] 234 p.tun = tuntest.NewChannelTUN() 235 p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)}) 236 level := LogLevelVerbose 237 if _, ok := tb.(*testing.B); ok && !testing.Verbose() { 238 level = LogLevelError 239 } 240 p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i))) 241 if err := p.dev.IpcSet(cfg[i]); err != nil { 242 tb.Errorf("failed to configure device %d: %v", i, err) 243 p.dev.Close() 244 continue 245 } 246 if err := p.dev.Up(); err != nil { 247 tb.Errorf("failed to bring up device %d: %v", i, err) 248 p.dev.Close() 249 continue 250 } 251 endpointCfg[i^1] = fmt.Sprintf(endpointCfg[i^1], p.dev.net.port) 252 } 253 for i := range pair { 254 p := &pair[i] 255 if err := p.dev.IpcSet(endpointCfg[i]); err != nil { 256 tb.Errorf("failed to configure device endpoint %d: %v", i, err) 257 p.dev.Close() 258 continue 259 } 260 // The device is ready. Close it when the test completes. 261 tb.Cleanup(p.dev.Close) 262 } 263 return 264 } 265 266 func TestTwoDevicePing(t *testing.T) { 267 goroutineLeakCheck(t) 268 pair := genTestPair(t, true, false) 269 t.Run("ping 1.0.0.1", func(t *testing.T) { 270 pair.Send(t, Ping, nil) 271 }) 272 t.Run("ping 1.0.0.2", func(t *testing.T) { 273 pair.Send(t, Pong, nil) 274 }) 275 } 276 277 func TestTwoDevicePingASecurity(t *testing.T) { 278 goroutineLeakCheck(t) 279 pair := genTestPair(t, true, true) 280 t.Run("ping 1.0.0.1", func(t *testing.T) { 281 pair.Send(t, Ping, nil) 282 }) 283 t.Run("ping 1.0.0.2", func(t *testing.T) { 284 pair.Send(t, Pong, nil) 285 }) 286 } 287 288 func TestUpDown(t *testing.T) { 289 goroutineLeakCheck(t) 290 const itrials = 50 291 const otrials = 10 292 293 for n := 0; n < otrials; n++ { 294 pair := genTestPair(t, false, false) 295 for i := range pair { 296 for k := range pair[i].dev.peers.keyMap { 297 pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:]))) 298 } 299 } 300 var wg sync.WaitGroup 301 wg.Add(len(pair)) 302 for i := range pair { 303 go func(d *Device) { 304 defer wg.Done() 305 for i := 0; i < itrials; i++ { 306 if err := d.Up(); err != nil { 307 t.Errorf("failed up bring up device: %v", err) 308 } 309 time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) 310 if err := d.Down(); err != nil { 311 t.Errorf("failed to bring down device: %v", err) 312 } 313 time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) 314 } 315 }(pair[i].dev) 316 } 317 wg.Wait() 318 for i := range pair { 319 pair[i].dev.Up() 320 pair[i].dev.Close() 321 } 322 } 323 } 324 325 // TestConcurrencySafety does other things concurrently with tunnel use. 326 // It is intended to be used with the race detector to catch data races. 327 func TestConcurrencySafety(t *testing.T) { 328 pair := genTestPair(t, true, false) 329 done := make(chan struct{}) 330 331 const warmupIters = 10 332 var warmup sync.WaitGroup 333 warmup.Add(warmupIters) 334 go func() { 335 // Send data continuously back and forth until we're done. 336 // Note that we may continue to attempt to send data 337 // even after done is closed. 338 i := warmupIters 339 for ping := Ping; ; ping = !ping { 340 pair.Send(t, ping, done) 341 select { 342 case <-done: 343 return 344 default: 345 } 346 if i > 0 { 347 warmup.Done() 348 i-- 349 } 350 } 351 }() 352 warmup.Wait() 353 354 applyCfg := func(cfg string) { 355 err := pair[0].dev.IpcSet(cfg) 356 if err != nil { 357 t.Fatal(err) 358 } 359 } 360 361 // Change persistent_keepalive_interval concurrently with tunnel use. 362 t.Run("persistentKeepaliveInterval", func(t *testing.T) { 363 var pub NoisePublicKey 364 for key := range pair[0].dev.peers.keyMap { 365 pub = key 366 break 367 } 368 cfg := uapiCfg( 369 "public_key", hex.EncodeToString(pub[:]), 370 "persistent_keepalive_interval", "1", 371 ) 372 for i := 0; i < 1000; i++ { 373 applyCfg(cfg) 374 } 375 }) 376 377 // Change private keys concurrently with tunnel use. 378 t.Run("privateKey", func(t *testing.T) { 379 bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777") 380 good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:])) 381 // Set iters to a large number like 1000 to flush out data races quickly. 382 // Don't leave it large. That can cause logical races 383 // in which the handshake is interleaved with key changes 384 // such that the private key appears to be unchanging but 385 // other state gets reset, which can cause handshake failures like 386 // "Received packet with invalid mac1". 387 const iters = 1 388 for i := 0; i < iters; i++ { 389 applyCfg(bad) 390 applyCfg(good) 391 } 392 }) 393 394 // Perform bind updates and keepalive sends concurrently with tunnel use. 395 t.Run("bindUpdate and keepalive", func(t *testing.T) { 396 const iters = 10 397 for i := 0; i < iters; i++ { 398 for _, peer := range pair { 399 peer.dev.BindUpdate() 400 peer.dev.SendKeepalivesToPeersWithCurrentKeypair() 401 } 402 } 403 }) 404 405 close(done) 406 } 407 408 func BenchmarkLatency(b *testing.B) { 409 pair := genTestPair(b, true, false) 410 411 // Establish a connection. 412 pair.Send(b, Ping, nil) 413 pair.Send(b, Pong, nil) 414 415 b.ResetTimer() 416 for i := 0; i < b.N; i++ { 417 pair.Send(b, Ping, nil) 418 pair.Send(b, Pong, nil) 419 } 420 } 421 422 func BenchmarkThroughput(b *testing.B) { 423 pair := genTestPair(b, true, false) 424 425 // Establish a connection. 426 pair.Send(b, Ping, nil) 427 pair.Send(b, Pong, nil) 428 429 // Measure how long it takes to receive b.N packets, 430 // starting when we receive the first packet. 431 var recv atomic.Uint64 432 var elapsed time.Duration 433 var wg sync.WaitGroup 434 wg.Add(1) 435 go func() { 436 defer wg.Done() 437 var start time.Time 438 for { 439 <-pair[0].tun.Inbound 440 new := recv.Add(1) 441 if new == 1 { 442 start = time.Now() 443 } 444 // Careful! Don't change this to else if; b.N can be equal to 1. 445 if new == uint64(b.N) { 446 elapsed = time.Since(start) 447 return 448 } 449 } 450 }() 451 452 // Send packets as fast as we can until we've received enough. 453 ping := tuntest.Ping(pair[0].ip, pair[1].ip) 454 pingc := pair[1].tun.Outbound 455 var sent uint64 456 for recv.Load() != uint64(b.N) { 457 sent++ 458 pingc <- ping 459 } 460 wg.Wait() 461 462 b.ReportMetric(float64(elapsed)/float64(b.N), "ns/op") 463 b.ReportMetric(1-float64(b.N)/float64(sent), "packet-loss") 464 } 465 466 func BenchmarkUAPIGet(b *testing.B) { 467 pair := genTestPair(b, true, false) 468 pair.Send(b, Ping, nil) 469 pair.Send(b, Pong, nil) 470 b.ReportAllocs() 471 b.ResetTimer() 472 for i := 0; i < b.N; i++ { 473 pair[0].dev.IpcGetOperation(io.Discard) 474 } 475 } 476 477 func goroutineLeakCheck(t *testing.T) { 478 goroutines := func() (int, []byte) { 479 p := pprof.Lookup("goroutine") 480 b := new(bytes.Buffer) 481 p.WriteTo(b, 1) 482 return p.Count(), b.Bytes() 483 } 484 485 startGoroutines, startStacks := goroutines() 486 t.Cleanup(func() { 487 if t.Failed() { 488 return 489 } 490 // Give goroutines time to exit, if they need it. 491 for i := 0; i < 10000; i++ { 492 if runtime.NumGoroutine() <= startGoroutines { 493 return 494 } 495 time.Sleep(1 * time.Millisecond) 496 } 497 endGoroutines, endStacks := goroutines() 498 t.Logf("starting stacks:\n%s\n", startStacks) 499 t.Logf("ending stacks:\n%s\n", endStacks) 500 t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines) 501 }) 502 } 503 504 type fakeBindSized struct { 505 size int 506 } 507 508 func (b *fakeBindSized) Open( 509 port uint16, 510 ) (fns []conn.ReceiveFunc, actualPort uint16, err error) { 511 return nil, 0, nil 512 } 513 514 func (b *fakeBindSized) Close() error { return nil } 515 516 func (b *fakeBindSized) SetMark(mark uint32) error { return nil } 517 518 func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil } 519 520 func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } 521 522 func (b *fakeBindSized) BatchSize() int { return b.size } 523 524 type fakeTUNDeviceSized struct { 525 size int 526 } 527 528 func (t *fakeTUNDeviceSized) File() *os.File { return nil } 529 530 func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { 531 return 0, nil 532 } 533 534 func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil } 535 536 func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil } 537 538 func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil } 539 540 func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil } 541 542 func (t *fakeTUNDeviceSized) Close() error { return nil } 543 544 func (t *fakeTUNDeviceSized) BatchSize() int { return t.size } 545 546 func TestBatchSize(t *testing.T) { 547 d := Device{} 548 549 d.net.bind = &fakeBindSized{1} 550 d.tun.device = &fakeTUNDeviceSized{1} 551 if want, got := 1, d.BatchSize(); got != want { 552 t.Errorf("expected batch size %d, got %d", want, got) 553 } 554 555 d.net.bind = &fakeBindSized{1} 556 d.tun.device = &fakeTUNDeviceSized{128} 557 if want, got := 128, d.BatchSize(); got != want { 558 t.Errorf("expected batch size %d, got %d", want, got) 559 } 560 561 d.net.bind = &fakeBindSized{128} 562 d.tun.device = &fakeTUNDeviceSized{1} 563 if want, got := 128, d.BatchSize(); got != want { 564 t.Errorf("expected batch size %d, got %d", want, got) 565 } 566 567 d.net.bind = &fakeBindSized{128} 568 d.tun.device = &fakeTUNDeviceSized{128} 569 if want, got := 128, d.BatchSize(); got != want { 570 t.Errorf("expected batch size %d, got %d", want, got) 571 } 572 }