github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/integrationtests/self/mitm_test.go (about) 1 package self_test 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "math" 9 "net" 10 "sync/atomic" 11 "time" 12 13 "golang.org/x/exp/rand" 14 15 "github.com/apernet/quic-go" 16 quicproxy "github.com/apernet/quic-go/integrationtests/tools/proxy" 17 "github.com/apernet/quic-go/internal/protocol" 18 "github.com/apernet/quic-go/internal/wire" 19 "github.com/apernet/quic-go/testutils" 20 21 . "github.com/onsi/ginkgo/v2" 22 . "github.com/onsi/gomega" 23 ) 24 25 var _ = Describe("MITM test", func() { 26 const connIDLen = 6 // explicitly set the connection ID length, so the proxy can parse it 27 28 var ( 29 clientUDPConn net.PacketConn 30 serverTransport, clientTransport *quic.Transport 31 serverConn quic.Connection 32 serverConfig *quic.Config 33 ) 34 35 startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback, forceAddressValidation bool) (proxyPort int, closeFn func()) { 36 addr, err := net.ResolveUDPAddr("udp", "localhost:0") 37 Expect(err).ToNot(HaveOccurred()) 38 c, err := net.ListenUDP("udp", addr) 39 Expect(err).ToNot(HaveOccurred()) 40 serverTransport = &quic.Transport{ 41 Conn: c, 42 ConnectionIDLength: connIDLen, 43 } 44 addTracer(serverTransport) 45 if forceAddressValidation { 46 serverTransport.VerifySourceAddress = func(net.Addr) bool { return true } 47 } 48 ln, err := serverTransport.Listen(getTLSConfig(), serverConfig) 49 Expect(err).ToNot(HaveOccurred()) 50 done := make(chan struct{}) 51 go func() { 52 defer GinkgoRecover() 53 defer close(done) 54 var err error 55 serverConn, err = ln.Accept(context.Background()) 56 if err != nil { 57 return 58 } 59 str, err := serverConn.OpenUniStream() 60 Expect(err).ToNot(HaveOccurred()) 61 _, err = str.Write(PRData) 62 Expect(err).ToNot(HaveOccurred()) 63 Expect(str.Close()).To(Succeed()) 64 }() 65 serverPort := ln.Addr().(*net.UDPAddr).Port 66 proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ 67 RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), 68 DelayPacket: delayCb, 69 DropPacket: dropCb, 70 }) 71 Expect(err).ToNot(HaveOccurred()) 72 return proxy.LocalPort(), func() { 73 proxy.Close() 74 ln.Close() 75 serverTransport.Close() 76 <-done 77 } 78 } 79 80 BeforeEach(func() { 81 serverConfig = getQuicConfig(nil) 82 addr, err := net.ResolveUDPAddr("udp", "localhost:0") 83 Expect(err).ToNot(HaveOccurred()) 84 clientUDPConn, err = net.ListenUDP("udp", addr) 85 Expect(err).ToNot(HaveOccurred()) 86 clientTransport = &quic.Transport{ 87 Conn: clientUDPConn, 88 ConnectionIDLength: connIDLen, 89 } 90 addTracer(clientTransport) 91 }) 92 93 Context("unsuccessful attacks", func() { 94 AfterEach(func() { 95 Eventually(serverConn.Context().Done()).Should(BeClosed()) 96 // Test shutdown is tricky due to the proxy. Just wait for a bit. 97 time.Sleep(50 * time.Millisecond) 98 Expect(clientUDPConn.Close()).To(Succeed()) 99 Expect(clientTransport.Close()).To(Succeed()) 100 }) 101 102 Context("injecting invalid packets", func() { 103 const rtt = 20 * time.Millisecond 104 105 sendRandomPacketsOfSameType := func(conn *quic.Transport, remoteAddr net.Addr, raw []byte) { 106 defer GinkgoRecover() 107 const numPackets = 10 108 ticker := time.NewTicker(rtt / numPackets) 109 defer ticker.Stop() 110 111 if wire.IsLongHeaderPacket(raw[0]) { 112 hdr, _, _, err := wire.ParsePacket(raw) 113 Expect(err).ToNot(HaveOccurred()) 114 replyHdr := &wire.ExtendedHeader{ 115 Header: wire.Header{ 116 DestConnectionID: hdr.DestConnectionID, 117 SrcConnectionID: hdr.SrcConnectionID, 118 Type: hdr.Type, 119 Version: hdr.Version, 120 }, 121 PacketNumber: protocol.PacketNumber(rand.Int31n(math.MaxInt32 / 4)), 122 PacketNumberLen: protocol.PacketNumberLen(rand.Int31n(4) + 1), 123 } 124 125 for i := 0; i < numPackets; i++ { 126 payloadLen := rand.Int31n(100) 127 replyHdr.Length = protocol.ByteCount(rand.Int31n(payloadLen + 1)) 128 b, err := replyHdr.Append(nil, hdr.Version) 129 Expect(err).ToNot(HaveOccurred()) 130 r := make([]byte, payloadLen) 131 rand.Read(r) 132 b = append(b, r...) 133 if _, err := conn.WriteTo(b, remoteAddr); err != nil { 134 return 135 } 136 <-ticker.C 137 } 138 } else { 139 connID, err := wire.ParseConnectionID(raw, connIDLen) 140 Expect(err).ToNot(HaveOccurred()) 141 _, pn, pnLen, _, err := wire.ParseShortHeader(raw, connIDLen) 142 if err != nil { // normally, ParseShortHeader is called after decrypting the header 143 Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) 144 } 145 for i := 0; i < numPackets; i++ { 146 b, err := wire.AppendShortHeader(nil, connID, pn, pnLen, protocol.KeyPhaseBit(rand.Intn(2))) 147 Expect(err).ToNot(HaveOccurred()) 148 payloadLen := rand.Int31n(100) 149 r := make([]byte, payloadLen) 150 rand.Read(r) 151 b = append(b, r...) 152 if _, err := conn.WriteTo(b, remoteAddr); err != nil { 153 return 154 } 155 <-ticker.C 156 } 157 } 158 } 159 160 runTest := func(delayCb quicproxy.DelayCallback) { 161 proxyPort, closeFn := startServerAndProxy(delayCb, nil, false) 162 defer closeFn() 163 raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) 164 Expect(err).ToNot(HaveOccurred()) 165 conn, err := clientTransport.Dial( 166 context.Background(), 167 raddr, 168 getTLSClientConfig(), 169 getQuicConfig(nil), 170 ) 171 Expect(err).ToNot(HaveOccurred()) 172 str, err := conn.AcceptUniStream(context.Background()) 173 Expect(err).ToNot(HaveOccurred()) 174 data, err := io.ReadAll(str) 175 Expect(err).ToNot(HaveOccurred()) 176 Expect(data).To(Equal(PRData)) 177 Expect(conn.CloseWithError(0, "")).To(Succeed()) 178 } 179 180 It("downloads a message when the packets are injected towards the server", func() { 181 delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { 182 if dir == quicproxy.DirectionIncoming { 183 defer GinkgoRecover() 184 go sendRandomPacketsOfSameType(clientTransport, serverTransport.Conn.LocalAddr(), raw) 185 } 186 return rtt / 2 187 } 188 runTest(delayCb) 189 }) 190 191 It("downloads a message when the packets are injected towards the client", func() { 192 delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { 193 if dir == quicproxy.DirectionOutgoing { 194 defer GinkgoRecover() 195 go sendRandomPacketsOfSameType(serverTransport, clientTransport.Conn.LocalAddr(), raw) 196 } 197 return rtt / 2 198 } 199 runTest(delayCb) 200 }) 201 }) 202 203 runTest := func(dropCb quicproxy.DropCallback) { 204 proxyPort, closeFn := startServerAndProxy(nil, dropCb, false) 205 defer closeFn() 206 raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) 207 Expect(err).ToNot(HaveOccurred()) 208 conn, err := clientTransport.Dial( 209 context.Background(), 210 raddr, 211 getTLSClientConfig(), 212 getQuicConfig(nil), 213 ) 214 Expect(err).ToNot(HaveOccurred()) 215 str, err := conn.AcceptUniStream(context.Background()) 216 Expect(err).ToNot(HaveOccurred()) 217 data, err := io.ReadAll(str) 218 Expect(err).ToNot(HaveOccurred()) 219 Expect(data).To(Equal(PRData)) 220 Expect(conn.CloseWithError(0, "")).To(Succeed()) 221 } 222 223 Context("duplicating packets", func() { 224 It("downloads a message when packets are duplicated towards the server", func() { 225 dropCb := func(dir quicproxy.Direction, raw []byte) bool { 226 defer GinkgoRecover() 227 if dir == quicproxy.DirectionIncoming { 228 _, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr()) 229 Expect(err).ToNot(HaveOccurred()) 230 } 231 return false 232 } 233 runTest(dropCb) 234 }) 235 236 It("downloads a message when packets are duplicated towards the client", func() { 237 dropCb := func(dir quicproxy.Direction, raw []byte) bool { 238 defer GinkgoRecover() 239 if dir == quicproxy.DirectionOutgoing { 240 _, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr()) 241 Expect(err).ToNot(HaveOccurred()) 242 } 243 return false 244 } 245 runTest(dropCb) 246 }) 247 }) 248 249 Context("corrupting packets", func() { 250 const idleTimeout = time.Second 251 252 var numCorrupted, numPackets atomic.Int32 253 254 BeforeEach(func() { 255 numCorrupted.Store(0) 256 numPackets.Store(0) 257 serverConfig.MaxIdleTimeout = idleTimeout 258 }) 259 260 AfterEach(func() { 261 num := numCorrupted.Load() 262 fmt.Fprintf(GinkgoWriter, "Corrupted %d of %d packets.", num, numPackets.Load()) 263 Expect(num).To(BeNumerically(">=", 1)) 264 // If the packet containing the CONNECTION_CLOSE is corrupted, 265 // we have to wait for the connection to time out. 266 Eventually(serverConn.Context().Done(), 3*idleTimeout).Should(BeClosed()) 267 }) 268 269 It("downloads a message when packet are corrupted towards the server", func() { 270 const interval = 4 // corrupt every 4th packet (stochastically) 271 dropCb := func(dir quicproxy.Direction, raw []byte) bool { 272 defer GinkgoRecover() 273 if dir == quicproxy.DirectionIncoming { 274 numPackets.Add(1) 275 if rand.Intn(interval) == 0 { 276 pos := rand.Intn(len(raw)) 277 raw[pos] = byte(rand.Intn(256)) 278 _, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr()) 279 Expect(err).ToNot(HaveOccurred()) 280 numCorrupted.Add(1) 281 return true 282 } 283 } 284 return false 285 } 286 runTest(dropCb) 287 }) 288 289 It("downloads a message when packet are corrupted towards the client", func() { 290 const interval = 10 // corrupt every 10th packet (stochastically) 291 dropCb := func(dir quicproxy.Direction, raw []byte) bool { 292 defer GinkgoRecover() 293 if dir == quicproxy.DirectionOutgoing { 294 numPackets.Add(1) 295 if rand.Intn(interval) == 0 { 296 pos := rand.Intn(len(raw)) 297 raw[pos] = byte(rand.Intn(256)) 298 _, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr()) 299 Expect(err).ToNot(HaveOccurred()) 300 numCorrupted.Add(1) 301 return true 302 } 303 } 304 return false 305 } 306 runTest(dropCb) 307 }) 308 }) 309 }) 310 311 Context("successful injection attacks", func() { 312 // These tests demonstrate that the QUIC protocol is vulnerable to injection attacks before the handshake 313 // finishes. In particular, an adversary who can intercept packets coming from one endpoint and send a reply 314 // that arrives before the real reply can tear down the connection in multiple ways. 315 316 const rtt = 20 * time.Millisecond 317 318 runTest := func(proxyPort int) (closeFn func(), err error) { 319 raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) 320 Expect(err).ToNot(HaveOccurred()) 321 _, err = clientTransport.Dial( 322 context.Background(), 323 raddr, 324 getTLSClientConfig(), 325 getQuicConfig(&quic.Config{HandshakeIdleTimeout: scaleDuration(200 * time.Millisecond)}), 326 ) 327 return func() { clientTransport.Close() }, err 328 } 329 330 // fails immediately because client connection closes when it can't find compatible version 331 It("fails when a forged version negotiation packet is sent to client", func() { 332 done := make(chan struct{}) 333 delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { 334 if dir == quicproxy.DirectionIncoming { 335 defer GinkgoRecover() 336 337 hdr, _, _, err := wire.ParsePacket(raw) 338 Expect(err).ToNot(HaveOccurred()) 339 340 if hdr.Type != protocol.PacketTypeInitial { 341 return 0 342 } 343 344 // Create fake version negotiation packet with no supported versions 345 versions := []protocol.Version{} 346 packet := wire.ComposeVersionNegotiation( 347 protocol.ArbitraryLenConnectionID(hdr.SrcConnectionID.Bytes()), 348 protocol.ArbitraryLenConnectionID(hdr.DestConnectionID.Bytes()), 349 versions, 350 ) 351 352 // Send the packet 353 _, err = serverTransport.WriteTo(packet, clientTransport.Conn.LocalAddr()) 354 Expect(err).ToNot(HaveOccurred()) 355 close(done) 356 } 357 return rtt / 2 358 } 359 proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false) 360 defer serverCloseFn() 361 closeFn, err := runTest(proxyPort) 362 defer closeFn() 363 Expect(err).To(HaveOccurred()) 364 vnErr := &quic.VersionNegotiationError{} 365 Expect(errors.As(err, &vnErr)).To(BeTrue()) 366 Eventually(done).Should(BeClosed()) 367 }) 368 369 // times out, because client doesn't accept subsequent real retry packets from server 370 // as it has already accepted a retry. 371 // TODO: determine behavior when server does not send Retry packets 372 It("fails when a forged Retry packet with modified Source Connection ID is sent to client", func() { 373 var initialPacketIntercepted bool 374 done := make(chan struct{}) 375 delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { 376 if dir == quicproxy.DirectionIncoming && !initialPacketIntercepted { 377 defer GinkgoRecover() 378 defer close(done) 379 380 hdr, _, _, err := wire.ParsePacket(raw) 381 Expect(err).ToNot(HaveOccurred()) 382 383 if hdr.Type != protocol.PacketTypeInitial { 384 return 0 385 } 386 387 initialPacketIntercepted = true 388 fakeSrcConnID := protocol.ParseConnectionID([]byte{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12}) 389 retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version) 390 391 _, err = serverTransport.WriteTo(retryPacket, clientTransport.Conn.LocalAddr()) 392 Expect(err).ToNot(HaveOccurred()) 393 } 394 return rtt / 2 395 } 396 proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, true) 397 defer serverCloseFn() 398 closeFn, err := runTest(proxyPort) 399 defer closeFn() 400 Expect(err).To(HaveOccurred()) 401 Expect(err.(net.Error).Timeout()).To(BeTrue()) 402 Eventually(done).Should(BeClosed()) 403 }) 404 405 // times out, because client doesn't accept real retry packets from server because 406 // it has already accepted an initial. 407 // TODO: determine behavior when server does not send Retry packets 408 It("fails when a forged initial packet is sent to client", func() { 409 done := make(chan struct{}) 410 var injected bool 411 delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { 412 if dir == quicproxy.DirectionIncoming { 413 defer GinkgoRecover() 414 415 hdr, _, _, err := wire.ParsePacket(raw) 416 Expect(err).ToNot(HaveOccurred()) 417 if hdr.Type != protocol.PacketTypeInitial || injected { 418 return 0 419 } 420 defer close(done) 421 injected = true 422 initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.DestConnectionID, nil, nil, protocol.PerspectiveServer, hdr.Version) 423 _, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr()) 424 Expect(err).ToNot(HaveOccurred()) 425 } 426 return rtt 427 } 428 proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false) 429 defer serverCloseFn() 430 closeFn, err := runTest(proxyPort) 431 defer closeFn() 432 Expect(err).To(HaveOccurred()) 433 Expect(err.(net.Error).Timeout()).To(BeTrue()) 434 Eventually(done).Should(BeClosed()) 435 }) 436 437 // client connection closes immediately on receiving ack for unsent packet 438 It("fails when a forged initial packet with ack for unsent packet is sent to client", func() { 439 done := make(chan struct{}) 440 var injected bool 441 delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { 442 if dir == quicproxy.DirectionIncoming { 443 defer GinkgoRecover() 444 445 hdr, _, _, err := wire.ParsePacket(raw) 446 Expect(err).ToNot(HaveOccurred()) 447 if hdr.Type != protocol.PacketTypeInitial || injected { 448 return 0 449 } 450 defer close(done) 451 injected = true 452 // Fake Initial with ACK for packet 2 (unsent) 453 ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} 454 initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.DestConnectionID, nil, []wire.Frame{ack}, protocol.PerspectiveServer, hdr.Version) 455 _, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr()) 456 Expect(err).ToNot(HaveOccurred()) 457 } 458 return rtt 459 } 460 proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false) 461 defer serverCloseFn() 462 closeFn, err := runTest(proxyPort) 463 defer closeFn() 464 Expect(err).To(HaveOccurred()) 465 var transportErr *quic.TransportError 466 Expect(errors.As(err, &transportErr)).To(BeTrue()) 467 Expect(transportErr.ErrorCode).To(Equal(quic.ProtocolViolation)) 468 Expect(transportErr.ErrorMessage).To(ContainSubstring("received ACK for an unsent packet")) 469 Eventually(done).Should(BeClosed()) 470 }) 471 }) 472 })