github.com/MerlinKodo/quic-go@v0.39.2/transport_test.go (about) 1 package quic 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/rand" 7 "crypto/tls" 8 "errors" 9 "net" 10 "syscall" 11 "time" 12 13 mocklogging "github.com/MerlinKodo/quic-go/internal/mocks/logging" 14 "github.com/MerlinKodo/quic-go/internal/protocol" 15 "github.com/MerlinKodo/quic-go/internal/wire" 16 "github.com/MerlinKodo/quic-go/logging" 17 18 . "github.com/onsi/ginkgo/v2" 19 . "github.com/onsi/gomega" 20 "go.uber.org/mock/gomock" 21 ) 22 23 var _ = Describe("Transport", func() { 24 type packetToRead struct { 25 addr net.Addr 26 data []byte 27 err error 28 } 29 30 getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte { 31 b, err := (&wire.ExtendedHeader{ 32 Header: wire.Header{ 33 Type: t, 34 DestConnectionID: connID, 35 Length: length, 36 Version: protocol.Version1, 37 }, 38 PacketNumberLen: protocol.PacketNumberLen2, 39 }).Append(nil, protocol.Version1) 40 Expect(err).ToNot(HaveOccurred()) 41 return b 42 } 43 44 getPacket := func(connID protocol.ConnectionID) []byte { 45 return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2) 46 } 47 48 newMockPacketConn := func(packetChan <-chan packetToRead) *MockPacketConn { 49 conn := NewMockPacketConn(mockCtrl) 50 conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() 51 conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) { 52 p, ok := <-packetChan 53 if !ok { 54 return 0, nil, errors.New("closed") 55 } 56 return copy(b, p.data), p.addr, p.err 57 }).AnyTimes() 58 // for shutdown 59 conn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes() 60 return conn 61 } 62 63 It("handles packets for different packet handlers on the same packet conn", func() { 64 packetChan := make(chan packetToRead) 65 tr := &Transport{Conn: newMockPacketConn(packetChan)} 66 tr.init(true) 67 phm := NewMockPacketHandlerManager(mockCtrl) 68 tr.handlerMap = phm 69 connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 70 connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) 71 72 handled := make(chan struct{}, 2) 73 phm.EXPECT().Get(connID1).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) { 74 h := NewMockPacketHandler(mockCtrl) 75 h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) { 76 defer GinkgoRecover() 77 connID, err := wire.ParseConnectionID(p.data, 0) 78 Expect(err).ToNot(HaveOccurred()) 79 Expect(connID).To(Equal(connID1)) 80 handled <- struct{}{} 81 }) 82 return h, true 83 }) 84 phm.EXPECT().Get(connID2).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) { 85 h := NewMockPacketHandler(mockCtrl) 86 h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) { 87 defer GinkgoRecover() 88 connID, err := wire.ParseConnectionID(p.data, 0) 89 Expect(err).ToNot(HaveOccurred()) 90 Expect(connID).To(Equal(connID2)) 91 handled <- struct{}{} 92 }) 93 return h, true 94 }) 95 96 packetChan <- packetToRead{data: getPacket(connID1)} 97 packetChan <- packetToRead{data: getPacket(connID2)} 98 99 Eventually(handled).Should(Receive()) 100 Eventually(handled).Should(Receive()) 101 102 // shutdown 103 phm.EXPECT().Close(gomock.Any()) 104 close(packetChan) 105 tr.Close() 106 }) 107 108 It("closes listeners", func() { 109 packetChan := make(chan packetToRead) 110 tr := &Transport{Conn: newMockPacketConn(packetChan)} 111 defer tr.Close() 112 ln, err := tr.Listen(&tls.Config{}, nil) 113 Expect(err).ToNot(HaveOccurred()) 114 phm := NewMockPacketHandlerManager(mockCtrl) 115 tr.handlerMap = phm 116 117 phm.EXPECT().CloseServer() 118 Expect(ln.Close()).To(Succeed()) 119 120 // shutdown 121 phm.EXPECT().Close(gomock.Any()) 122 close(packetChan) 123 tr.Close() 124 }) 125 126 It("drops unparseable QUIC packets", func() { 127 addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} 128 packetChan := make(chan packetToRead) 129 t, tracer := mocklogging.NewMockTracer(mockCtrl) 130 tr := &Transport{ 131 Conn: newMockPacketConn(packetChan), 132 ConnectionIDLength: 10, 133 Tracer: t, 134 } 135 tr.init(true) 136 dropped := make(chan struct{}) 137 tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) }) 138 packetChan <- packetToRead{ 139 addr: addr, 140 data: []byte{0x40 /* set the QUIC bit */, 1, 2, 3}, 141 } 142 Eventually(dropped).Should(BeClosed()) 143 144 // shutdown 145 close(packetChan) 146 tr.Close() 147 }) 148 149 It("closes when reading from the conn fails", func() { 150 packetChan := make(chan packetToRead) 151 tr := Transport{Conn: newMockPacketConn(packetChan)} 152 defer tr.Close() 153 phm := NewMockPacketHandlerManager(mockCtrl) 154 tr.init(true) 155 tr.handlerMap = phm 156 157 done := make(chan struct{}) 158 phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) }) 159 packetChan <- packetToRead{err: errors.New("read failed")} 160 Eventually(done).Should(BeClosed()) 161 162 // shutdown 163 close(packetChan) 164 tr.Close() 165 }) 166 167 It("continues listening after temporary errors", func() { 168 packetChan := make(chan packetToRead) 169 tr := Transport{Conn: newMockPacketConn(packetChan)} 170 defer tr.Close() 171 phm := NewMockPacketHandlerManager(mockCtrl) 172 tr.init(true) 173 tr.handlerMap = phm 174 175 tempErr := deadlineError{} 176 Expect(tempErr.Temporary()).To(BeTrue()) 177 packetChan <- packetToRead{err: tempErr} 178 // don't expect any calls to phm.Close 179 time.Sleep(50 * time.Millisecond) 180 181 // shutdown 182 phm.EXPECT().Close(gomock.Any()) 183 close(packetChan) 184 tr.Close() 185 }) 186 187 It("handles short header packets resets", func() { 188 connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) 189 packetChan := make(chan packetToRead) 190 tr := Transport{ 191 Conn: newMockPacketConn(packetChan), 192 ConnectionIDLength: connID.Len(), 193 } 194 tr.init(true) 195 defer tr.Close() 196 phm := NewMockPacketHandlerManager(mockCtrl) 197 tr.handlerMap = phm 198 199 var token protocol.StatelessResetToken 200 rand.Read(token[:]) 201 202 var b []byte 203 b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne) 204 Expect(err).ToNot(HaveOccurred()) 205 b = append(b, token[:]...) 206 conn := NewMockPacketHandler(mockCtrl) 207 gomock.InOrder( 208 phm.EXPECT().GetByResetToken(token), 209 phm.EXPECT().Get(connID).Return(conn, true), 210 conn.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) { 211 Expect(p.data).To(Equal(b)) 212 Expect(p.rcvTime).To(BeTemporally("~", time.Now(), time.Second)) 213 }), 214 ) 215 packetChan <- packetToRead{data: b} 216 217 // shutdown 218 phm.EXPECT().Close(gomock.Any()) 219 close(packetChan) 220 tr.Close() 221 }) 222 223 It("handles stateless resets", func() { 224 connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) 225 packetChan := make(chan packetToRead) 226 tr := Transport{Conn: newMockPacketConn(packetChan)} 227 tr.init(true) 228 defer tr.Close() 229 phm := NewMockPacketHandlerManager(mockCtrl) 230 tr.handlerMap = phm 231 232 var token protocol.StatelessResetToken 233 rand.Read(token[:]) 234 235 var b []byte 236 b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne) 237 Expect(err).ToNot(HaveOccurred()) 238 b = append(b, token[:]...) 239 conn := NewMockPacketHandler(mockCtrl) 240 destroyed := make(chan struct{}) 241 gomock.InOrder( 242 phm.EXPECT().GetByResetToken(token).Return(conn, true), 243 conn.EXPECT().destroy(gomock.Any()).Do(func(err error) { 244 Expect(err).To(MatchError(&StatelessResetError{Token: token})) 245 close(destroyed) 246 }), 247 ) 248 packetChan <- packetToRead{data: b} 249 Eventually(destroyed).Should(BeClosed()) 250 251 // shutdown 252 phm.EXPECT().Close(gomock.Any()) 253 close(packetChan) 254 tr.Close() 255 }) 256 257 It("sends stateless resets", func() { 258 connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) 259 packetChan := make(chan packetToRead) 260 conn := newMockPacketConn(packetChan) 261 tr := Transport{ 262 Conn: conn, 263 StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}, 264 ConnectionIDLength: connID.Len(), 265 } 266 tr.init(true) 267 defer tr.Close() 268 phm := NewMockPacketHandlerManager(mockCtrl) 269 tr.handlerMap = phm 270 271 var b []byte 272 b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne) 273 Expect(err).ToNot(HaveOccurred()) 274 b = append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...) 275 276 var token protocol.StatelessResetToken 277 rand.Read(token[:]) 278 written := make(chan struct{}) 279 gomock.InOrder( 280 phm.EXPECT().GetByResetToken(gomock.Any()), 281 phm.EXPECT().Get(connID), 282 phm.EXPECT().GetStatelessResetToken(connID).Return(token), 283 conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func(b []byte, _ net.Addr) { 284 defer close(written) 285 Expect(bytes.Contains(b, token[:])).To(BeTrue()) 286 }), 287 ) 288 packetChan <- packetToRead{data: b} 289 Eventually(written).Should(BeClosed()) 290 291 // shutdown 292 phm.EXPECT().Close(gomock.Any()) 293 close(packetChan) 294 tr.Close() 295 }) 296 297 It("closes uninitialized Transport and closes underlying PacketConn", func() { 298 packetChan := make(chan packetToRead) 299 pconn := newMockPacketConn(packetChan) 300 301 tr := &Transport{ 302 Conn: pconn, 303 createdConn: true, // owns pconn 304 } 305 // NO init 306 307 // shutdown 308 close(packetChan) 309 pconn.EXPECT().Close() 310 Expect(tr.Close()).To(Succeed()) 311 }) 312 313 It("doesn't add the PacketConn to the multiplexer if (*Transport).init fails", func() { 314 packetChan := make(chan packetToRead) 315 pconn := newMockPacketConn(packetChan) 316 syscallconn := &mockSyscallConn{pconn} 317 318 tr := &Transport{ 319 Conn: syscallconn, 320 } 321 322 err := tr.init(false) 323 Expect(err).To(HaveOccurred()) 324 conns := getMultiplexer().(*connMultiplexer).conns 325 Expect(len(conns)).To(BeZero()) 326 }) 327 328 It("allows receiving non-QUIC packets", func() { 329 remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} 330 packetChan := make(chan packetToRead) 331 tr := &Transport{ 332 Conn: newMockPacketConn(packetChan), 333 ConnectionIDLength: 10, 334 } 335 tr.init(true) 336 receivedPacketChan := make(chan []byte) 337 go func() { 338 defer GinkgoRecover() 339 b := make([]byte, 100) 340 n, addr, err := tr.ReadNonQUICPacket(context.Background(), b) 341 Expect(err).ToNot(HaveOccurred()) 342 Expect(addr).To(Equal(remoteAddr)) 343 receivedPacketChan <- b[:n] 344 }() 345 // Receiving of non-QUIC packets is enabled when ReadNonQUICPacket is called. 346 // Give the Go routine some time to spin up. 347 time.Sleep(scaleDuration(50 * time.Millisecond)) 348 packetChan <- packetToRead{ 349 addr: remoteAddr, 350 data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3}, 351 } 352 353 Eventually(receivedPacketChan).Should(Receive(Equal([]byte{0, 1, 2, 3}))) 354 355 // shutdown 356 close(packetChan) 357 tr.Close() 358 }) 359 360 It("drops non-QUIC packet if the application doesn't process them quickly enough", func() { 361 remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} 362 packetChan := make(chan packetToRead) 363 t, tracer := mocklogging.NewMockTracer(mockCtrl) 364 tr := &Transport{ 365 Conn: newMockPacketConn(packetChan), 366 ConnectionIDLength: 10, 367 Tracer: t, 368 } 369 tr.init(true) 370 371 ctx, cancel := context.WithCancel(context.Background()) 372 cancel() 373 _, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 10)) 374 Expect(err).To(MatchError(context.Canceled)) 375 376 for i := 0; i < maxQueuedNonQUICPackets; i++ { 377 packetChan <- packetToRead{ 378 addr: remoteAddr, 379 data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3}, 380 } 381 } 382 383 done := make(chan struct{}) 384 tracer.EXPECT().DroppedPacket(remoteAddr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { 385 close(done) 386 }) 387 packetChan <- packetToRead{ 388 addr: remoteAddr, 389 data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3}, 390 } 391 Eventually(done).Should(BeClosed()) 392 393 // shutdown 394 close(packetChan) 395 tr.Close() 396 }) 397 398 remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 3, 5, 7), Port: 1234} 399 DescribeTable("setting the tls.Config.ServerName", 400 func(expected string, conf *tls.Config, addr net.Addr, host string) { 401 setTLSConfigServerName(conf, addr, host) 402 Expect(conf.ServerName).To(Equal(expected)) 403 }, 404 Entry("uses the value from the config", "foo.bar", &tls.Config{ServerName: "foo.bar"}, remoteAddr, "baz.foo"), 405 Entry("uses the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org"), 406 Entry("removes the port from the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org:1234"), 407 Entry("uses the IP", "1.3.5.7", &tls.Config{}, remoteAddr, ""), 408 ) 409 }) 410 411 type mockSyscallConn struct { 412 net.PacketConn 413 } 414 415 func (c *mockSyscallConn) SyscallConn() (syscall.RawConn, error) { 416 return nil, errors.New("mocked") 417 }