github.com/psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/quic/quic_test.go (about) 1 //go:build !PSIPHON_DISABLE_QUIC 2 // +build !PSIPHON_DISABLE_QUIC 3 4 /* 5 * Copyright (c) 2018, Psiphon Inc. 6 * All rights reserved. 7 * 8 * This program is free software: you can redistribute it and/or modify 9 * it under the terms of the GNU General Public License as published by 10 * the Free Software Foundation, either version 3 of the License, or 11 * (at your option) any later version. 12 * 13 * This program is distributed in the hope that it will be useful, 14 * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 * GNU General Public License for more details. 17 * 18 * You should have received a copy of the GNU General Public License 19 * along with this program. If not, see <http://www.gnu.org/licenses/>. 20 * 21 */ 22 23 package quic 24 25 import ( 26 "context" 27 "fmt" 28 "io" 29 "net" 30 "runtime" 31 "strings" 32 "sync/atomic" 33 "testing" 34 "time" 35 36 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common" 37 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors" 38 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng" 39 "golang.org/x/sync/errgroup" 40 ) 41 42 func TestQUIC(t *testing.T) { 43 for quicVersion := range supportedVersionNumbers { 44 t.Run(fmt.Sprintf("%s", quicVersion), func(t *testing.T) { 45 if isGQUIC(quicVersion) && !GQUICEnabled() { 46 t.Skipf("gQUIC is not enabled") 47 } 48 runQUIC(t, quicVersion, GQUICEnabled(), false) 49 }) 50 if isIETF(quicVersion) { 51 t.Run(fmt.Sprintf("%s (invoke anti-probing)", quicVersion), func(t *testing.T) { 52 runQUIC(t, quicVersion, GQUICEnabled(), true) 53 }) 54 } 55 if isIETF(quicVersion) { 56 t.Run(fmt.Sprintf("%s (disable gQUIC)", quicVersion), func(t *testing.T) { 57 runQUIC(t, quicVersion, false, false) 58 }) 59 } 60 } 61 } 62 63 func runQUIC( 64 t *testing.T, 65 quicVersion string, 66 enableGQUIC bool, 67 invokeAntiProbing bool) { 68 69 initGoroutines := getGoroutines() 70 71 clients := 10 72 bytesToSend := 1 << 20 73 74 serverReceivedBytes := int64(0) 75 clientReceivedBytes := int64(0) 76 77 // Intermittently, on some platforms, the client connection termination 78 // packet is not received even when sent/received locally; set a brief 79 // idle timeout to ensure the server-side client handler doesn't block too 80 // long on Read, causing the test to fail. 81 // 82 // In realistic network conditions, and especially under adversarial 83 // network conditions, we should not expect to regularly receive client 84 // connection termination packets. 85 serverIdleTimeout = 1 * time.Second 86 87 irregularTunnelLogger := func(_ string, err error, _ common.LogFields) { 88 if !invokeAntiProbing { 89 t.Errorf("unexpected irregular tunnel event: %v", err) 90 } 91 } 92 93 obfuscationKey := prng.HexString(32) 94 95 listener, err := Listen( 96 nil, 97 irregularTunnelLogger, 98 "127.0.0.1:0", 99 obfuscationKey, 100 enableGQUIC) 101 if err != nil { 102 t.Fatalf("Listen failed: %s", err) 103 } 104 105 serverAddress := listener.Addr().String() 106 107 testGroup, testCtx := errgroup.WithContext(context.Background()) 108 109 testGroup.Go(func() error { 110 111 if invokeAntiProbing { 112 // The quic-go server can still handshake new sessions even if 113 // Accept isn't called. 114 return nil 115 } 116 117 var serverGroup errgroup.Group 118 119 for i := 0; i < clients; i++ { 120 121 conn, err := listener.Accept() 122 if err != nil { 123 return errors.Trace(err) 124 } 125 126 serverGroup.Go(func() error { 127 b := make([]byte, 1024) 128 for { 129 n, err := conn.Read(b) 130 atomic.AddInt64(&serverReceivedBytes, int64(n)) 131 if err == io.EOF { 132 return nil 133 } else if err != nil { 134 return errors.Trace(err) 135 } 136 _, err = conn.Write(b[:n]) 137 if err != nil { 138 return errors.Trace(err) 139 } 140 } 141 }) 142 } 143 144 err := serverGroup.Wait() 145 if err != nil { 146 return errors.Trace(err) 147 } 148 149 return nil 150 }) 151 152 for i := 0; i < clients; i++ { 153 154 disablePathMTUDiscovery := i%2 == 0 155 156 testGroup.Go(func() error { 157 158 ctx, cancelFunc := context.WithTimeout( 159 context.Background(), 1*time.Second) 160 defer cancelFunc() 161 162 remoteAddr, err := net.ResolveUDPAddr("udp", serverAddress) 163 if err != nil { 164 return errors.Trace(err) 165 } 166 167 packetConn, err := net.ListenPacket("udp4", "127.0.0.1:0") 168 if err != nil { 169 return errors.Trace(err) 170 } 171 172 clientObfuscationKey := obfuscationKey 173 if invokeAntiProbing { 174 clientObfuscationKey = prng.HexString(32) 175 packetConn = &countReadsConn{PacketConn: packetConn} 176 } 177 178 obfuscationPaddingSeed, err := prng.NewSeed() 179 if err != nil { 180 return errors.Trace(err) 181 } 182 183 var clientHelloSeed *prng.Seed 184 if isClientHelloRandomized(quicVersion) { 185 clientHelloSeed, err = prng.NewSeed() 186 if err != nil { 187 return errors.Trace(err) 188 } 189 } 190 191 conn, err := Dial( 192 ctx, 193 packetConn, 194 remoteAddr, 195 serverAddress, 196 quicVersion, 197 clientHelloSeed, 198 clientObfuscationKey, 199 obfuscationPaddingSeed, 200 disablePathMTUDiscovery) 201 202 if invokeAntiProbing { 203 204 if err == nil { 205 return errors.TraceNew( 206 "unexpected dial success with invalid client hello random") 207 } 208 209 readCount := packetConn.(*countReadsConn).getReadCount() 210 211 if readCount > 0 { 212 return errors.Tracef( 213 "unexpected %d read packets with invalid client hello random", 214 readCount) 215 } 216 217 return nil 218 } 219 220 if err != nil { 221 return errors.Trace(err) 222 } 223 224 // Cancel should interrupt dialing only 225 cancelFunc() 226 227 var clientGroup errgroup.Group 228 229 clientGroup.Go(func() error { 230 defer conn.Close() 231 b := make([]byte, 1024) 232 bytesRead := 0 233 for bytesRead < bytesToSend { 234 n, err := conn.Read(b) 235 bytesRead += n 236 atomic.AddInt64(&clientReceivedBytes, int64(n)) 237 if err == io.EOF { 238 break 239 } else if err != nil { 240 return errors.Trace(err) 241 } 242 } 243 return nil 244 }) 245 246 clientGroup.Go(func() error { 247 b := make([]byte, bytesToSend) 248 _, err := conn.Write(b) 249 if err != nil { 250 return errors.Trace(err) 251 } 252 return nil 253 }) 254 255 return clientGroup.Wait() 256 }) 257 258 } 259 260 go func() { 261 testGroup.Wait() 262 }() 263 264 <-testCtx.Done() 265 listener.Close() 266 267 err = testGroup.Wait() 268 if err != nil { 269 t.Errorf("goroutine failed: %s", err) 270 } 271 272 bytes := atomic.LoadInt64(&serverReceivedBytes) 273 expectedBytes := int64(clients * bytesToSend) 274 if invokeAntiProbing { 275 expectedBytes = 0 276 } 277 if bytes != expectedBytes { 278 t.Errorf("unexpected serverReceivedBytes: %d vs. %d", bytes, expectedBytes) 279 } 280 281 bytes = atomic.LoadInt64(&clientReceivedBytes) 282 if bytes != expectedBytes { 283 t.Errorf("unexpected clientReceivedBytes: %d vs. %d", bytes, expectedBytes) 284 } 285 286 _, err = listener.Accept() 287 if err == nil { 288 t.Error("unexpected Accept after Close") 289 } 290 291 // Check for unexpected dangling goroutines after shutdown. 292 // 293 // quic-go.packetHandlerMap.listen shutdown is async and some quic-go 294 // goroutines and/or timers dangle so this test makes allowances for these 295 // known dangling goroutinees. 296 297 expectedDanglingGoroutines := []string{ 298 "quic-go.(*packetHandlerMap).Retire.func1", 299 "quic-go.(*packetHandlerMap).ReplaceWithClosed.func1", 300 "quic-go.(*packetHandlerMap).RetireResetToken.func1", 301 "gquic-go.(*packetHandlerMap).removeByConnectionIDAsString.func1", 302 } 303 304 sleepTime := 100 * time.Millisecond 305 306 // The longest expected dangling goroutine is in gquic-go and is launched by a timer 307 // that fires after ClosedSessionDeleteTimeout, which is 1m. Allow one extra second 308 // to ensure this period elapses and the time.AfterFunc runs. 309 // 310 // To avoid taking 1m to run this test every time, the dangling goroutine check exits 311 // early once no dangling goroutines are found. Note that this doesn't account for 312 // any timers still pending at the early exit time. 313 n := int((61 * time.Second) / sleepTime) 314 315 for i := 0; i < n; i++ { 316 317 // Sleep before making any checks, since quic-go.packetHandlerMap.listen 318 // shutdown is asynchronous. 319 time.Sleep(100 * time.Millisecond) 320 321 // After the full 61s, no dangling goroutines are expected. 322 if i == n-1 { 323 expectedDanglingGoroutines = []string{} 324 } 325 326 hasDangling, onlyExpectedDangling := checkDanglingGoroutines( 327 t, initGoroutines, expectedDanglingGoroutines) 328 if !hasDangling { 329 break 330 } else if !onlyExpectedDangling { 331 t.Fatalf("unexpected dangling goroutines") 332 } 333 } 334 } 335 336 type countReadsConn struct { 337 net.PacketConn 338 readCount int32 339 } 340 341 func (conn *countReadsConn) ReadFrom(p []byte) (int, net.Addr, error) { 342 n, addr, err := conn.PacketConn.ReadFrom(p) 343 if n > 0 { 344 atomic.AddInt32(&conn.readCount, 1) 345 } 346 return n, addr, err 347 } 348 349 func (conn *countReadsConn) getReadCount() int { 350 return int(atomic.LoadInt32(&conn.readCount)) 351 } 352 353 func getGoroutines() []runtime.StackRecord { 354 n, _ := runtime.GoroutineProfile(nil) 355 r := make([]runtime.StackRecord, n) 356 runtime.GoroutineProfile(r) 357 return r 358 } 359 360 func checkDanglingGoroutines( 361 t *testing.T, 362 initGoroutines []runtime.StackRecord, 363 expectedDanglingGoroutines []string) (bool, bool) { 364 365 hasDangling := false 366 onlyExpectedDangling := true 367 current := getGoroutines() 368 for _, g := range current { 369 found := false 370 for _, h := range initGoroutines { 371 if g == h { 372 found = true 373 break 374 } 375 } 376 if !found { 377 stack := g.Stack() 378 funcNames := make([]string, len(stack)) 379 skip := false 380 isExpected := false 381 for i := 0; i < len(stack); i++ { 382 funcNames[i] = getFunctionName(stack[i]) 383 384 // The current goroutine won't have the same stack as in initGoroutines. 385 if strings.Contains(funcNames[i], "checkDanglingGoroutines") { 386 skip = true 387 break 388 } 389 390 // testing.T.Run runs the the test function, f, in another goroutine. f is 391 // the current goroutine, which captures initGoroutines. 392 // https://github.com/golang/go/blob/release-branch.go1.13/src/testing/testing.go#L960-L961: 393 // 394 // go tRunner(t, f) 395 // if !<-t.signal { 396 // ... 397 // 398 // f may capture initGoroutines before or after testing.T.Run advances to 399 // the channel receive, so the stack of the testing.T.Run goroutine may or 400 // may not match initGoroutines. Skip it. 401 if strings.Contains(funcNames[i], "testing.(*T).Run") { 402 skip = true 403 break 404 } 405 406 // This goroutine, created by Listener.clientRandomHistory, 407 // terminates nondeterministically, based on garbage 408 // collection. Skip it. 409 if strings.Contains(funcNames[i], "go-cache-lru.(*janitor).Run") { 410 skip = true 411 break 412 } 413 414 for _, expected := range expectedDanglingGoroutines { 415 if strings.Contains(funcNames[i], expected) { 416 isExpected = true 417 break 418 } 419 } 420 if isExpected { 421 break 422 } 423 } 424 if !skip { 425 hasDangling = true 426 if !isExpected { 427 onlyExpectedDangling = false 428 s := strings.Join(funcNames, " <- ") 429 t.Logf("found unexpected dangling goroutine: %s", s) 430 } 431 } 432 } 433 } 434 return hasDangling, onlyExpectedDangling 435 } 436 437 func getFunctionName(pc uintptr) string { 438 funcName := runtime.FuncForPC(pc).Name() 439 index := strings.LastIndex(funcName, "/") 440 if index != -1 { 441 funcName = funcName[index+1:] 442 } 443 return funcName 444 }