github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/transport_test.go (about) 1 package quic 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/rand" 7 "crypto/tls" 8 "errors" 9 "net" 10 "syscall" 11 "time" 12 13 mocklogging "github.com/apernet/quic-go/internal/mocks/logging" 14 "github.com/apernet/quic-go/internal/protocol" 15 "github.com/apernet/quic-go/internal/wire" 16 "github.com/apernet/quic-go/logging" 17 18 . "github.com/onsi/ginkgo/v2" 19 . "github.com/onsi/gomega" 20 "go.uber.org/mock/gomock" 21 ) 22 23 var _ = Describe("Transport", func() { 24 type packetToRead struct { 25 addr net.Addr 26 data []byte 27 err error 28 } 29 30 getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte { 31 b, err := (&wire.ExtendedHeader{ 32 Header: wire.Header{ 33 Type: t, 34 DestConnectionID: connID, 35 Length: length, 36 Version: protocol.Version1, 37 }, 38 PacketNumberLen: protocol.PacketNumberLen2, 39 }).Append(nil, protocol.Version1) 40 Expect(err).ToNot(HaveOccurred()) 41 return b 42 } 43 44 getPacket := func(connID protocol.ConnectionID) []byte { 45 return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2) 46 } 47 48 newMockPacketConn := func(packetChan <-chan packetToRead) *MockPacketConn { 49 conn := NewMockPacketConn(mockCtrl) 50 conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() 51 conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) { 52 p, ok := <-packetChan 53 if !ok { 54 return 0, nil, errors.New("closed") 55 } 56 return copy(b, p.data), p.addr, p.err 57 }).AnyTimes() 58 // for shutdown 59 conn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes() 60 return conn 61 } 62 63 It("handles packets for different packet handlers on the same packet conn", func() { 64 packetChan := make(chan packetToRead) 65 tr := &Transport{Conn: newMockPacketConn(packetChan)} 66 tr.init(true) 67 phm := NewMockPacketHandlerManager(mockCtrl) 68 tr.handlerMap = phm 69 connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 70 connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) 71 72 handled := make(chan struct{}, 2) 73 phm.EXPECT().Get(connID1).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) { 74 h := NewMockPacketHandler(mockCtrl) 75 h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) { 76 defer GinkgoRecover() 77 connID, err := wire.ParseConnectionID(p.data, 0) 78 Expect(err).ToNot(HaveOccurred()) 79 Expect(connID).To(Equal(connID1)) 80 handled <- struct{}{} 81 }) 82 return h, true 83 }) 84 phm.EXPECT().Get(connID2).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) { 85 h := NewMockPacketHandler(mockCtrl) 86 h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) { 87 defer GinkgoRecover() 88 connID, err := wire.ParseConnectionID(p.data, 0) 89 Expect(err).ToNot(HaveOccurred()) 90 Expect(connID).To(Equal(connID2)) 91 handled <- struct{}{} 92 }) 93 return h, true 94 }) 95 96 packetChan <- packetToRead{data: getPacket(connID1)} 97 packetChan <- packetToRead{data: getPacket(connID2)} 98 99 Eventually(handled).Should(Receive()) 100 Eventually(handled).Should(Receive()) 101 102 // shutdown 103 phm.EXPECT().Close(gomock.Any()) 104 close(packetChan) 105 tr.Close() 106 }) 107 108 It("closes listeners", func() { 109 packetChan := make(chan packetToRead) 110 tr := &Transport{Conn: newMockPacketConn(packetChan)} 111 defer tr.Close() 112 ln, err := tr.Listen(&tls.Config{}, nil) 113 Expect(err).ToNot(HaveOccurred()) 114 phm := NewMockPacketHandlerManager(mockCtrl) 115 tr.handlerMap = phm 116 117 Expect(ln.Close()).To(Succeed()) 118 119 // shutdown 120 phm.EXPECT().Close(gomock.Any()) 121 close(packetChan) 122 tr.Close() 123 }) 124 125 It("closes transport concurrently with listener", func() { 126 // try 10 times to trigger race conditions 127 for i := 0; i < 10; i++ { 128 packetChan := make(chan packetToRead) 129 tr := &Transport{Conn: newMockPacketConn(packetChan)} 130 ln, err := tr.Listen(&tls.Config{}, nil) 131 Expect(err).ToNot(HaveOccurred()) 132 ch := make(chan bool) 133 // Close transport and listener concurrently. 134 go func() { 135 ch <- true 136 Expect(ln.Close()).To(Succeed()) 137 ch <- true 138 }() 139 <-ch 140 close(packetChan) 141 Expect(tr.Close()).To(Succeed()) 142 <-ch 143 } 144 }) 145 146 It("drops unparseable QUIC packets", func() { 147 addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} 148 packetChan := make(chan packetToRead) 149 t, tracer := mocklogging.NewMockTracer(mockCtrl) 150 tr := &Transport{ 151 Conn: newMockPacketConn(packetChan), 152 ConnectionIDLength: 10, 153 Tracer: t, 154 } 155 tr.init(true) 156 dropped := make(chan struct{}) 157 tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) }) 158 packetChan <- packetToRead{ 159 addr: addr, 160 data: []byte{0x40 /* set the QUIC bit */, 1, 2, 3}, 161 } 162 Eventually(dropped).Should(BeClosed()) 163 164 // shutdown 165 tracer.EXPECT().Close() 166 close(packetChan) 167 tr.Close() 168 }) 169 170 It("closes when reading from the conn fails", func() { 171 packetChan := make(chan packetToRead) 172 tr := Transport{Conn: newMockPacketConn(packetChan)} 173 defer tr.Close() 174 phm := NewMockPacketHandlerManager(mockCtrl) 175 tr.init(true) 176 tr.handlerMap = phm 177 178 done := make(chan struct{}) 179 phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) }) 180 packetChan <- packetToRead{err: errors.New("read failed")} 181 Eventually(done).Should(BeClosed()) 182 183 // shutdown 184 close(packetChan) 185 tr.Close() 186 }) 187 188 It("continues listening after temporary errors", func() { 189 packetChan := make(chan packetToRead) 190 tr := Transport{Conn: newMockPacketConn(packetChan)} 191 defer tr.Close() 192 phm := NewMockPacketHandlerManager(mockCtrl) 193 tr.init(true) 194 tr.handlerMap = phm 195 196 tempErr := deadlineError{} 197 Expect(tempErr.Temporary()).To(BeTrue()) 198 packetChan <- packetToRead{err: tempErr} 199 // don't expect any calls to phm.Close 200 time.Sleep(50 * time.Millisecond) 201 202 // shutdown 203 phm.EXPECT().Close(gomock.Any()) 204 close(packetChan) 205 tr.Close() 206 }) 207 208 It("handles short header packets resets", func() { 209 connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) 210 packetChan := make(chan packetToRead) 211 tr := Transport{ 212 Conn: newMockPacketConn(packetChan), 213 ConnectionIDLength: connID.Len(), 214 } 215 tr.init(true) 216 defer tr.Close() 217 phm := NewMockPacketHandlerManager(mockCtrl) 218 tr.handlerMap = phm 219 220 var token protocol.StatelessResetToken 221 rand.Read(token[:]) 222 223 var b []byte 224 b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne) 225 Expect(err).ToNot(HaveOccurred()) 226 b = append(b, token[:]...) 227 conn := NewMockPacketHandler(mockCtrl) 228 gomock.InOrder( 229 phm.EXPECT().Get(connID).Return(conn, true), 230 conn.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) { 231 Expect(p.data).To(Equal(b)) 232 Expect(p.rcvTime).To(BeTemporally("~", time.Now(), time.Second)) 233 }), 234 ) 235 packetChan <- packetToRead{data: b} 236 237 // shutdown 238 phm.EXPECT().Close(gomock.Any()) 239 close(packetChan) 240 tr.Close() 241 }) 242 243 It("handles stateless resets", func() { 244 connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) 245 packetChan := make(chan packetToRead) 246 tr := Transport{ 247 Conn: newMockPacketConn(packetChan), 248 ConnectionIDLength: connID.Len(), 249 } 250 tr.init(true) 251 defer tr.Close() 252 phm := NewMockPacketHandlerManager(mockCtrl) 253 tr.handlerMap = phm 254 255 var token protocol.StatelessResetToken 256 rand.Read(token[:]) 257 258 var b []byte 259 b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne) 260 Expect(err).ToNot(HaveOccurred()) 261 b = append(b, token[:]...) 262 conn := NewMockPacketHandler(mockCtrl) 263 destroyed := make(chan struct{}) 264 gomock.InOrder( 265 phm.EXPECT().Get(connID), 266 phm.EXPECT().GetByResetToken(token).Return(conn, true), 267 conn.EXPECT().destroy(gomock.Any()).Do(func(err error) { 268 Expect(err).To(MatchError(&StatelessResetError{Token: token})) 269 close(destroyed) 270 }), 271 ) 272 packetChan <- packetToRead{data: b} 273 Eventually(destroyed).Should(BeClosed()) 274 275 // shutdown 276 phm.EXPECT().Close(gomock.Any()) 277 close(packetChan) 278 tr.Close() 279 }) 280 281 It("sends stateless resets", func() { 282 connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) 283 packetChan := make(chan packetToRead) 284 conn := newMockPacketConn(packetChan) 285 tr := Transport{ 286 Conn: conn, 287 StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}, 288 ConnectionIDLength: connID.Len(), 289 } 290 tr.init(true) 291 defer tr.Close() 292 phm := NewMockPacketHandlerManager(mockCtrl) 293 tr.handlerMap = phm 294 295 var b []byte 296 b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne) 297 Expect(err).ToNot(HaveOccurred()) 298 b = append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...) 299 300 var token protocol.StatelessResetToken 301 rand.Read(token[:]) 302 written := make(chan struct{}) 303 gomock.InOrder( 304 phm.EXPECT().Get(connID), 305 phm.EXPECT().GetByResetToken(gomock.Any()), 306 phm.EXPECT().GetStatelessResetToken(connID).Return(token), 307 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func(b []byte, _ net.Addr) (int, error) { 308 defer close(written) 309 Expect(bytes.Contains(b, token[:])).To(BeTrue()) 310 return len(b), nil 311 }), 312 ) 313 packetChan <- packetToRead{data: b} 314 Eventually(written).Should(BeClosed()) 315 316 // shutdown 317 phm.EXPECT().Close(gomock.Any()) 318 close(packetChan) 319 tr.Close() 320 }) 321 322 It("closes uninitialized Transport and closes underlying PacketConn", func() { 323 packetChan := make(chan packetToRead) 324 pconn := newMockPacketConn(packetChan) 325 326 tr := &Transport{ 327 Conn: pconn, 328 createdConn: true, // owns pconn 329 } 330 // NO init 331 332 // shutdown 333 close(packetChan) 334 pconn.EXPECT().Close() 335 Expect(tr.Close()).To(Succeed()) 336 }) 337 338 It("doesn't add the PacketConn to the multiplexer if (*Transport).init fails", func() { 339 packetChan := make(chan packetToRead) 340 pconn := newMockPacketConn(packetChan) 341 syscallconn := &mockSyscallConn{pconn} 342 343 tr := &Transport{ 344 Conn: syscallconn, 345 } 346 347 err := tr.init(false) 348 Expect(err).To(HaveOccurred()) 349 conns := getMultiplexer().(*connMultiplexer).conns 350 Expect(len(conns)).To(BeZero()) 351 }) 352 353 It("allows receiving non-QUIC packets", func() { 354 remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} 355 packetChan := make(chan packetToRead) 356 tr := &Transport{ 357 Conn: newMockPacketConn(packetChan), 358 ConnectionIDLength: 10, 359 } 360 tr.init(true) 361 receivedPacketChan := make(chan []byte) 362 go func() { 363 defer GinkgoRecover() 364 b := make([]byte, 100) 365 n, addr, err := tr.ReadNonQUICPacket(context.Background(), b) 366 Expect(err).ToNot(HaveOccurred()) 367 Expect(addr).To(Equal(remoteAddr)) 368 receivedPacketChan <- b[:n] 369 }() 370 // Receiving of non-QUIC packets is enabled when ReadNonQUICPacket is called. 371 // Give the Go routine some time to spin up. 372 time.Sleep(scaleDuration(50 * time.Millisecond)) 373 packetChan <- packetToRead{ 374 addr: remoteAddr, 375 data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3}, 376 } 377 378 Eventually(receivedPacketChan).Should(Receive(Equal([]byte{0, 1, 2, 3}))) 379 380 // shutdown 381 close(packetChan) 382 tr.Close() 383 }) 384 385 It("drops non-QUIC packet if the application doesn't process them quickly enough", func() { 386 remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} 387 packetChan := make(chan packetToRead) 388 t, tracer := mocklogging.NewMockTracer(mockCtrl) 389 tr := &Transport{ 390 Conn: newMockPacketConn(packetChan), 391 ConnectionIDLength: 10, 392 Tracer: t, 393 } 394 tr.init(true) 395 396 ctx, cancel := context.WithCancel(context.Background()) 397 cancel() 398 _, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 10)) 399 Expect(err).To(MatchError(context.Canceled)) 400 401 for i := 0; i < maxQueuedNonQUICPackets; i++ { 402 packetChan <- packetToRead{ 403 addr: remoteAddr, 404 data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3}, 405 } 406 } 407 408 done := make(chan struct{}) 409 tracer.EXPECT().DroppedPacket(remoteAddr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 410 close(done) 411 }) 412 packetChan <- packetToRead{ 413 addr: remoteAddr, 414 data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3}, 415 } 416 Eventually(done).Should(BeClosed()) 417 418 // shutdown 419 tracer.EXPECT().Close() 420 close(packetChan) 421 tr.Close() 422 }) 423 424 remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 3, 5, 7), Port: 1234} 425 DescribeTable("setting the tls.Config.ServerName", 426 func(expected string, conf *tls.Config, addr net.Addr, host string) { 427 setTLSConfigServerName(conf, addr, host) 428 Expect(conf.ServerName).To(Equal(expected)) 429 }, 430 Entry("uses the value from the config", "foo.bar", &tls.Config{ServerName: "foo.bar"}, remoteAddr, "baz.foo"), 431 Entry("uses the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org"), 432 Entry("removes the port from the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org:1234"), 433 Entry("uses the IP", "1.3.5.7", &tls.Config{}, remoteAddr, ""), 434 ) 435 }) 436 437 type mockSyscallConn struct { 438 net.PacketConn 439 } 440 441 func (c *mockSyscallConn) SyscallConn() (syscall.RawConn, error) { 442 return nil, errors.New("mocked") 443 }