github.com/MerlinKodo/quic-go@v0.39.2/transport_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/rand"
     7  	"crypto/tls"
     8  	"errors"
     9  	"net"
    10  	"syscall"
    11  	"time"
    12  
    13  	mocklogging "github.com/MerlinKodo/quic-go/internal/mocks/logging"
    14  	"github.com/MerlinKodo/quic-go/internal/protocol"
    15  	"github.com/MerlinKodo/quic-go/internal/wire"
    16  	"github.com/MerlinKodo/quic-go/logging"
    17  
    18  	. "github.com/onsi/ginkgo/v2"
    19  	. "github.com/onsi/gomega"
    20  	"go.uber.org/mock/gomock"
    21  )
    22  
    23  var _ = Describe("Transport", func() {
    24  	type packetToRead struct {
    25  		addr net.Addr
    26  		data []byte
    27  		err  error
    28  	}
    29  
    30  	getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte {
    31  		b, err := (&wire.ExtendedHeader{
    32  			Header: wire.Header{
    33  				Type:             t,
    34  				DestConnectionID: connID,
    35  				Length:           length,
    36  				Version:          protocol.Version1,
    37  			},
    38  			PacketNumberLen: protocol.PacketNumberLen2,
    39  		}).Append(nil, protocol.Version1)
    40  		Expect(err).ToNot(HaveOccurred())
    41  		return b
    42  	}
    43  
    44  	getPacket := func(connID protocol.ConnectionID) []byte {
    45  		return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2)
    46  	}
    47  
    48  	newMockPacketConn := func(packetChan <-chan packetToRead) *MockPacketConn {
    49  		conn := NewMockPacketConn(mockCtrl)
    50  		conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
    51  		conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) {
    52  			p, ok := <-packetChan
    53  			if !ok {
    54  				return 0, nil, errors.New("closed")
    55  			}
    56  			return copy(b, p.data), p.addr, p.err
    57  		}).AnyTimes()
    58  		// for shutdown
    59  		conn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
    60  		return conn
    61  	}
    62  
    63  	It("handles packets for different packet handlers on the same packet conn", func() {
    64  		packetChan := make(chan packetToRead)
    65  		tr := &Transport{Conn: newMockPacketConn(packetChan)}
    66  		tr.init(true)
    67  		phm := NewMockPacketHandlerManager(mockCtrl)
    68  		tr.handlerMap = phm
    69  		connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
    70  		connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
    71  
    72  		handled := make(chan struct{}, 2)
    73  		phm.EXPECT().Get(connID1).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
    74  			h := NewMockPacketHandler(mockCtrl)
    75  			h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
    76  				defer GinkgoRecover()
    77  				connID, err := wire.ParseConnectionID(p.data, 0)
    78  				Expect(err).ToNot(HaveOccurred())
    79  				Expect(connID).To(Equal(connID1))
    80  				handled <- struct{}{}
    81  			})
    82  			return h, true
    83  		})
    84  		phm.EXPECT().Get(connID2).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
    85  			h := NewMockPacketHandler(mockCtrl)
    86  			h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
    87  				defer GinkgoRecover()
    88  				connID, err := wire.ParseConnectionID(p.data, 0)
    89  				Expect(err).ToNot(HaveOccurred())
    90  				Expect(connID).To(Equal(connID2))
    91  				handled <- struct{}{}
    92  			})
    93  			return h, true
    94  		})
    95  
    96  		packetChan <- packetToRead{data: getPacket(connID1)}
    97  		packetChan <- packetToRead{data: getPacket(connID2)}
    98  
    99  		Eventually(handled).Should(Receive())
   100  		Eventually(handled).Should(Receive())
   101  
   102  		// shutdown
   103  		phm.EXPECT().Close(gomock.Any())
   104  		close(packetChan)
   105  		tr.Close()
   106  	})
   107  
   108  	It("closes listeners", func() {
   109  		packetChan := make(chan packetToRead)
   110  		tr := &Transport{Conn: newMockPacketConn(packetChan)}
   111  		defer tr.Close()
   112  		ln, err := tr.Listen(&tls.Config{}, nil)
   113  		Expect(err).ToNot(HaveOccurred())
   114  		phm := NewMockPacketHandlerManager(mockCtrl)
   115  		tr.handlerMap = phm
   116  
   117  		phm.EXPECT().CloseServer()
   118  		Expect(ln.Close()).To(Succeed())
   119  
   120  		// shutdown
   121  		phm.EXPECT().Close(gomock.Any())
   122  		close(packetChan)
   123  		tr.Close()
   124  	})
   125  
   126  	It("drops unparseable QUIC packets", func() {
   127  		addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
   128  		packetChan := make(chan packetToRead)
   129  		t, tracer := mocklogging.NewMockTracer(mockCtrl)
   130  		tr := &Transport{
   131  			Conn:               newMockPacketConn(packetChan),
   132  			ConnectionIDLength: 10,
   133  			Tracer:             t,
   134  		}
   135  		tr.init(true)
   136  		dropped := make(chan struct{})
   137  		tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) })
   138  		packetChan <- packetToRead{
   139  			addr: addr,
   140  			data: []byte{0x40 /* set the QUIC bit */, 1, 2, 3},
   141  		}
   142  		Eventually(dropped).Should(BeClosed())
   143  
   144  		// shutdown
   145  		close(packetChan)
   146  		tr.Close()
   147  	})
   148  
   149  	It("closes when reading from the conn fails", func() {
   150  		packetChan := make(chan packetToRead)
   151  		tr := Transport{Conn: newMockPacketConn(packetChan)}
   152  		defer tr.Close()
   153  		phm := NewMockPacketHandlerManager(mockCtrl)
   154  		tr.init(true)
   155  		tr.handlerMap = phm
   156  
   157  		done := make(chan struct{})
   158  		phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) })
   159  		packetChan <- packetToRead{err: errors.New("read failed")}
   160  		Eventually(done).Should(BeClosed())
   161  
   162  		// shutdown
   163  		close(packetChan)
   164  		tr.Close()
   165  	})
   166  
   167  	It("continues listening after temporary errors", func() {
   168  		packetChan := make(chan packetToRead)
   169  		tr := Transport{Conn: newMockPacketConn(packetChan)}
   170  		defer tr.Close()
   171  		phm := NewMockPacketHandlerManager(mockCtrl)
   172  		tr.init(true)
   173  		tr.handlerMap = phm
   174  
   175  		tempErr := deadlineError{}
   176  		Expect(tempErr.Temporary()).To(BeTrue())
   177  		packetChan <- packetToRead{err: tempErr}
   178  		// don't expect any calls to phm.Close
   179  		time.Sleep(50 * time.Millisecond)
   180  
   181  		// shutdown
   182  		phm.EXPECT().Close(gomock.Any())
   183  		close(packetChan)
   184  		tr.Close()
   185  	})
   186  
   187  	It("handles short header packets resets", func() {
   188  		connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
   189  		packetChan := make(chan packetToRead)
   190  		tr := Transport{
   191  			Conn:               newMockPacketConn(packetChan),
   192  			ConnectionIDLength: connID.Len(),
   193  		}
   194  		tr.init(true)
   195  		defer tr.Close()
   196  		phm := NewMockPacketHandlerManager(mockCtrl)
   197  		tr.handlerMap = phm
   198  
   199  		var token protocol.StatelessResetToken
   200  		rand.Read(token[:])
   201  
   202  		var b []byte
   203  		b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
   204  		Expect(err).ToNot(HaveOccurred())
   205  		b = append(b, token[:]...)
   206  		conn := NewMockPacketHandler(mockCtrl)
   207  		gomock.InOrder(
   208  			phm.EXPECT().GetByResetToken(token),
   209  			phm.EXPECT().Get(connID).Return(conn, true),
   210  			conn.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
   211  				Expect(p.data).To(Equal(b))
   212  				Expect(p.rcvTime).To(BeTemporally("~", time.Now(), time.Second))
   213  			}),
   214  		)
   215  		packetChan <- packetToRead{data: b}
   216  
   217  		// shutdown
   218  		phm.EXPECT().Close(gomock.Any())
   219  		close(packetChan)
   220  		tr.Close()
   221  	})
   222  
   223  	It("handles stateless resets", func() {
   224  		connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
   225  		packetChan := make(chan packetToRead)
   226  		tr := Transport{Conn: newMockPacketConn(packetChan)}
   227  		tr.init(true)
   228  		defer tr.Close()
   229  		phm := NewMockPacketHandlerManager(mockCtrl)
   230  		tr.handlerMap = phm
   231  
   232  		var token protocol.StatelessResetToken
   233  		rand.Read(token[:])
   234  
   235  		var b []byte
   236  		b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
   237  		Expect(err).ToNot(HaveOccurred())
   238  		b = append(b, token[:]...)
   239  		conn := NewMockPacketHandler(mockCtrl)
   240  		destroyed := make(chan struct{})
   241  		gomock.InOrder(
   242  			phm.EXPECT().GetByResetToken(token).Return(conn, true),
   243  			conn.EXPECT().destroy(gomock.Any()).Do(func(err error) {
   244  				Expect(err).To(MatchError(&StatelessResetError{Token: token}))
   245  				close(destroyed)
   246  			}),
   247  		)
   248  		packetChan <- packetToRead{data: b}
   249  		Eventually(destroyed).Should(BeClosed())
   250  
   251  		// shutdown
   252  		phm.EXPECT().Close(gomock.Any())
   253  		close(packetChan)
   254  		tr.Close()
   255  	})
   256  
   257  	It("sends stateless resets", func() {
   258  		connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
   259  		packetChan := make(chan packetToRead)
   260  		conn := newMockPacketConn(packetChan)
   261  		tr := Transport{
   262  			Conn:               conn,
   263  			StatelessResetKey:  &StatelessResetKey{1, 2, 3, 4},
   264  			ConnectionIDLength: connID.Len(),
   265  		}
   266  		tr.init(true)
   267  		defer tr.Close()
   268  		phm := NewMockPacketHandlerManager(mockCtrl)
   269  		tr.handlerMap = phm
   270  
   271  		var b []byte
   272  		b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
   273  		Expect(err).ToNot(HaveOccurred())
   274  		b = append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...)
   275  
   276  		var token protocol.StatelessResetToken
   277  		rand.Read(token[:])
   278  		written := make(chan struct{})
   279  		gomock.InOrder(
   280  			phm.EXPECT().GetByResetToken(gomock.Any()),
   281  			phm.EXPECT().Get(connID),
   282  			phm.EXPECT().GetStatelessResetToken(connID).Return(token),
   283  			conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func(b []byte, _ net.Addr) {
   284  				defer close(written)
   285  				Expect(bytes.Contains(b, token[:])).To(BeTrue())
   286  			}),
   287  		)
   288  		packetChan <- packetToRead{data: b}
   289  		Eventually(written).Should(BeClosed())
   290  
   291  		// shutdown
   292  		phm.EXPECT().Close(gomock.Any())
   293  		close(packetChan)
   294  		tr.Close()
   295  	})
   296  
   297  	It("closes uninitialized Transport and closes underlying PacketConn", func() {
   298  		packetChan := make(chan packetToRead)
   299  		pconn := newMockPacketConn(packetChan)
   300  
   301  		tr := &Transport{
   302  			Conn:        pconn,
   303  			createdConn: true, // owns pconn
   304  		}
   305  		// NO init
   306  
   307  		// shutdown
   308  		close(packetChan)
   309  		pconn.EXPECT().Close()
   310  		Expect(tr.Close()).To(Succeed())
   311  	})
   312  
   313  	It("doesn't add the PacketConn to the multiplexer if (*Transport).init fails", func() {
   314  		packetChan := make(chan packetToRead)
   315  		pconn := newMockPacketConn(packetChan)
   316  		syscallconn := &mockSyscallConn{pconn}
   317  
   318  		tr := &Transport{
   319  			Conn: syscallconn,
   320  		}
   321  
   322  		err := tr.init(false)
   323  		Expect(err).To(HaveOccurred())
   324  		conns := getMultiplexer().(*connMultiplexer).conns
   325  		Expect(len(conns)).To(BeZero())
   326  	})
   327  
   328  	It("allows receiving non-QUIC packets", func() {
   329  		remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
   330  		packetChan := make(chan packetToRead)
   331  		tr := &Transport{
   332  			Conn:               newMockPacketConn(packetChan),
   333  			ConnectionIDLength: 10,
   334  		}
   335  		tr.init(true)
   336  		receivedPacketChan := make(chan []byte)
   337  		go func() {
   338  			defer GinkgoRecover()
   339  			b := make([]byte, 100)
   340  			n, addr, err := tr.ReadNonQUICPacket(context.Background(), b)
   341  			Expect(err).ToNot(HaveOccurred())
   342  			Expect(addr).To(Equal(remoteAddr))
   343  			receivedPacketChan <- b[:n]
   344  		}()
   345  		// Receiving of non-QUIC packets is enabled when ReadNonQUICPacket is called.
   346  		// Give the Go routine some time to spin up.
   347  		time.Sleep(scaleDuration(50 * time.Millisecond))
   348  		packetChan <- packetToRead{
   349  			addr: remoteAddr,
   350  			data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
   351  		}
   352  
   353  		Eventually(receivedPacketChan).Should(Receive(Equal([]byte{0, 1, 2, 3})))
   354  
   355  		// shutdown
   356  		close(packetChan)
   357  		tr.Close()
   358  	})
   359  
   360  	It("drops non-QUIC packet if the application doesn't process them quickly enough", func() {
   361  		remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
   362  		packetChan := make(chan packetToRead)
   363  		t, tracer := mocklogging.NewMockTracer(mockCtrl)
   364  		tr := &Transport{
   365  			Conn:               newMockPacketConn(packetChan),
   366  			ConnectionIDLength: 10,
   367  			Tracer:             t,
   368  		}
   369  		tr.init(true)
   370  
   371  		ctx, cancel := context.WithCancel(context.Background())
   372  		cancel()
   373  		_, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 10))
   374  		Expect(err).To(MatchError(context.Canceled))
   375  
   376  		for i := 0; i < maxQueuedNonQUICPackets; i++ {
   377  			packetChan <- packetToRead{
   378  				addr: remoteAddr,
   379  				data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
   380  			}
   381  		}
   382  
   383  		done := make(chan struct{})
   384  		tracer.EXPECT().DroppedPacket(remoteAddr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
   385  			close(done)
   386  		})
   387  		packetChan <- packetToRead{
   388  			addr: remoteAddr,
   389  			data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
   390  		}
   391  		Eventually(done).Should(BeClosed())
   392  
   393  		// shutdown
   394  		close(packetChan)
   395  		tr.Close()
   396  	})
   397  
   398  	remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 3, 5, 7), Port: 1234}
   399  	DescribeTable("setting the tls.Config.ServerName",
   400  		func(expected string, conf *tls.Config, addr net.Addr, host string) {
   401  			setTLSConfigServerName(conf, addr, host)
   402  			Expect(conf.ServerName).To(Equal(expected))
   403  		},
   404  		Entry("uses the value from the config", "foo.bar", &tls.Config{ServerName: "foo.bar"}, remoteAddr, "baz.foo"),
   405  		Entry("uses the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org"),
   406  		Entry("removes the port from the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org:1234"),
   407  		Entry("uses the IP", "1.3.5.7", &tls.Config{}, remoteAddr, ""),
   408  	)
   409  })
   410  
   411  type mockSyscallConn struct {
   412  	net.PacketConn
   413  }
   414  
   415  func (c *mockSyscallConn) SyscallConn() (syscall.RawConn, error) {
   416  	return nil, errors.New("mocked")
   417  }