github.com/GFW-knocker/wireguard@v1.0.1/device/device_test.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "bytes" 10 "encoding/hex" 11 "fmt" 12 "io" 13 "math/rand" 14 "net/netip" 15 "os" 16 "runtime" 17 "runtime/pprof" 18 "sync" 19 "sync/atomic" 20 "testing" 21 "time" 22 23 "github.com/GFW-knocker/wireguard/conn" 24 "github.com/GFW-knocker/wireguard/conn/bindtest" 25 "github.com/GFW-knocker/wireguard/tun" 26 "github.com/GFW-knocker/wireguard/tun/tuntest" 27 ) 28 29 // uapiCfg returns a string that contains cfg formatted use with IpcSet. 30 // cfg is a series of alternating key/value strings. 31 // uapiCfg exists because editors and humans like to insert 32 // whitespace into configs, which can cause failures, some of which are silent. 33 // For example, a leading blank newline causes the remainder 34 // of the config to be silently ignored. 35 func uapiCfg(cfg ...string) string { 36 if len(cfg)%2 != 0 { 37 panic("odd number of args to uapiReader") 38 } 39 buf := new(bytes.Buffer) 40 for i, s := range cfg { 41 buf.WriteString(s) 42 sep := byte('\n') 43 if i%2 == 0 { 44 sep = '=' 45 } 46 buf.WriteByte(sep) 47 } 48 return buf.String() 49 } 50 51 // genConfigs generates a pair of configs that connect to each other. 52 // The configs use distinct, probably-usable ports. 53 func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { 54 var key1, key2 NoisePrivateKey 55 _, err := rand.Read(key1[:]) 56 if err != nil { 57 tb.Errorf("unable to generate private key random bytes: %v", err) 58 } 59 _, err = rand.Read(key2[:]) 60 if err != nil { 61 tb.Errorf("unable to generate private key random bytes: %v", err) 62 } 63 pub1, pub2 := key1.publicKey(), key2.publicKey() 64 65 cfgs[0] = uapiCfg( 66 "private_key", hex.EncodeToString(key1[:]), 67 "listen_port", "0", 68 "replace_peers", "true", 69 "public_key", hex.EncodeToString(pub2[:]), 70 "protocol_version", "1", 71 "replace_allowed_ips", "true", 72 "allowed_ip", "1.0.0.2/32", 73 ) 74 endpointCfgs[0] = uapiCfg( 75 "public_key", hex.EncodeToString(pub2[:]), 76 "endpoint", "127.0.0.1:%d", 77 ) 78 cfgs[1] = uapiCfg( 79 "private_key", hex.EncodeToString(key2[:]), 80 "listen_port", "0", 81 "replace_peers", "true", 82 "public_key", hex.EncodeToString(pub1[:]), 83 "protocol_version", "1", 84 "replace_allowed_ips", "true", 85 "allowed_ip", "1.0.0.1/32", 86 ) 87 endpointCfgs[1] = uapiCfg( 88 "public_key", hex.EncodeToString(pub1[:]), 89 "endpoint", "127.0.0.1:%d", 90 ) 91 return 92 } 93 94 // A testPair is a pair of testPeers. 95 type testPair [2]testPeer 96 97 // A testPeer is a peer used for testing. 98 type testPeer struct { 99 tun *tuntest.ChannelTUN 100 dev *Device 101 ip netip.Addr 102 } 103 104 type SendDirection bool 105 106 const ( 107 Ping SendDirection = true 108 Pong SendDirection = false 109 ) 110 111 func (d SendDirection) String() string { 112 if d == Ping { 113 return "ping" 114 } 115 return "pong" 116 } 117 118 func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) { 119 tb.Helper() 120 p0, p1 := pair[0], pair[1] 121 if !ping { 122 // pong is the new ping 123 p0, p1 = p1, p0 124 } 125 msg := tuntest.Ping(p0.ip, p1.ip) 126 p1.tun.Outbound <- msg 127 timer := time.NewTimer(5 * time.Second) 128 defer timer.Stop() 129 var err error 130 select { 131 case msgRecv := <-p0.tun.Inbound: 132 if !bytes.Equal(msg, msgRecv) { 133 err = fmt.Errorf("%s did not transit correctly", ping) 134 } 135 case <-timer.C: 136 err = fmt.Errorf("%s did not transit", ping) 137 case <-done: 138 } 139 if err != nil { 140 // The error may have occurred because the test is done. 141 select { 142 case <-done: 143 return 144 default: 145 } 146 // Real error. 147 tb.Error(err) 148 } 149 } 150 151 // genTestPair creates a testPair. 152 func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { 153 cfg, endpointCfg := genConfigs(tb) 154 var binds [2]conn.Bind 155 if realSocket { 156 binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind() 157 } else { 158 binds = bindtest.NewChannelBinds() 159 } 160 // Bring up a ChannelTun for each config. 161 for i := range pair { 162 p := &pair[i] 163 p.tun = tuntest.NewChannelTUN() 164 p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)}) 165 level := LogLevelVerbose 166 if _, ok := tb.(*testing.B); ok && !testing.Verbose() { 167 level = LogLevelError 168 } 169 p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i))) 170 if err := p.dev.IpcSet(cfg[i]); err != nil { 171 tb.Errorf("failed to configure device %d: %v", i, err) 172 p.dev.Close() 173 continue 174 } 175 if err := p.dev.Up(); err != nil { 176 tb.Errorf("failed to bring up device %d: %v", i, err) 177 p.dev.Close() 178 continue 179 } 180 endpointCfg[i^1] = fmt.Sprintf(endpointCfg[i^1], p.dev.net.port) 181 } 182 for i := range pair { 183 p := &pair[i] 184 if err := p.dev.IpcSet(endpointCfg[i]); err != nil { 185 tb.Errorf("failed to configure device endpoint %d: %v", i, err) 186 p.dev.Close() 187 continue 188 } 189 // The device is ready. Close it when the test completes. 190 tb.Cleanup(p.dev.Close) 191 } 192 return 193 } 194 195 func TestTwoDevicePing(t *testing.T) { 196 goroutineLeakCheck(t) 197 pair := genTestPair(t, true) 198 t.Run("ping 1.0.0.1", func(t *testing.T) { 199 pair.Send(t, Ping, nil) 200 }) 201 t.Run("ping 1.0.0.2", func(t *testing.T) { 202 pair.Send(t, Pong, nil) 203 }) 204 } 205 206 func TestUpDown(t *testing.T) { 207 goroutineLeakCheck(t) 208 const itrials = 50 209 const otrials = 10 210 211 for n := 0; n < otrials; n++ { 212 pair := genTestPair(t, false) 213 for i := range pair { 214 for k := range pair[i].dev.peers.keyMap { 215 pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:]))) 216 } 217 } 218 var wg sync.WaitGroup 219 wg.Add(len(pair)) 220 for i := range pair { 221 go func(d *Device) { 222 defer wg.Done() 223 for i := 0; i < itrials; i++ { 224 if err := d.Up(); err != nil { 225 t.Errorf("failed up bring up device: %v", err) 226 } 227 time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) 228 if err := d.Down(); err != nil { 229 t.Errorf("failed to bring down device: %v", err) 230 } 231 time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) 232 } 233 }(pair[i].dev) 234 } 235 wg.Wait() 236 for i := range pair { 237 pair[i].dev.Up() 238 pair[i].dev.Close() 239 } 240 } 241 } 242 243 // TestConcurrencySafety does other things concurrently with tunnel use. 244 // It is intended to be used with the race detector to catch data races. 245 func TestConcurrencySafety(t *testing.T) { 246 pair := genTestPair(t, true) 247 done := make(chan struct{}) 248 249 const warmupIters = 10 250 var warmup sync.WaitGroup 251 warmup.Add(warmupIters) 252 go func() { 253 // Send data continuously back and forth until we're done. 254 // Note that we may continue to attempt to send data 255 // even after done is closed. 256 i := warmupIters 257 for ping := Ping; ; ping = !ping { 258 pair.Send(t, ping, done) 259 select { 260 case <-done: 261 return 262 default: 263 } 264 if i > 0 { 265 warmup.Done() 266 i-- 267 } 268 } 269 }() 270 warmup.Wait() 271 272 applyCfg := func(cfg string) { 273 err := pair[0].dev.IpcSet(cfg) 274 if err != nil { 275 t.Fatal(err) 276 } 277 } 278 279 // Change persistent_keepalive_interval concurrently with tunnel use. 280 t.Run("persistentKeepaliveInterval", func(t *testing.T) { 281 var pub NoisePublicKey 282 for key := range pair[0].dev.peers.keyMap { 283 pub = key 284 break 285 } 286 cfg := uapiCfg( 287 "public_key", hex.EncodeToString(pub[:]), 288 "persistent_keepalive_interval", "1", 289 ) 290 for i := 0; i < 1000; i++ { 291 applyCfg(cfg) 292 } 293 }) 294 295 // Change private keys concurrently with tunnel use. 296 t.Run("privateKey", func(t *testing.T) { 297 bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777") 298 good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:])) 299 // Set iters to a large number like 1000 to flush out data races quickly. 300 // Don't leave it large. That can cause logical races 301 // in which the handshake is interleaved with key changes 302 // such that the private key appears to be unchanging but 303 // other state gets reset, which can cause handshake failures like 304 // "Received packet with invalid mac1". 305 const iters = 1 306 for i := 0; i < iters; i++ { 307 applyCfg(bad) 308 applyCfg(good) 309 } 310 }) 311 312 // Perform bind updates and keepalive sends concurrently with tunnel use. 313 t.Run("bindUpdate and keepalive", func(t *testing.T) { 314 const iters = 10 315 for i := 0; i < iters; i++ { 316 for _, peer := range pair { 317 peer.dev.BindUpdate() 318 peer.dev.SendKeepalivesToPeersWithCurrentKeypair() 319 } 320 } 321 }) 322 323 close(done) 324 } 325 326 func BenchmarkLatency(b *testing.B) { 327 pair := genTestPair(b, true) 328 329 // Establish a connection. 330 pair.Send(b, Ping, nil) 331 pair.Send(b, Pong, nil) 332 333 b.ResetTimer() 334 for i := 0; i < b.N; i++ { 335 pair.Send(b, Ping, nil) 336 pair.Send(b, Pong, nil) 337 } 338 } 339 340 func BenchmarkThroughput(b *testing.B) { 341 pair := genTestPair(b, true) 342 343 // Establish a connection. 344 pair.Send(b, Ping, nil) 345 pair.Send(b, Pong, nil) 346 347 // Measure how long it takes to receive b.N packets, 348 // starting when we receive the first packet. 349 var recv atomic.Uint64 350 var elapsed time.Duration 351 var wg sync.WaitGroup 352 wg.Add(1) 353 go func() { 354 defer wg.Done() 355 var start time.Time 356 for { 357 <-pair[0].tun.Inbound 358 new := recv.Add(1) 359 if new == 1 { 360 start = time.Now() 361 } 362 // Careful! Don't change this to else if; b.N can be equal to 1. 363 if new == uint64(b.N) { 364 elapsed = time.Since(start) 365 return 366 } 367 } 368 }() 369 370 // Send packets as fast as we can until we've received enough. 371 ping := tuntest.Ping(pair[0].ip, pair[1].ip) 372 pingc := pair[1].tun.Outbound 373 var sent uint64 374 for recv.Load() != uint64(b.N) { 375 sent++ 376 pingc <- ping 377 } 378 wg.Wait() 379 380 b.ReportMetric(float64(elapsed)/float64(b.N), "ns/op") 381 b.ReportMetric(1-float64(b.N)/float64(sent), "packet-loss") 382 } 383 384 func BenchmarkUAPIGet(b *testing.B) { 385 pair := genTestPair(b, true) 386 pair.Send(b, Ping, nil) 387 pair.Send(b, Pong, nil) 388 b.ReportAllocs() 389 b.ResetTimer() 390 for i := 0; i < b.N; i++ { 391 pair[0].dev.IpcGetOperation(io.Discard) 392 } 393 } 394 395 func goroutineLeakCheck(t *testing.T) { 396 goroutines := func() (int, []byte) { 397 p := pprof.Lookup("goroutine") 398 b := new(bytes.Buffer) 399 p.WriteTo(b, 1) 400 return p.Count(), b.Bytes() 401 } 402 403 startGoroutines, startStacks := goroutines() 404 t.Cleanup(func() { 405 if t.Failed() { 406 return 407 } 408 // Give goroutines time to exit, if they need it. 409 for i := 0; i < 10000; i++ { 410 if runtime.NumGoroutine() <= startGoroutines { 411 return 412 } 413 time.Sleep(1 * time.Millisecond) 414 } 415 endGoroutines, endStacks := goroutines() 416 t.Logf("starting stacks:\n%s\n", startStacks) 417 t.Logf("ending stacks:\n%s\n", endStacks) 418 t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines) 419 }) 420 } 421 422 type fakeBindSized struct { 423 size int 424 } 425 426 func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { 427 return nil, 0, nil 428 } 429 func (b *fakeBindSized) Close() error { return nil } 430 func (b *fakeBindSized) SetMark(mark uint32) error { return nil } 431 func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil } 432 func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } 433 func (b *fakeBindSized) BatchSize() int { return b.size } 434 435 type fakeTUNDeviceSized struct { 436 size int 437 } 438 439 func (t *fakeTUNDeviceSized) File() *os.File { return nil } 440 func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { 441 return 0, nil 442 } 443 func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil } 444 func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil } 445 func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil } 446 func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil } 447 func (t *fakeTUNDeviceSized) Close() error { return nil } 448 func (t *fakeTUNDeviceSized) BatchSize() int { return t.size } 449 450 func TestBatchSize(t *testing.T) { 451 d := Device{} 452 453 d.net.bind = &fakeBindSized{1} 454 d.tun.device = &fakeTUNDeviceSized{1} 455 if want, got := 1, d.BatchSize(); got != want { 456 t.Errorf("expected batch size %d, got %d", want, got) 457 } 458 459 d.net.bind = &fakeBindSized{1} 460 d.tun.device = &fakeTUNDeviceSized{128} 461 if want, got := 128, d.BatchSize(); got != want { 462 t.Errorf("expected batch size %d, got %d", want, got) 463 } 464 465 d.net.bind = &fakeBindSized{128} 466 d.tun.device = &fakeTUNDeviceSized{1} 467 if want, got := 128, d.BatchSize(); got != want { 468 t.Errorf("expected batch size %d, got %d", want, got) 469 } 470 471 d.net.bind = &fakeBindSized{128} 472 d.tun.device = &fakeTUNDeviceSized{128} 473 if want, got := 128, d.BatchSize(); got != want { 474 t.Errorf("expected batch size %d, got %d", want, got) 475 } 476 }