github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/transport_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/rand"
     7  	"crypto/tls"
     8  	"errors"
     9  	"net"
    10  	"syscall"
    11  	"time"
    12  
    13  	mocklogging "github.com/danielpfeifer02/quic-go-prio-packs/internal/mocks/logging"
    14  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol"
    15  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/wire"
    16  	"github.com/danielpfeifer02/quic-go-prio-packs/logging"
    17  
    18  	. "github.com/onsi/ginkgo/v2"
    19  	. "github.com/onsi/gomega"
    20  	"go.uber.org/mock/gomock"
    21  )
    22  
    23  var _ = Describe("Transport", func() {
    24  	type packetToRead struct {
    25  		addr net.Addr
    26  		data []byte
    27  		err  error
    28  	}
    29  
    30  	getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte {
    31  		b, err := (&wire.ExtendedHeader{
    32  			Header: wire.Header{
    33  				Type:             t,
    34  				DestConnectionID: connID,
    35  				Length:           length,
    36  				Version:          protocol.Version1,
    37  			},
    38  			PacketNumberLen: protocol.PacketNumberLen2,
    39  		}).Append(nil, protocol.Version1)
    40  		Expect(err).ToNot(HaveOccurred())
    41  		return b
    42  	}
    43  
    44  	getPacket := func(connID protocol.ConnectionID) []byte {
    45  		return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2)
    46  	}
    47  
    48  	newMockPacketConn := func(packetChan <-chan packetToRead) *MockPacketConn {
    49  		conn := NewMockPacketConn(mockCtrl)
    50  		conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
    51  		conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) {
    52  			p, ok := <-packetChan
    53  			if !ok {
    54  				return 0, nil, errors.New("closed")
    55  			}
    56  			return copy(b, p.data), p.addr, p.err
    57  		}).AnyTimes()
    58  		// for shutdown
    59  		conn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
    60  		return conn
    61  	}
    62  
    63  	It("handles packets for different packet handlers on the same packet conn", func() {
    64  		packetChan := make(chan packetToRead)
    65  		tr := &Transport{Conn: newMockPacketConn(packetChan)}
    66  		tr.init(true)
    67  		phm := NewMockPacketHandlerManager(mockCtrl)
    68  		tr.handlerMap = phm
    69  		connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
    70  		connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
    71  
    72  		handled := make(chan struct{}, 2)
    73  		phm.EXPECT().Get(connID1).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
    74  			h := NewMockPacketHandler(mockCtrl)
    75  			h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
    76  				defer GinkgoRecover()
    77  				connID, err := wire.ParseConnectionID(p.data, 0)
    78  				Expect(err).ToNot(HaveOccurred())
    79  				Expect(connID).To(Equal(connID1))
    80  				handled <- struct{}{}
    81  			})
    82  			return h, true
    83  		})
    84  		phm.EXPECT().Get(connID2).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
    85  			h := NewMockPacketHandler(mockCtrl)
    86  			h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
    87  				defer GinkgoRecover()
    88  				connID, err := wire.ParseConnectionID(p.data, 0)
    89  				Expect(err).ToNot(HaveOccurred())
    90  				Expect(connID).To(Equal(connID2))
    91  				handled <- struct{}{}
    92  			})
    93  			return h, true
    94  		})
    95  
    96  		packetChan <- packetToRead{data: getPacket(connID1)}
    97  		packetChan <- packetToRead{data: getPacket(connID2)}
    98  
    99  		Eventually(handled).Should(Receive())
   100  		Eventually(handled).Should(Receive())
   101  
   102  		// shutdown
   103  		phm.EXPECT().Close(gomock.Any())
   104  		close(packetChan)
   105  		tr.Close()
   106  	})
   107  
   108  	It("closes listeners", func() {
   109  		packetChan := make(chan packetToRead)
   110  		tr := &Transport{Conn: newMockPacketConn(packetChan)}
   111  		defer tr.Close()
   112  		ln, err := tr.Listen(&tls.Config{}, nil)
   113  		Expect(err).ToNot(HaveOccurred())
   114  		phm := NewMockPacketHandlerManager(mockCtrl)
   115  		tr.handlerMap = phm
   116  
   117  		Expect(ln.Close()).To(Succeed())
   118  
   119  		// shutdown
   120  		phm.EXPECT().Close(gomock.Any())
   121  		close(packetChan)
   122  		tr.Close()
   123  	})
   124  
   125  	It("closes transport concurrently with listener", func() {
   126  		// try 10 times to trigger race conditions
   127  		for i := 0; i < 10; i++ {
   128  			packetChan := make(chan packetToRead)
   129  			tr := &Transport{Conn: newMockPacketConn(packetChan)}
   130  			ln, err := tr.Listen(&tls.Config{}, nil)
   131  			Expect(err).ToNot(HaveOccurred())
   132  			ch := make(chan bool)
   133  			// Close transport and listener concurrently.
   134  			go func() {
   135  				ch <- true
   136  				Expect(ln.Close()).To(Succeed())
   137  				ch <- true
   138  			}()
   139  			<-ch
   140  			close(packetChan)
   141  			Expect(tr.Close()).To(Succeed())
   142  			<-ch
   143  		}
   144  	})
   145  
   146  	It("drops unparseable QUIC packets", func() {
   147  		addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
   148  		packetChan := make(chan packetToRead)
   149  		t, tracer := mocklogging.NewMockTracer(mockCtrl)
   150  		tr := &Transport{
   151  			Conn:               newMockPacketConn(packetChan),
   152  			ConnectionIDLength: 10,
   153  			Tracer:             t,
   154  		}
   155  		tr.init(true)
   156  		dropped := make(chan struct{})
   157  		tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) })
   158  		packetChan <- packetToRead{
   159  			addr: addr,
   160  			data: []byte{0x40 /* set the QUIC bit */, 1, 2, 3},
   161  		}
   162  		Eventually(dropped).Should(BeClosed())
   163  
   164  		// shutdown
   165  		tracer.EXPECT().Close()
   166  		close(packetChan)
   167  		tr.Close()
   168  	})
   169  
   170  	It("closes when reading from the conn fails", func() {
   171  		packetChan := make(chan packetToRead)
   172  		tr := Transport{Conn: newMockPacketConn(packetChan)}
   173  		defer tr.Close()
   174  		phm := NewMockPacketHandlerManager(mockCtrl)
   175  		tr.init(true)
   176  		tr.handlerMap = phm
   177  
   178  		done := make(chan struct{})
   179  		phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) })
   180  		packetChan <- packetToRead{err: errors.New("read failed")}
   181  		Eventually(done).Should(BeClosed())
   182  
   183  		// shutdown
   184  		close(packetChan)
   185  		tr.Close()
   186  	})
   187  
   188  	It("continues listening after temporary errors", func() {
   189  		packetChan := make(chan packetToRead)
   190  		tr := Transport{Conn: newMockPacketConn(packetChan)}
   191  		defer tr.Close()
   192  		phm := NewMockPacketHandlerManager(mockCtrl)
   193  		tr.init(true)
   194  		tr.handlerMap = phm
   195  
   196  		tempErr := deadlineError{}
   197  		Expect(tempErr.Temporary()).To(BeTrue())
   198  		packetChan <- packetToRead{err: tempErr}
   199  		// don't expect any calls to phm.Close
   200  		time.Sleep(50 * time.Millisecond)
   201  
   202  		// shutdown
   203  		phm.EXPECT().Close(gomock.Any())
   204  		close(packetChan)
   205  		tr.Close()
   206  	})
   207  
   208  	It("handles short header packets resets", func() {
   209  		connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
   210  		packetChan := make(chan packetToRead)
   211  		tr := Transport{
   212  			Conn:               newMockPacketConn(packetChan),
   213  			ConnectionIDLength: connID.Len(),
   214  		}
   215  		tr.init(true)
   216  		defer tr.Close()
   217  		phm := NewMockPacketHandlerManager(mockCtrl)
   218  		tr.handlerMap = phm
   219  
   220  		var token protocol.StatelessResetToken
   221  		rand.Read(token[:])
   222  
   223  		var b []byte
   224  		b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
   225  		Expect(err).ToNot(HaveOccurred())
   226  		b = append(b, token[:]...)
   227  		conn := NewMockPacketHandler(mockCtrl)
   228  		gomock.InOrder(
   229  			phm.EXPECT().Get(connID).Return(conn, true),
   230  			conn.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
   231  				Expect(p.data).To(Equal(b))
   232  				Expect(p.rcvTime).To(BeTemporally("~", time.Now(), time.Second))
   233  			}),
   234  		)
   235  		packetChan <- packetToRead{data: b}
   236  
   237  		// shutdown
   238  		phm.EXPECT().Close(gomock.Any())
   239  		close(packetChan)
   240  		tr.Close()
   241  	})
   242  
   243  	It("handles stateless resets", func() {
   244  		connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
   245  		packetChan := make(chan packetToRead)
   246  		tr := Transport{
   247  			Conn:               newMockPacketConn(packetChan),
   248  			ConnectionIDLength: connID.Len(),
   249  		}
   250  		tr.init(true)
   251  		defer tr.Close()
   252  		phm := NewMockPacketHandlerManager(mockCtrl)
   253  		tr.handlerMap = phm
   254  
   255  		var token protocol.StatelessResetToken
   256  		rand.Read(token[:])
   257  
   258  		var b []byte
   259  		b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
   260  		Expect(err).ToNot(HaveOccurred())
   261  		b = append(b, token[:]...)
   262  		conn := NewMockPacketHandler(mockCtrl)
   263  		destroyed := make(chan struct{})
   264  		gomock.InOrder(
   265  			phm.EXPECT().Get(connID),
   266  			phm.EXPECT().GetByResetToken(token).Return(conn, true),
   267  			conn.EXPECT().destroy(gomock.Any()).Do(func(err error) {
   268  				Expect(err).To(MatchError(&StatelessResetError{Token: token}))
   269  				close(destroyed)
   270  			}),
   271  		)
   272  		packetChan <- packetToRead{data: b}
   273  		Eventually(destroyed).Should(BeClosed())
   274  
   275  		// shutdown
   276  		phm.EXPECT().Close(gomock.Any())
   277  		close(packetChan)
   278  		tr.Close()
   279  	})
   280  
   281  	It("sends stateless resets", func() {
   282  		connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
   283  		packetChan := make(chan packetToRead)
   284  		conn := newMockPacketConn(packetChan)
   285  		tr := Transport{
   286  			Conn:               conn,
   287  			StatelessResetKey:  &StatelessResetKey{1, 2, 3, 4},
   288  			ConnectionIDLength: connID.Len(),
   289  		}
   290  		tr.init(true)
   291  		defer tr.Close()
   292  		phm := NewMockPacketHandlerManager(mockCtrl)
   293  		tr.handlerMap = phm
   294  
   295  		var b []byte
   296  		b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
   297  		Expect(err).ToNot(HaveOccurred())
   298  		b = append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...)
   299  
   300  		var token protocol.StatelessResetToken
   301  		rand.Read(token[:])
   302  		written := make(chan struct{})
   303  		gomock.InOrder(
   304  			phm.EXPECT().Get(connID),
   305  			phm.EXPECT().GetByResetToken(gomock.Any()),
   306  			phm.EXPECT().GetStatelessResetToken(connID).Return(token),
   307  			conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func(b []byte, _ net.Addr) (int, error) {
   308  				defer close(written)
   309  				Expect(bytes.Contains(b, token[:])).To(BeTrue())
   310  				return len(b), nil
   311  			}),
   312  		)
   313  		packetChan <- packetToRead{data: b}
   314  		Eventually(written).Should(BeClosed())
   315  
   316  		// shutdown
   317  		phm.EXPECT().Close(gomock.Any())
   318  		close(packetChan)
   319  		tr.Close()
   320  	})
   321  
   322  	It("closes uninitialized Transport and closes underlying PacketConn", func() {
   323  		packetChan := make(chan packetToRead)
   324  		pconn := newMockPacketConn(packetChan)
   325  
   326  		tr := &Transport{
   327  			Conn:        pconn,
   328  			createdConn: true, // owns pconn
   329  		}
   330  		// NO init
   331  
   332  		// shutdown
   333  		close(packetChan)
   334  		pconn.EXPECT().Close()
   335  		Expect(tr.Close()).To(Succeed())
   336  	})
   337  
   338  	It("doesn't add the PacketConn to the multiplexer if (*Transport).init fails", func() {
   339  		packetChan := make(chan packetToRead)
   340  		pconn := newMockPacketConn(packetChan)
   341  		syscallconn := &mockSyscallConn{pconn}
   342  
   343  		tr := &Transport{
   344  			Conn: syscallconn,
   345  		}
   346  
   347  		err := tr.init(false)
   348  		Expect(err).To(HaveOccurred())
   349  		conns := getMultiplexer().(*connMultiplexer).conns
   350  		Expect(len(conns)).To(BeZero())
   351  	})
   352  
   353  	It("allows receiving non-QUIC packets", func() {
   354  		remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
   355  		packetChan := make(chan packetToRead)
   356  		tr := &Transport{
   357  			Conn:               newMockPacketConn(packetChan),
   358  			ConnectionIDLength: 10,
   359  		}
   360  		tr.init(true)
   361  		receivedPacketChan := make(chan []byte)
   362  		go func() {
   363  			defer GinkgoRecover()
   364  			b := make([]byte, 100)
   365  			n, addr, err := tr.ReadNonQUICPacket(context.Background(), b)
   366  			Expect(err).ToNot(HaveOccurred())
   367  			Expect(addr).To(Equal(remoteAddr))
   368  			receivedPacketChan <- b[:n]
   369  		}()
   370  		// Receiving of non-QUIC packets is enabled when ReadNonQUICPacket is called.
   371  		// Give the Go routine some time to spin up.
   372  		time.Sleep(scaleDuration(50 * time.Millisecond))
   373  		packetChan <- packetToRead{
   374  			addr: remoteAddr,
   375  			data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
   376  		}
   377  
   378  		Eventually(receivedPacketChan).Should(Receive(Equal([]byte{0, 1, 2, 3})))
   379  
   380  		// shutdown
   381  		close(packetChan)
   382  		tr.Close()
   383  	})
   384  
   385  	It("drops non-QUIC packet if the application doesn't process them quickly enough", func() {
   386  		remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
   387  		packetChan := make(chan packetToRead)
   388  		t, tracer := mocklogging.NewMockTracer(mockCtrl)
   389  		tr := &Transport{
   390  			Conn:               newMockPacketConn(packetChan),
   391  			ConnectionIDLength: 10,
   392  			Tracer:             t,
   393  		}
   394  		tr.init(true)
   395  
   396  		ctx, cancel := context.WithCancel(context.Background())
   397  		cancel()
   398  		_, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 10))
   399  		Expect(err).To(MatchError(context.Canceled))
   400  
   401  		for i := 0; i < maxQueuedNonQUICPackets; i++ {
   402  			packetChan <- packetToRead{
   403  				addr: remoteAddr,
   404  				data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
   405  			}
   406  		}
   407  
   408  		done := make(chan struct{})
   409  		tracer.EXPECT().DroppedPacket(remoteAddr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
   410  			close(done)
   411  		})
   412  		packetChan <- packetToRead{
   413  			addr: remoteAddr,
   414  			data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
   415  		}
   416  		Eventually(done).Should(BeClosed())
   417  
   418  		// shutdown
   419  		tracer.EXPECT().Close()
   420  		close(packetChan)
   421  		tr.Close()
   422  	})
   423  
   424  	remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 3, 5, 7), Port: 1234}
   425  	DescribeTable("setting the tls.Config.ServerName",
   426  		func(expected string, conf *tls.Config, addr net.Addr, host string) {
   427  			setTLSConfigServerName(conf, addr, host)
   428  			Expect(conf.ServerName).To(Equal(expected))
   429  		},
   430  		Entry("uses the value from the config", "foo.bar", &tls.Config{ServerName: "foo.bar"}, remoteAddr, "baz.foo"),
   431  		Entry("uses the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org"),
   432  		Entry("removes the port from the hostname", "golang.org", &tls.Config{}, remoteAddr, "golang.org:1234"),
   433  		Entry("uses the IP", "1.3.5.7", &tls.Config{}, remoteAddr, ""),
   434  	)
   435  })
   436  
   437  type mockSyscallConn struct {
   438  	net.PacketConn
   439  }
   440  
   441  func (c *mockSyscallConn) SyscallConn() (syscall.RawConn, error) {
   442  	return nil, errors.New("mocked")
   443  }