github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/integrationtests/versionnegotiation/handshake_test.go (about)

     1  package versionnegotiation
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"time"
    10  
    11  	"github.com/apernet/quic-go"
    12  	"github.com/apernet/quic-go/integrationtests/tools/israce"
    13  	"github.com/apernet/quic-go/internal/protocol"
    14  	"github.com/apernet/quic-go/logging"
    15  
    16  	. "github.com/onsi/ginkgo/v2"
    17  	. "github.com/onsi/gomega"
    18  )
    19  
    20  type versioner interface {
    21  	GetVersion() protocol.Version
    22  }
    23  
    24  type result struct {
    25  	loggedVersions                 bool
    26  	receivedVersionNegotiation     bool
    27  	chosen                         logging.VersionNumber
    28  	clientVersions, serverVersions []logging.VersionNumber
    29  }
    30  
    31  func newVersionNegotiationTracer() (*result, *logging.ConnectionTracer) {
    32  	r := &result{}
    33  	return r, &logging.ConnectionTracer{
    34  		NegotiatedVersion: func(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) {
    35  			if r.loggedVersions {
    36  				Fail("only expected one call to NegotiatedVersions")
    37  			}
    38  			r.loggedVersions = true
    39  			r.chosen = chosen
    40  			r.clientVersions = clientVersions
    41  			r.serverVersions = serverVersions
    42  		},
    43  		ReceivedVersionNegotiationPacket: func(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) {
    44  			r.receivedVersionNegotiation = true
    45  		},
    46  	}
    47  }
    48  
    49  var _ = Describe("Handshake tests", func() {
    50  	startServer := func(tlsConf *tls.Config, conf *quic.Config) (*quic.Listener, func()) {
    51  		server, err := quic.ListenAddr("localhost:0", tlsConf, conf)
    52  		Expect(err).ToNot(HaveOccurred())
    53  
    54  		acceptStopped := make(chan struct{})
    55  		go func() {
    56  			defer GinkgoRecover()
    57  			defer close(acceptStopped)
    58  			for {
    59  				if _, err := server.Accept(context.Background()); err != nil {
    60  					return
    61  				}
    62  			}
    63  		}()
    64  
    65  		return server, func() {
    66  			server.Close()
    67  			<-acceptStopped
    68  		}
    69  	}
    70  
    71  	var supportedVersions []protocol.Version
    72  
    73  	BeforeEach(func() {
    74  		supportedVersions = append([]quic.Version{}, protocol.SupportedVersions...)
    75  		protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.Version{7, 8, 9, 10}...)
    76  	})
    77  
    78  	AfterEach(func() {
    79  		protocol.SupportedVersions = supportedVersions
    80  	})
    81  
    82  	if !israce.Enabled {
    83  		It("when the server supports more versions than the client", func() {
    84  			expectedVersion := protocol.SupportedVersions[0]
    85  			// the server doesn't support the highest supported version, which is the first one the client will try
    86  			// but it supports a bunch of versions that the client doesn't speak
    87  			serverConfig := &quic.Config{}
    88  			serverConfig.Versions = []protocol.Version{7, 8, protocol.SupportedVersions[0], 9}
    89  			serverResult, serverTracer := newVersionNegotiationTracer()
    90  			serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
    91  				return serverTracer
    92  			}
    93  			server, cl := startServer(getTLSConfig(), serverConfig)
    94  			defer cl()
    95  			clientResult, clientTracer := newVersionNegotiationTracer()
    96  			conn, err := quic.DialAddr(
    97  				context.Background(),
    98  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
    99  				getTLSClientConfig(),
   100  				maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) *logging.ConnectionTracer {
   101  					return clientTracer
   102  				}}),
   103  			)
   104  			Expect(err).ToNot(HaveOccurred())
   105  			Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion))
   106  			Expect(conn.CloseWithError(0, "")).To(Succeed())
   107  			Expect(clientResult.chosen).To(Equal(expectedVersion))
   108  			Expect(clientResult.receivedVersionNegotiation).To(BeFalse())
   109  			Expect(clientResult.clientVersions).To(Equal(protocol.SupportedVersions))
   110  			Expect(clientResult.serverVersions).To(BeEmpty())
   111  			Expect(serverResult.chosen).To(Equal(expectedVersion))
   112  			Expect(serverResult.serverVersions).To(Equal(serverConfig.Versions))
   113  			Expect(serverResult.clientVersions).To(BeEmpty())
   114  		})
   115  
   116  		It("when the client supports more versions than the server supports", func() {
   117  			expectedVersion := protocol.SupportedVersions[0]
   118  			// The server doesn't support the highest supported version, which is the first one the client will try,
   119  			// but it supports a bunch of versions that the client doesn't speak
   120  			serverResult, serverTracer := newVersionNegotiationTracer()
   121  			serverConfig := &quic.Config{}
   122  			serverConfig.Versions = supportedVersions
   123  			serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
   124  				return serverTracer
   125  			}
   126  			server, cl := startServer(getTLSConfig(), serverConfig)
   127  			defer cl()
   128  			clientVersions := []protocol.Version{7, 8, 9, protocol.SupportedVersions[0], 10}
   129  			clientResult, clientTracer := newVersionNegotiationTracer()
   130  			conn, err := quic.DialAddr(
   131  				context.Background(),
   132  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   133  				getTLSClientConfig(),
   134  				maybeAddQLOGTracer(&quic.Config{
   135  					Versions: clientVersions,
   136  					Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
   137  						return clientTracer
   138  					},
   139  				}),
   140  			)
   141  			Expect(err).ToNot(HaveOccurred())
   142  			Expect(conn.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0]))
   143  			Expect(conn.CloseWithError(0, "")).To(Succeed())
   144  			Expect(clientResult.chosen).To(Equal(expectedVersion))
   145  			Expect(clientResult.receivedVersionNegotiation).To(BeTrue())
   146  			Expect(clientResult.clientVersions).To(Equal(clientVersions))
   147  			Expect(clientResult.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions
   148  			Expect(serverResult.chosen).To(Equal(expectedVersion))
   149  			Expect(serverResult.serverVersions).To(Equal(serverConfig.Versions))
   150  			Expect(serverResult.clientVersions).To(BeEmpty())
   151  		})
   152  
   153  		It("fails if the server disables version negotiation", func() {
   154  			// The server doesn't support the highest supported version, which is the first one the client will try,
   155  			// but it supports a bunch of versions that the client doesn't speak
   156  			_, serverTracer := newVersionNegotiationTracer()
   157  			serverConfig := &quic.Config{}
   158  			serverConfig.Versions = supportedVersions
   159  			serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
   160  				return serverTracer
   161  			}
   162  			conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
   163  			Expect(err).ToNot(HaveOccurred())
   164  			tr := &quic.Transport{
   165  				Conn:                             conn,
   166  				DisableVersionNegotiationPackets: true,
   167  			}
   168  			ln, err := tr.Listen(getTLSConfig(), serverConfig)
   169  			Expect(err).ToNot(HaveOccurred())
   170  			defer ln.Close()
   171  
   172  			clientVersions := []protocol.Version{7, 8, 9, protocol.SupportedVersions[0], 10}
   173  			clientResult, clientTracer := newVersionNegotiationTracer()
   174  			_, err = quic.DialAddr(
   175  				context.Background(),
   176  				fmt.Sprintf("localhost:%d", conn.LocalAddr().(*net.UDPAddr).Port),
   177  				getTLSClientConfig(),
   178  				maybeAddQLOGTracer(&quic.Config{
   179  					Versions: clientVersions,
   180  					Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
   181  						return clientTracer
   182  					},
   183  					HandshakeIdleTimeout: 100 * time.Millisecond,
   184  				}),
   185  			)
   186  			Expect(err).To(HaveOccurred())
   187  			var nerr net.Error
   188  			Expect(errors.As(err, &nerr)).To(BeTrue())
   189  			Expect(nerr.Timeout()).To(BeTrue())
   190  			Expect(clientResult.receivedVersionNegotiation).To(BeFalse())
   191  		})
   192  	}
   193  })