github.com/tumi8/quic-go@v0.37.4-tum/transport_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"crypto/tls"
     7  	"errors"
     8  	"net"
     9  	"syscall"
    10  	"time"
    11  
    12  	mocklogging "github.com/tumi8/quic-go/noninternal/mocks/logging"
    13  	"github.com/tumi8/quic-go/noninternal/protocol"
    14  	"github.com/tumi8/quic-go/noninternal/wire"
    15  	"github.com/tumi8/quic-go/logging"
    16  
    17  	"github.com/golang/mock/gomock"
    18  	. "github.com/onsi/ginkgo/v2"
    19  	. "github.com/onsi/gomega"
    20  )
    21  
    22  var _ = Describe("Transport", func() {
    23  	type packetToRead struct {
    24  		addr net.Addr
    25  		data []byte
    26  		err  error
    27  	}
    28  
    29  	getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte {
    30  		b, err := (&wire.ExtendedHeader{
    31  			Header: wire.Header{
    32  				Type:             t,
    33  				DestConnectionID: connID,
    34  				Length:           length,
    35  				Version:          protocol.Version1,
    36  			},
    37  			PacketNumberLen: protocol.PacketNumberLen2,
    38  		}).Append(nil, protocol.Version1)
    39  		Expect(err).ToNot(HaveOccurred())
    40  		return b
    41  	}
    42  
    43  	getPacket := func(connID protocol.ConnectionID) []byte {
    44  		return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2)
    45  	}
    46  
    47  	newMockPacketConn := func(packetChan <-chan packetToRead) *MockPacketConn {
    48  		conn := NewMockPacketConn(mockCtrl)
    49  		conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
    50  		conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) {
    51  			p, ok := <-packetChan
    52  			if !ok {
    53  				return 0, nil, errors.New("closed")
    54  			}
    55  			return copy(b, p.data), p.addr, p.err
    56  		}).AnyTimes()
    57  		// for shutdown
    58  		conn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
    59  		return conn
    60  	}
    61  
    62  	It("handles packets for different packet handlers on the same packet conn", func() {
    63  		packetChan := make(chan packetToRead)
    64  		tr := &Transport{Conn: newMockPacketConn(packetChan)}
    65  		tr.init(true)
    66  		phm := NewMockPacketHandlerManager(mockCtrl)
    67  		tr.handlerMap = phm
    68  		connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
    69  		connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
    70  
    71  		handled := make(chan struct{}, 2)
    72  		phm.EXPECT().Get(connID1).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
    73  			h := NewMockPacketHandler(mockCtrl)
    74  			h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
    75  				defer GinkgoRecover()
    76  				connID, err := wire.ParseConnectionID(p.data, 0)
    77  				Expect(err).ToNot(HaveOccurred())
    78  				Expect(connID).To(Equal(connID1))
    79  				handled <- struct{}{}
    80  			})
    81  			return h, true
    82  		})
    83  		phm.EXPECT().Get(connID2).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
    84  			h := NewMockPacketHandler(mockCtrl)
    85  			h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
    86  				defer GinkgoRecover()
    87  				connID, err := wire.ParseConnectionID(p.data, 0)
    88  				Expect(err).ToNot(HaveOccurred())
    89  				Expect(connID).To(Equal(connID2))
    90  				handled <- struct{}{}
    91  			})
    92  			return h, true
    93  		})
    94  
    95  		packetChan <- packetToRead{data: getPacket(connID1)}
    96  		packetChan <- packetToRead{data: getPacket(connID2)}
    97  
    98  		Eventually(handled).Should(Receive())
    99  		Eventually(handled).Should(Receive())
   100  
   101  		// shutdown
   102  		phm.EXPECT().Close(gomock.Any())
   103  		close(packetChan)
   104  		tr.Close()
   105  	})
   106  
   107  	It("closes listeners", func() {
   108  		packetChan := make(chan packetToRead)
   109  		tr := &Transport{Conn: newMockPacketConn(packetChan)}
   110  		defer tr.Close()
   111  		ln, err := tr.Listen(&tls.Config{}, nil)
   112  		Expect(err).ToNot(HaveOccurred())
   113  		phm := NewMockPacketHandlerManager(mockCtrl)
   114  		tr.handlerMap = phm
   115  
   116  		phm.EXPECT().CloseServer()
   117  		Expect(ln.Close()).To(Succeed())
   118  
   119  		// shutdown
   120  		phm.EXPECT().Close(gomock.Any())
   121  		close(packetChan)
   122  		tr.Close()
   123  	})
   124  
   125  	It("drops unparseable packets", func() {
   126  		addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
   127  		packetChan := make(chan packetToRead)
   128  		tracer := mocklogging.NewMockTracer(mockCtrl)
   129  		tr := &Transport{
   130  			Conn:               newMockPacketConn(packetChan),
   131  			ConnectionIDLength: 10,
   132  			Tracer:             tracer,
   133  		}
   134  		tr.init(true)
   135  		dropped := make(chan struct{})
   136  		tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) })
   137  		packetChan <- packetToRead{
   138  			addr: addr,
   139  			data: []byte{0, 1, 2, 3},
   140  		}
   141  		Eventually(dropped).Should(BeClosed())
   142  
   143  		// shutdown
   144  		close(packetChan)
   145  		tr.Close()
   146  	})
   147  
   148  	It("closes when reading from the conn fails", func() {
   149  		packetChan := make(chan packetToRead)
   150  		tr := Transport{Conn: newMockPacketConn(packetChan)}
   151  		defer tr.Close()
   152  		phm := NewMockPacketHandlerManager(mockCtrl)
   153  		tr.init(true)
   154  		tr.handlerMap = phm
   155  
   156  		done := make(chan struct{})
   157  		phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) })
   158  		packetChan <- packetToRead{err: errors.New("read failed")}
   159  		Eventually(done).Should(BeClosed())
   160  
   161  		// shutdown
   162  		close(packetChan)
   163  		tr.Close()
   164  	})
   165  
   166  	It("continues listening after temporary errors", func() {
   167  		packetChan := make(chan packetToRead)
   168  		tr := Transport{Conn: newMockPacketConn(packetChan)}
   169  		defer tr.Close()
   170  		phm := NewMockPacketHandlerManager(mockCtrl)
   171  		tr.init(true)
   172  		tr.handlerMap = phm
   173  
   174  		tempErr := deadlineError{}
   175  		Expect(tempErr.Temporary()).To(BeTrue())
   176  		packetChan <- packetToRead{err: tempErr}
   177  		// don't expect any calls to phm.Close
   178  		time.Sleep(50 * time.Millisecond)
   179  
   180  		// shutdown
   181  		phm.EXPECT().Close(gomock.Any())
   182  		close(packetChan)
   183  		tr.Close()
   184  	})
   185  
   186  	It("handles short header packets resets", func() {
   187  		connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
   188  		packetChan := make(chan packetToRead)
   189  		tr := Transport{
   190  			Conn:               newMockPacketConn(packetChan),
   191  			ConnectionIDLength: connID.Len(),
   192  		}
   193  		tr.init(true)
   194  		defer tr.Close()
   195  		phm := NewMockPacketHandlerManager(mockCtrl)
   196  		tr.handlerMap = phm
   197  
   198  		var token protocol.StatelessResetToken
   199  		rand.Read(token[:])
   200  
   201  		var b []byte
   202  		b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
   203  		Expect(err).ToNot(HaveOccurred())
   204  		b = append(b, token[:]...)
   205  		conn := NewMockPacketHandler(mockCtrl)
   206  		gomock.InOrder(
   207  			phm.EXPECT().GetByResetToken(token),
   208  			phm.EXPECT().Get(connID).Return(conn, true),
   209  			conn.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
   210  				Expect(p.data).To(Equal(b))
   211  				Expect(p.rcvTime).To(BeTemporally("~", time.Now(), time.Second))
   212  			}),
   213  		)
   214  		packetChan <- packetToRead{data: b}
   215  
   216  		// shutdown
   217  		phm.EXPECT().Close(gomock.Any())
   218  		close(packetChan)
   219  		tr.Close()
   220  	})
   221  
   222  	It("handles stateless resets", func() {
   223  		connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
   224  		packetChan := make(chan packetToRead)
   225  		tr := Transport{Conn: newMockPacketConn(packetChan)}
   226  		tr.init(true)
   227  		defer tr.Close()
   228  		phm := NewMockPacketHandlerManager(mockCtrl)
   229  		tr.handlerMap = phm
   230  
   231  		var token protocol.StatelessResetToken
   232  		rand.Read(token[:])
   233  
   234  		var b []byte
   235  		b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
   236  		Expect(err).ToNot(HaveOccurred())
   237  		b = append(b, token[:]...)
   238  		conn := NewMockPacketHandler(mockCtrl)
   239  		destroyed := make(chan struct{})
   240  		gomock.InOrder(
   241  			phm.EXPECT().GetByResetToken(token).Return(conn, true),
   242  			conn.EXPECT().destroy(gomock.Any()).Do(func(err error) {
   243  				Expect(err).To(MatchError(&StatelessResetError{Token: token}))
   244  				close(destroyed)
   245  			}),
   246  		)
   247  		packetChan <- packetToRead{data: b}
   248  		Eventually(destroyed).Should(BeClosed())
   249  
   250  		// shutdown
   251  		phm.EXPECT().Close(gomock.Any())
   252  		close(packetChan)
   253  		tr.Close()
   254  	})
   255  
   256  	It("sends stateless resets", func() {
   257  		connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
   258  		packetChan := make(chan packetToRead)
   259  		conn := newMockPacketConn(packetChan)
   260  		tr := Transport{
   261  			Conn:               conn,
   262  			StatelessResetKey:  &StatelessResetKey{1, 2, 3, 4},
   263  			ConnectionIDLength: connID.Len(),
   264  		}
   265  		tr.init(true)
   266  		defer tr.Close()
   267  		phm := NewMockPacketHandlerManager(mockCtrl)
   268  		tr.handlerMap = phm
   269  
   270  		var b []byte
   271  		b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
   272  		Expect(err).ToNot(HaveOccurred())
   273  		b = append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...)
   274  
   275  		var token protocol.StatelessResetToken
   276  		rand.Read(token[:])
   277  		written := make(chan struct{})
   278  		gomock.InOrder(
   279  			phm.EXPECT().GetByResetToken(gomock.Any()),
   280  			phm.EXPECT().Get(connID),
   281  			phm.EXPECT().GetStatelessResetToken(connID).Return(token),
   282  			conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func(b []byte, _ net.Addr) {
   283  				defer close(written)
   284  				Expect(bytes.Contains(b, token[:])).To(BeTrue())
   285  			}),
   286  		)
   287  		packetChan <- packetToRead{data: b}
   288  		Eventually(written).Should(BeClosed())
   289  
   290  		// shutdown
   291  		phm.EXPECT().Close(gomock.Any())
   292  		close(packetChan)
   293  		tr.Close()
   294  	})
   295  
   296  	It("closes uninitialized Transport and closes underlying PacketConn", func() {
   297  		packetChan := make(chan packetToRead)
   298  		pconn := newMockPacketConn(packetChan)
   299  
   300  		tr := &Transport{
   301  			Conn:        pconn,
   302  			createdConn: true, // owns pconn
   303  		}
   304  		// NO init
   305  
   306  		// shutdown
   307  		close(packetChan)
   308  		pconn.EXPECT().Close()
   309  		Expect(tr.Close()).To(Succeed())
   310  	})
   311  
   312  	It("doesn't add the PacketConn to the multiplexer if (*Transport).init fails", func() {
   313  		packetChan := make(chan packetToRead)
   314  		pconn := newMockPacketConn(packetChan)
   315  		syscallconn := &mockSyscallConn{pconn}
   316  
   317  		tr := &Transport{
   318  			Conn: syscallconn,
   319  		}
   320  
   321  		err := tr.init(false)
   322  		Expect(err).To(HaveOccurred())
   323  		conns := getMultiplexer().(*connMultiplexer).conns
   324  		Expect(len(conns)).To(BeZero())
   325  	})
   326  })
   327  
   328  type mockSyscallConn struct {
   329  	net.PacketConn
   330  }
   331  
   332  func (c *mockSyscallConn) SyscallConn() (syscall.RawConn, error) {
   333  	return nil, errors.New("mocked")
   334  }