github.com/MerlinKodo/quic-go@v0.39.2/client_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"net"
     8  	"time"
     9  
    10  	mocklogging "github.com/MerlinKodo/quic-go/internal/mocks/logging"
    11  	"github.com/MerlinKodo/quic-go/internal/protocol"
    12  	"github.com/MerlinKodo/quic-go/internal/utils"
    13  	"github.com/MerlinKodo/quic-go/logging"
    14  
    15  	. "github.com/onsi/ginkgo/v2"
    16  	. "github.com/onsi/gomega"
    17  	"go.uber.org/mock/gomock"
    18  )
    19  
    20  type nullMultiplexer struct{}
    21  
    22  func (n nullMultiplexer) AddConn(indexableConn)          {}
    23  func (n nullMultiplexer) RemoveConn(indexableConn) error { return nil }
    24  
    25  var _ = Describe("Client", func() {
    26  	var (
    27  		cl              *client
    28  		packetConn      *MockSendConn
    29  		connID          protocol.ConnectionID
    30  		origMultiplexer multiplexer
    31  		tlsConf         *tls.Config
    32  		tracer          *mocklogging.MockConnectionTracer
    33  		config          *Config
    34  
    35  		originalClientConnConstructor func(
    36  			conn sendConn,
    37  			runner connRunner,
    38  			destConnID protocol.ConnectionID,
    39  			srcConnID protocol.ConnectionID,
    40  			connIDGenerator ConnectionIDGenerator,
    41  			conf *Config,
    42  			tlsConf *tls.Config,
    43  			initialPacketNumber protocol.PacketNumber,
    44  			enable0RTT bool,
    45  			hasNegotiatedVersion bool,
    46  			tracer *logging.ConnectionTracer,
    47  			tracingID uint64,
    48  			logger utils.Logger,
    49  			v protocol.VersionNumber,
    50  		) quicConn
    51  	)
    52  
    53  	BeforeEach(func() {
    54  		tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
    55  		connID = protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0x13, 0x37})
    56  		originalClientConnConstructor = newClientConnection
    57  		var tr *logging.ConnectionTracer
    58  		tr, tracer = mocklogging.NewMockConnectionTracer(mockCtrl)
    59  		config = &Config{
    60  			Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) *logging.ConnectionTracer {
    61  				return tr
    62  			},
    63  			Versions: []protocol.VersionNumber{protocol.Version1},
    64  		}
    65  		Eventually(areConnsRunning).Should(BeFalse())
    66  		packetConn = NewMockSendConn(mockCtrl)
    67  		packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
    68  		packetConn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes()
    69  		cl = &client{
    70  			srcConnID:  connID,
    71  			destConnID: connID,
    72  			version:    protocol.Version1,
    73  			sendConn:   packetConn,
    74  			tracer:     tr,
    75  			logger:     utils.DefaultLogger,
    76  		}
    77  		getMultiplexer() // make the sync.Once execute
    78  		// replace the clientMuxer. getMultiplexer will now return the nullMultiplexer
    79  		origMultiplexer = connMuxer
    80  		connMuxer = &nullMultiplexer{}
    81  	})
    82  
    83  	AfterEach(func() {
    84  		connMuxer = origMultiplexer
    85  		newClientConnection = originalClientConnConstructor
    86  	})
    87  
    88  	AfterEach(func() {
    89  		if s, ok := cl.conn.(*connection); ok {
    90  			s.shutdown()
    91  		}
    92  		Eventually(areConnsRunning).Should(BeFalse())
    93  	})
    94  
    95  	Context("Dialing", func() {
    96  		var origGenerateConnectionIDForInitial func() (protocol.ConnectionID, error)
    97  
    98  		BeforeEach(func() {
    99  			origGenerateConnectionIDForInitial = generateConnectionIDForInitial
   100  			generateConnectionIDForInitial = func() (protocol.ConnectionID, error) {
   101  				return connID, nil
   102  			}
   103  		})
   104  
   105  		AfterEach(func() {
   106  			generateConnectionIDForInitial = origGenerateConnectionIDForInitial
   107  		})
   108  
   109  		It("returns after the handshake is complete", func() {
   110  			manager := NewMockPacketHandlerManager(mockCtrl)
   111  			manager.EXPECT().Add(gomock.Any(), gomock.Any())
   112  
   113  			run := make(chan struct{})
   114  			newClientConnection = func(
   115  				_ sendConn,
   116  				_ connRunner,
   117  				_ protocol.ConnectionID,
   118  				_ protocol.ConnectionID,
   119  				_ ConnectionIDGenerator,
   120  				_ *Config,
   121  				_ *tls.Config,
   122  				_ protocol.PacketNumber,
   123  				enable0RTT bool,
   124  				_ bool,
   125  				_ *logging.ConnectionTracer,
   126  				_ uint64,
   127  				_ utils.Logger,
   128  				_ protocol.VersionNumber,
   129  			) quicConn {
   130  				Expect(enable0RTT).To(BeFalse())
   131  				conn := NewMockQUICConn(mockCtrl)
   132  				conn.EXPECT().run().Do(func() { close(run) })
   133  				c := make(chan struct{})
   134  				close(c)
   135  				conn.EXPECT().HandshakeComplete().Return(c)
   136  				return conn
   137  			}
   138  			cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false)
   139  			Expect(err).ToNot(HaveOccurred())
   140  			cl.packetHandlers = manager
   141  			Expect(cl).ToNot(BeNil())
   142  			Expect(cl.dial(context.Background())).To(Succeed())
   143  			Eventually(run).Should(BeClosed())
   144  		})
   145  
   146  		It("returns early connections", func() {
   147  			manager := NewMockPacketHandlerManager(mockCtrl)
   148  			manager.EXPECT().Add(gomock.Any(), gomock.Any())
   149  			readyChan := make(chan struct{})
   150  			done := make(chan struct{})
   151  			newClientConnection = func(
   152  				_ sendConn,
   153  				runner connRunner,
   154  				_ protocol.ConnectionID,
   155  				_ protocol.ConnectionID,
   156  				_ ConnectionIDGenerator,
   157  				_ *Config,
   158  				_ *tls.Config,
   159  				_ protocol.PacketNumber,
   160  				enable0RTT bool,
   161  				_ bool,
   162  				_ *logging.ConnectionTracer,
   163  				_ uint64,
   164  				_ utils.Logger,
   165  				_ protocol.VersionNumber,
   166  			) quicConn {
   167  				Expect(enable0RTT).To(BeTrue())
   168  				conn := NewMockQUICConn(mockCtrl)
   169  				conn.EXPECT().run().Do(func() { close(done) })
   170  				conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   171  				conn.EXPECT().earlyConnReady().Return(readyChan)
   172  				return conn
   173  			}
   174  
   175  			cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true)
   176  			Expect(err).ToNot(HaveOccurred())
   177  			cl.packetHandlers = manager
   178  			Expect(cl).ToNot(BeNil())
   179  			Expect(cl.dial(context.Background())).To(Succeed())
   180  			Eventually(done).Should(BeClosed())
   181  		})
   182  
   183  		It("returns an error that occurs while waiting for the handshake to complete", func() {
   184  			manager := NewMockPacketHandlerManager(mockCtrl)
   185  			manager.EXPECT().Add(gomock.Any(), gomock.Any())
   186  
   187  			testErr := errors.New("early handshake error")
   188  			newClientConnection = func(
   189  				_ sendConn,
   190  				_ connRunner,
   191  				_ protocol.ConnectionID,
   192  				_ protocol.ConnectionID,
   193  				_ ConnectionIDGenerator,
   194  				_ *Config,
   195  				_ *tls.Config,
   196  				_ protocol.PacketNumber,
   197  				_ bool,
   198  				_ bool,
   199  				_ *logging.ConnectionTracer,
   200  				_ uint64,
   201  				_ utils.Logger,
   202  				_ protocol.VersionNumber,
   203  			) quicConn {
   204  				conn := NewMockQUICConn(mockCtrl)
   205  				conn.EXPECT().run().Return(testErr)
   206  				conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   207  				conn.EXPECT().earlyConnReady().Return(make(chan struct{}))
   208  				return conn
   209  			}
   210  			var closed bool
   211  			cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true)
   212  			Expect(err).ToNot(HaveOccurred())
   213  			cl.packetHandlers = manager
   214  			Expect(cl).ToNot(BeNil())
   215  			Expect(cl.dial(context.Background())).To(MatchError(testErr))
   216  			Expect(closed).To(BeTrue())
   217  		})
   218  
   219  		Context("quic.Config", func() {
   220  			It("setups with the right values", func() {
   221  				tokenStore := NewLRUTokenStore(10, 4)
   222  				config := &Config{
   223  					HandshakeIdleTimeout:  1337 * time.Minute,
   224  					MaxIdleTimeout:        42 * time.Hour,
   225  					MaxIncomingStreams:    1234,
   226  					MaxIncomingUniStreams: 4321,
   227  					TokenStore:            tokenStore,
   228  					EnableDatagrams:       true,
   229  				}
   230  				c := populateConfig(config)
   231  				Expect(c.HandshakeIdleTimeout).To(Equal(1337 * time.Minute))
   232  				Expect(c.MaxIdleTimeout).To(Equal(42 * time.Hour))
   233  				Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234))
   234  				Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321))
   235  				Expect(c.TokenStore).To(Equal(tokenStore))
   236  				Expect(c.EnableDatagrams).To(BeTrue())
   237  			})
   238  
   239  			It("disables bidirectional streams", func() {
   240  				config := &Config{
   241  					MaxIncomingStreams:    -1,
   242  					MaxIncomingUniStreams: 4321,
   243  				}
   244  				c := populateConfig(config)
   245  				Expect(c.MaxIncomingStreams).To(BeZero())
   246  				Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321))
   247  			})
   248  
   249  			It("disables unidirectional streams", func() {
   250  				config := &Config{
   251  					MaxIncomingStreams:    1234,
   252  					MaxIncomingUniStreams: -1,
   253  				}
   254  				c := populateConfig(config)
   255  				Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234))
   256  				Expect(c.MaxIncomingUniStreams).To(BeZero())
   257  			})
   258  
   259  			It("fills in default values if options are not set in the Config", func() {
   260  				c := populateConfig(&Config{})
   261  				Expect(c.Versions).To(Equal(protocol.SupportedVersions))
   262  				Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
   263  				Expect(c.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
   264  			})
   265  		})
   266  
   267  		It("creates new connections with the right parameters", func() {
   268  			config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}}
   269  			c := make(chan struct{})
   270  			var version protocol.VersionNumber
   271  			var conf *Config
   272  			done := make(chan struct{})
   273  			newClientConnection = func(
   274  				connP sendConn,
   275  				_ connRunner,
   276  				_ protocol.ConnectionID,
   277  				_ protocol.ConnectionID,
   278  				_ ConnectionIDGenerator,
   279  				configP *Config,
   280  				_ *tls.Config,
   281  				_ protocol.PacketNumber,
   282  				_ bool,
   283  				_ bool,
   284  				_ *logging.ConnectionTracer,
   285  				_ uint64,
   286  				_ utils.Logger,
   287  				versionP protocol.VersionNumber,
   288  			) quicConn {
   289  				version = versionP
   290  				conf = configP
   291  				close(c)
   292  				// TODO: check connection IDs?
   293  				conn := NewMockQUICConn(mockCtrl)
   294  				conn.EXPECT().run()
   295  				conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   296  				conn.EXPECT().destroy(gomock.Any()).MaxTimes(1)
   297  				close(done)
   298  				return conn
   299  			}
   300  			packetConn := NewMockPacketConn(mockCtrl)
   301  			packetConn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func([]byte) (int, net.Addr, error) {
   302  				<-done
   303  				return 0, nil, errors.New("closed")
   304  			})
   305  			packetConn.EXPECT().LocalAddr()
   306  			packetConn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
   307  			_, err := Dial(context.Background(), packetConn, &net.UDPAddr{}, tlsConf, config)
   308  			Expect(err).ToNot(HaveOccurred())
   309  			Eventually(c).Should(BeClosed())
   310  			Expect(version).To(Equal(config.Versions[0]))
   311  			Expect(conf.Versions).To(Equal(config.Versions))
   312  		})
   313  
   314  		It("creates a new connections after version negotiation", func() {
   315  			var counter int
   316  			newClientConnection = func(
   317  				_ sendConn,
   318  				runner connRunner,
   319  				_ protocol.ConnectionID,
   320  				connID protocol.ConnectionID,
   321  				_ ConnectionIDGenerator,
   322  				configP *Config,
   323  				_ *tls.Config,
   324  				pn protocol.PacketNumber,
   325  				_ bool,
   326  				hasNegotiatedVersion bool,
   327  				_ *logging.ConnectionTracer,
   328  				_ uint64,
   329  				_ utils.Logger,
   330  				versionP protocol.VersionNumber,
   331  			) quicConn {
   332  				conn := NewMockQUICConn(mockCtrl)
   333  				conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
   334  				if counter == 0 {
   335  					Expect(pn).To(BeZero())
   336  					Expect(hasNegotiatedVersion).To(BeFalse())
   337  					conn.EXPECT().run().DoAndReturn(func() error {
   338  						runner.Remove(connID)
   339  						return &errCloseForRecreating{
   340  							nextPacketNumber: 109,
   341  							nextVersion:      789,
   342  						}
   343  					})
   344  				} else {
   345  					Expect(pn).To(Equal(protocol.PacketNumber(109)))
   346  					Expect(hasNegotiatedVersion).To(BeTrue())
   347  					conn.EXPECT().run()
   348  					conn.EXPECT().destroy(gomock.Any())
   349  				}
   350  				counter++
   351  				return conn
   352  			}
   353  
   354  			config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}}
   355  			tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
   356  			_, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config)
   357  			Expect(err).ToNot(HaveOccurred())
   358  			Expect(counter).To(Equal(2))
   359  		})
   360  	})
   361  })