github.com/tumi8/quic-go@v0.37.4-tum/integrationtests/versionnegotiation/handshake_test.go (about)

     1  package versionnegotiation
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"net"
     8  
     9  	"github.com/tumi8/quic-go"
    10  	"github.com/tumi8/quic-go/integrationtests/tools/israce"
    11  	"github.com/tumi8/quic-go/noninternal/protocol"
    12  	"github.com/tumi8/quic-go/logging"
    13  
    14  	. "github.com/onsi/ginkgo/v2"
    15  	. "github.com/onsi/gomega"
    16  )
    17  
    18  type versioner interface {
    19  	GetVersion() protocol.VersionNumber
    20  }
    21  
    22  type versionNegotiationTracer struct {
    23  	logging.NullConnectionTracer
    24  
    25  	loggedVersions                 bool
    26  	receivedVersionNegotiation     bool
    27  	chosen                         logging.VersionNumber
    28  	clientVersions, serverVersions []logging.VersionNumber
    29  }
    30  
    31  var _ logging.ConnectionTracer = &versionNegotiationTracer{}
    32  
    33  func (t *versionNegotiationTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) {
    34  	if t.loggedVersions {
    35  		Fail("only expected one call to NegotiatedVersions")
    36  	}
    37  	t.loggedVersions = true
    38  	t.chosen = chosen
    39  	t.clientVersions = clientVersions
    40  	t.serverVersions = serverVersions
    41  }
    42  
    43  func (t *versionNegotiationTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) {
    44  	t.receivedVersionNegotiation = true
    45  }
    46  
    47  var _ = Describe("Handshake tests", func() {
    48  	startServer := func(tlsConf *tls.Config, conf *quic.Config) (*quic.Listener, func()) {
    49  		server, err := quic.ListenAddr("localhost:0", tlsConf, conf)
    50  		Expect(err).ToNot(HaveOccurred())
    51  
    52  		acceptStopped := make(chan struct{})
    53  		go func() {
    54  			defer GinkgoRecover()
    55  			defer close(acceptStopped)
    56  			for {
    57  				if _, err := server.Accept(context.Background()); err != nil {
    58  					return
    59  				}
    60  			}
    61  		}()
    62  
    63  		return server, func() {
    64  			server.Close()
    65  			<-acceptStopped
    66  		}
    67  	}
    68  
    69  	var supportedVersions []protocol.VersionNumber
    70  
    71  	BeforeEach(func() {
    72  		supportedVersions = append([]quic.VersionNumber{}, protocol.SupportedVersions...)
    73  		protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{7, 8, 9, 10}...)
    74  	})
    75  
    76  	AfterEach(func() {
    77  		protocol.SupportedVersions = supportedVersions
    78  	})
    79  
    80  	if !israce.Enabled {
    81  		It("when the server supports more versions than the client", func() {
    82  			expectedVersion := protocol.SupportedVersions[0]
    83  			// the server doesn't support the highest supported version, which is the first one the client will try
    84  			// but it supports a bunch of versions that the client doesn't speak
    85  			serverConfig := &quic.Config{}
    86  			serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9}
    87  			serverTracer := &versionNegotiationTracer{}
    88  			serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
    89  				return serverTracer
    90  			}
    91  			server, cl := startServer(getTLSConfig(), serverConfig)
    92  			defer cl()
    93  			clientTracer := &versionNegotiationTracer{}
    94  			conn, err := quic.DialAddr(
    95  				context.Background(),
    96  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
    97  				getTLSClientConfig(),
    98  				maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) logging.ConnectionTracer {
    99  					return clientTracer
   100  				}}),
   101  			)
   102  			Expect(err).ToNot(HaveOccurred())
   103  			Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion))
   104  			Expect(conn.CloseWithError(0, "")).To(Succeed())
   105  			Expect(clientTracer.chosen).To(Equal(expectedVersion))
   106  			Expect(clientTracer.receivedVersionNegotiation).To(BeFalse())
   107  			Expect(clientTracer.clientVersions).To(Equal(protocol.SupportedVersions))
   108  			Expect(clientTracer.serverVersions).To(BeEmpty())
   109  			Expect(serverTracer.chosen).To(Equal(expectedVersion))
   110  			Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions))
   111  			Expect(serverTracer.clientVersions).To(BeEmpty())
   112  		})
   113  
   114  		It("when the client supports more versions than the server supports", func() {
   115  			expectedVersion := protocol.SupportedVersions[0]
   116  			// the server doesn't support the highest supported version, which is the first one the client will try
   117  			// but it supports a bunch of versions that the client doesn't speak
   118  			serverTracer := &versionNegotiationTracer{}
   119  			serverConfig := &quic.Config{}
   120  			serverConfig.Versions = supportedVersions
   121  			serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
   122  				return serverTracer
   123  			}
   124  			server, cl := startServer(getTLSConfig(), serverConfig)
   125  			defer cl()
   126  			clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10}
   127  			clientTracer := &versionNegotiationTracer{}
   128  			conn, err := quic.DialAddr(
   129  				context.Background(),
   130  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   131  				getTLSClientConfig(),
   132  				maybeAddQLOGTracer(&quic.Config{
   133  					Versions: clientVersions,
   134  					Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
   135  						return clientTracer
   136  					},
   137  				}),
   138  			)
   139  			Expect(err).ToNot(HaveOccurred())
   140  			Expect(conn.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0]))
   141  			Expect(conn.CloseWithError(0, "")).To(Succeed())
   142  			Expect(clientTracer.chosen).To(Equal(expectedVersion))
   143  			Expect(clientTracer.receivedVersionNegotiation).To(BeTrue())
   144  			Expect(clientTracer.clientVersions).To(Equal(clientVersions))
   145  			Expect(clientTracer.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions
   146  			Expect(serverTracer.chosen).To(Equal(expectedVersion))
   147  			Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions))
   148  			Expect(serverTracer.clientVersions).To(BeEmpty())
   149  		})
   150  	}
   151  })