github.com/dorkamotorka/go/src@v0.0.0-20230614113921-187095f0e316/crypto/tls/quic_test.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"reflect"
    11  	"testing"
    12  )
    13  
    14  type testQUICConn struct {
    15  	t           *testing.T
    16  	conn        *QUICConn
    17  	readSecret  map[QUICEncryptionLevel]suiteSecret
    18  	writeSecret map[QUICEncryptionLevel]suiteSecret
    19  	gotParams   []byte
    20  	complete    bool
    21  }
    22  
    23  func newTestQUICClient(t *testing.T, config *Config) *testQUICConn {
    24  	q := &testQUICConn{t: t}
    25  	q.conn = QUICClient(&QUICConfig{
    26  		TLSConfig: config,
    27  	})
    28  	t.Cleanup(func() {
    29  		q.conn.Close()
    30  	})
    31  	return q
    32  }
    33  
    34  func newTestQUICServer(t *testing.T, config *Config) *testQUICConn {
    35  	q := &testQUICConn{t: t}
    36  	q.conn = QUICServer(&QUICConfig{
    37  		TLSConfig: config,
    38  	})
    39  	t.Cleanup(func() {
    40  		q.conn.Close()
    41  	})
    42  	return q
    43  }
    44  
    45  type suiteSecret struct {
    46  	suite  uint16
    47  	secret []byte
    48  }
    49  
    50  func (q *testQUICConn) setReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
    51  	if _, ok := q.writeSecret[level]; !ok {
    52  		q.t.Errorf("SetReadSecret for level %v called before SetWriteSecret", level)
    53  	}
    54  	if level == QUICEncryptionLevelApplication && !q.complete {
    55  		q.t.Errorf("SetReadSecret for level %v called before HandshakeComplete", level)
    56  	}
    57  	if _, ok := q.readSecret[level]; ok {
    58  		q.t.Errorf("SetReadSecret for level %v called twice", level)
    59  	}
    60  	if q.readSecret == nil {
    61  		q.readSecret = map[QUICEncryptionLevel]suiteSecret{}
    62  	}
    63  	switch level {
    64  	case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication:
    65  		q.readSecret[level] = suiteSecret{suite, secret}
    66  	default:
    67  		q.t.Errorf("SetReadSecret for unexpected level %v", level)
    68  	}
    69  }
    70  
    71  func (q *testQUICConn) setWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
    72  	if _, ok := q.writeSecret[level]; ok {
    73  		q.t.Errorf("SetWriteSecret for level %v called twice", level)
    74  	}
    75  	if q.writeSecret == nil {
    76  		q.writeSecret = map[QUICEncryptionLevel]suiteSecret{}
    77  	}
    78  	switch level {
    79  	case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication:
    80  		q.writeSecret[level] = suiteSecret{suite, secret}
    81  	default:
    82  		q.t.Errorf("SetWriteSecret for unexpected level %v", level)
    83  	}
    84  }
    85  
    86  var errTransportParametersRequired = errors.New("transport parameters required")
    87  
    88  func runTestQUICConnection(ctx context.Context, cli, srv *testQUICConn, onHandleCryptoData func()) error {
    89  	a, b := cli, srv
    90  	for _, c := range []*testQUICConn{a, b} {
    91  		if !c.conn.conn.quic.started {
    92  			if err := c.conn.Start(ctx); err != nil {
    93  				return err
    94  			}
    95  		}
    96  	}
    97  	idleCount := 0
    98  	for {
    99  		e := a.conn.NextEvent()
   100  		switch e.Kind {
   101  		case QUICNoEvent:
   102  			idleCount++
   103  			if idleCount == 2 {
   104  				if !a.complete || !b.complete {
   105  					return errors.New("handshake incomplete")
   106  				}
   107  				return nil
   108  			}
   109  			a, b = b, a
   110  		case QUICSetReadSecret:
   111  			a.setReadSecret(e.Level, e.Suite, e.Data)
   112  		case QUICSetWriteSecret:
   113  			a.setWriteSecret(e.Level, e.Suite, e.Data)
   114  		case QUICWriteData:
   115  			if err := b.conn.HandleData(e.Level, e.Data); err != nil {
   116  				return err
   117  			}
   118  		case QUICTransportParameters:
   119  			a.gotParams = e.Data
   120  			if a.gotParams == nil {
   121  				a.gotParams = []byte{}
   122  			}
   123  		case QUICTransportParametersRequired:
   124  			return errTransportParametersRequired
   125  		case QUICHandshakeDone:
   126  			a.complete = true
   127  			if a == srv {
   128  				if err := srv.conn.SendSessionTicket(false); err != nil {
   129  					return err
   130  				}
   131  			}
   132  		}
   133  		if e.Kind != QUICNoEvent {
   134  			idleCount = 0
   135  		}
   136  	}
   137  }
   138  
   139  func TestQUICConnection(t *testing.T) {
   140  	config := testConfig.Clone()
   141  	config.MinVersion = VersionTLS13
   142  
   143  	cli := newTestQUICClient(t, config)
   144  	cli.conn.SetTransportParameters(nil)
   145  
   146  	srv := newTestQUICServer(t, config)
   147  	srv.conn.SetTransportParameters(nil)
   148  
   149  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   150  		t.Fatalf("error during connection handshake: %v", err)
   151  	}
   152  
   153  	if _, ok := cli.readSecret[QUICEncryptionLevelHandshake]; !ok {
   154  		t.Errorf("client has no Handshake secret")
   155  	}
   156  	if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; !ok {
   157  		t.Errorf("client has no Application secret")
   158  	}
   159  	if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; !ok {
   160  		t.Errorf("server has no Handshake secret")
   161  	}
   162  	if _, ok := srv.readSecret[QUICEncryptionLevelApplication]; !ok {
   163  		t.Errorf("server has no Application secret")
   164  	}
   165  	for _, level := range []QUICEncryptionLevel{QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication} {
   166  		if _, ok := cli.readSecret[level]; !ok {
   167  			t.Errorf("client has no %v read secret", level)
   168  		}
   169  		if _, ok := srv.readSecret[level]; !ok {
   170  			t.Errorf("server has no %v read secret", level)
   171  		}
   172  		if !reflect.DeepEqual(cli.readSecret[level], srv.writeSecret[level]) {
   173  			t.Errorf("client read secret does not match server write secret for level %v", level)
   174  		}
   175  		if !reflect.DeepEqual(cli.writeSecret[level], srv.readSecret[level]) {
   176  			t.Errorf("client write secret does not match server read secret for level %v", level)
   177  		}
   178  	}
   179  }
   180  
   181  func TestQUICSessionResumption(t *testing.T) {
   182  	clientConfig := testConfig.Clone()
   183  	clientConfig.MinVersion = VersionTLS13
   184  	clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
   185  	clientConfig.ServerName = "example.go.dev"
   186  
   187  	serverConfig := testConfig.Clone()
   188  	serverConfig.MinVersion = VersionTLS13
   189  
   190  	cli := newTestQUICClient(t, clientConfig)
   191  	cli.conn.SetTransportParameters(nil)
   192  	srv := newTestQUICServer(t, serverConfig)
   193  	srv.conn.SetTransportParameters(nil)
   194  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   195  		t.Fatalf("error during first connection handshake: %v", err)
   196  	}
   197  	if cli.conn.ConnectionState().DidResume {
   198  		t.Errorf("first connection unexpectedly used session resumption")
   199  	}
   200  
   201  	cli2 := newTestQUICClient(t, clientConfig)
   202  	cli2.conn.SetTransportParameters(nil)
   203  	srv2 := newTestQUICServer(t, serverConfig)
   204  	srv2.conn.SetTransportParameters(nil)
   205  	if err := runTestQUICConnection(context.Background(), cli2, srv2, nil); err != nil {
   206  		t.Fatalf("error during second connection handshake: %v", err)
   207  	}
   208  	if !cli2.conn.ConnectionState().DidResume {
   209  		t.Errorf("second connection did not use session resumption")
   210  	}
   211  }
   212  
   213  func TestQUICPostHandshakeClientAuthentication(t *testing.T) {
   214  	// RFC 9001, Section 4.4.
   215  	config := testConfig.Clone()
   216  	config.MinVersion = VersionTLS13
   217  	cli := newTestQUICClient(t, config)
   218  	cli.conn.SetTransportParameters(nil)
   219  	srv := newTestQUICServer(t, config)
   220  	srv.conn.SetTransportParameters(nil)
   221  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   222  		t.Fatalf("error during connection handshake: %v", err)
   223  	}
   224  
   225  	certReq := new(certificateRequestMsgTLS13)
   226  	certReq.ocspStapling = true
   227  	certReq.scts = true
   228  	certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
   229  	certReqBytes, err := certReq.marshal()
   230  	if err != nil {
   231  		t.Fatal(err)
   232  	}
   233  	if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{
   234  		byte(typeCertificateRequest),
   235  		byte(0), byte(0), byte(len(certReqBytes)),
   236  	}, certReqBytes...)); err == nil {
   237  		t.Fatalf("post-handshake authentication request: got no error, want one")
   238  	}
   239  }
   240  
   241  func TestQUICPostHandshakeKeyUpdate(t *testing.T) {
   242  	// RFC 9001, Section 6.
   243  	config := testConfig.Clone()
   244  	config.MinVersion = VersionTLS13
   245  	cli := newTestQUICClient(t, config)
   246  	cli.conn.SetTransportParameters(nil)
   247  	srv := newTestQUICServer(t, config)
   248  	srv.conn.SetTransportParameters(nil)
   249  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   250  		t.Fatalf("error during connection handshake: %v", err)
   251  	}
   252  
   253  	keyUpdate := new(keyUpdateMsg)
   254  	keyUpdateBytes, err := keyUpdate.marshal()
   255  	if err != nil {
   256  		t.Fatal(err)
   257  	}
   258  	if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{
   259  		byte(typeKeyUpdate),
   260  		byte(0), byte(0), byte(len(keyUpdateBytes)),
   261  	}, keyUpdateBytes...)); !errors.Is(err, alertUnexpectedMessage) {
   262  		t.Fatalf("key update request: got error %v, want alertUnexpectedMessage", err)
   263  	}
   264  }
   265  
   266  func TestQUICHandshakeError(t *testing.T) {
   267  	clientConfig := testConfig.Clone()
   268  	clientConfig.MinVersion = VersionTLS13
   269  	clientConfig.InsecureSkipVerify = false
   270  	clientConfig.ServerName = "name"
   271  
   272  	serverConfig := testConfig.Clone()
   273  	serverConfig.MinVersion = VersionTLS13
   274  
   275  	cli := newTestQUICClient(t, clientConfig)
   276  	cli.conn.SetTransportParameters(nil)
   277  	srv := newTestQUICServer(t, serverConfig)
   278  	srv.conn.SetTransportParameters(nil)
   279  	err := runTestQUICConnection(context.Background(), cli, srv, nil)
   280  	if !errors.Is(err, AlertError(alertBadCertificate)) {
   281  		t.Errorf("connection handshake terminated with error %q, want alertBadCertificate", err)
   282  	}
   283  	var e *CertificateVerificationError
   284  	if !errors.As(err, &e) {
   285  		t.Errorf("connection handshake terminated with error %q, want CertificateVerificationError", err)
   286  	}
   287  }
   288  
   289  // Test that QUICConn.ConnectionState can be used during the handshake,
   290  // and that it reports the application protocol as soon as it has been
   291  // negotiated.
   292  func TestQUICConnectionState(t *testing.T) {
   293  	config := testConfig.Clone()
   294  	config.MinVersion = VersionTLS13
   295  	config.NextProtos = []string{"h3"}
   296  	cli := newTestQUICClient(t, config)
   297  	cli.conn.SetTransportParameters(nil)
   298  	srv := newTestQUICServer(t, config)
   299  	srv.conn.SetTransportParameters(nil)
   300  	onHandleCryptoData := func() {
   301  		cliCS := cli.conn.ConnectionState()
   302  		cliWantALPN := ""
   303  		if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; ok {
   304  			cliWantALPN = "h3"
   305  		}
   306  		if want, got := cliCS.NegotiatedProtocol, cliWantALPN; want != got {
   307  			t.Errorf("cli.ConnectionState().NegotiatedProtocol = %q, want %q", want, got)
   308  		}
   309  
   310  		srvCS := srv.conn.ConnectionState()
   311  		srvWantALPN := ""
   312  		if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; ok {
   313  			srvWantALPN = "h3"
   314  		}
   315  		if want, got := srvCS.NegotiatedProtocol, srvWantALPN; want != got {
   316  			t.Errorf("srv.ConnectionState().NegotiatedProtocol = %q, want %q", want, got)
   317  		}
   318  	}
   319  	if err := runTestQUICConnection(context.Background(), cli, srv, onHandleCryptoData); err != nil {
   320  		t.Fatalf("error during connection handshake: %v", err)
   321  	}
   322  }
   323  
   324  func TestQUICStartContextPropagation(t *testing.T) {
   325  	const key = "key"
   326  	const value = "value"
   327  	ctx := context.WithValue(context.Background(), key, value)
   328  	config := testConfig.Clone()
   329  	config.MinVersion = VersionTLS13
   330  	calls := 0
   331  	config.GetConfigForClient = func(info *ClientHelloInfo) (*Config, error) {
   332  		calls++
   333  		got, _ := info.Context().Value(key).(string)
   334  		if got != value {
   335  			t.Errorf("GetConfigForClient context key %q has value %q, want %q", key, got, value)
   336  		}
   337  		return nil, nil
   338  	}
   339  	cli := newTestQUICClient(t, config)
   340  	cli.conn.SetTransportParameters(nil)
   341  	srv := newTestQUICServer(t, config)
   342  	srv.conn.SetTransportParameters(nil)
   343  	if err := runTestQUICConnection(ctx, cli, srv, nil); err != nil {
   344  		t.Fatalf("error during connection handshake: %v", err)
   345  	}
   346  	if calls != 1 {
   347  		t.Errorf("GetConfigForClient called %v times, want 1", calls)
   348  	}
   349  }
   350  
   351  func TestQUICDelayedTransportParameters(t *testing.T) {
   352  	clientConfig := testConfig.Clone()
   353  	clientConfig.MinVersion = VersionTLS13
   354  	clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
   355  	clientConfig.ServerName = "example.go.dev"
   356  
   357  	serverConfig := testConfig.Clone()
   358  	serverConfig.MinVersion = VersionTLS13
   359  
   360  	cliParams := "client params"
   361  	srvParams := "server params"
   362  
   363  	cli := newTestQUICClient(t, clientConfig)
   364  	srv := newTestQUICServer(t, serverConfig)
   365  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired {
   366  		t.Fatalf("handshake with no client parameters: %v; want errTransportParametersRequired", err)
   367  	}
   368  	cli.conn.SetTransportParameters([]byte(cliParams))
   369  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired {
   370  		t.Fatalf("handshake with no server parameters: %v; want errTransportParametersRequired", err)
   371  	}
   372  	srv.conn.SetTransportParameters([]byte(srvParams))
   373  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   374  		t.Fatalf("error during connection handshake: %v", err)
   375  	}
   376  
   377  	if got, want := string(cli.gotParams), srvParams; got != want {
   378  		t.Errorf("client got transport params: %q, want %q", got, want)
   379  	}
   380  	if got, want := string(srv.gotParams), cliParams; got != want {
   381  		t.Errorf("server got transport params: %q, want %q", got, want)
   382  	}
   383  }
   384  
   385  func TestQUICEmptyTransportParameters(t *testing.T) {
   386  	config := testConfig.Clone()
   387  	config.MinVersion = VersionTLS13
   388  
   389  	cli := newTestQUICClient(t, config)
   390  	cli.conn.SetTransportParameters(nil)
   391  	srv := newTestQUICServer(t, config)
   392  	srv.conn.SetTransportParameters(nil)
   393  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   394  		t.Fatalf("error during connection handshake: %v", err)
   395  	}
   396  
   397  	if cli.gotParams == nil {
   398  		t.Errorf("client did not get transport params")
   399  	}
   400  	if srv.gotParams == nil {
   401  		t.Errorf("server did not get transport params")
   402  	}
   403  	if len(cli.gotParams) != 0 {
   404  		t.Errorf("client got transport params: %v, want empty", cli.gotParams)
   405  	}
   406  	if len(srv.gotParams) != 0 {
   407  		t.Errorf("server got transport params: %v, want empty", srv.gotParams)
   408  	}
   409  }
   410  
   411  func TestQUICCanceledWaitingForData(t *testing.T) {
   412  	config := testConfig.Clone()
   413  	config.MinVersion = VersionTLS13
   414  	cli := newTestQUICClient(t, config)
   415  	cli.conn.SetTransportParameters(nil)
   416  	cli.conn.Start(context.Background())
   417  	for cli.conn.NextEvent().Kind != QUICNoEvent {
   418  	}
   419  	err := cli.conn.Close()
   420  	if !errors.Is(err, alertCloseNotify) {
   421  		t.Errorf("conn.Close() = %v, want alertCloseNotify", err)
   422  	}
   423  }
   424  
   425  func TestQUICCanceledWaitingForTransportParams(t *testing.T) {
   426  	config := testConfig.Clone()
   427  	config.MinVersion = VersionTLS13
   428  	cli := newTestQUICClient(t, config)
   429  	cli.conn.Start(context.Background())
   430  	for cli.conn.NextEvent().Kind != QUICTransportParametersRequired {
   431  	}
   432  	err := cli.conn.Close()
   433  	if !errors.Is(err, alertCloseNotify) {
   434  		t.Errorf("conn.Close() = %v, want alertCloseNotify", err)
   435  	}
   436  }