github.com/3andne/restls-client-go@v0.1.6/quic_test.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"reflect"
    11  	"testing"
    12  )
    13  
    14  type testQUICConn struct {
    15  	t           *testing.T
    16  	conn        *QUICConn
    17  	readSecret  map[QUICEncryptionLevel]suiteSecret
    18  	writeSecret map[QUICEncryptionLevel]suiteSecret
    19  	gotParams   []byte
    20  	complete    bool
    21  }
    22  
    23  func newTestQUICClient(t *testing.T, config *Config) *testQUICConn {
    24  	q := &testQUICConn{t: t}
    25  	q.conn = QUICClient(&QUICConfig{
    26  		TLSConfig: config,
    27  	})
    28  	t.Cleanup(func() {
    29  		q.conn.Close()
    30  	})
    31  	return q
    32  }
    33  
    34  func newTestQUICServer(t *testing.T, config *Config) *testQUICConn {
    35  	q := &testQUICConn{t: t}
    36  	q.conn = QUICServer(&QUICConfig{
    37  		TLSConfig: config,
    38  	})
    39  	t.Cleanup(func() {
    40  		q.conn.Close()
    41  	})
    42  	return q
    43  }
    44  
    45  type suiteSecret struct {
    46  	suite  uint16
    47  	secret []byte
    48  }
    49  
    50  func (q *testQUICConn) setReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
    51  	if _, ok := q.writeSecret[level]; !ok {
    52  		q.t.Errorf("SetReadSecret for level %v called before SetWriteSecret", level)
    53  	}
    54  	if level == QUICEncryptionLevelApplication && !q.complete {
    55  		q.t.Errorf("SetReadSecret for level %v called before HandshakeComplete", level)
    56  	}
    57  	if _, ok := q.readSecret[level]; ok {
    58  		q.t.Errorf("SetReadSecret for level %v called twice", level)
    59  	}
    60  	if q.readSecret == nil {
    61  		q.readSecret = map[QUICEncryptionLevel]suiteSecret{}
    62  	}
    63  	switch level {
    64  	case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication:
    65  		q.readSecret[level] = suiteSecret{suite, secret}
    66  	default:
    67  		q.t.Errorf("SetReadSecret for unexpected level %v", level)
    68  	}
    69  }
    70  
    71  func (q *testQUICConn) setWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
    72  	if _, ok := q.writeSecret[level]; ok {
    73  		q.t.Errorf("SetWriteSecret for level %v called twice", level)
    74  	}
    75  	if q.writeSecret == nil {
    76  		q.writeSecret = map[QUICEncryptionLevel]suiteSecret{}
    77  	}
    78  	switch level {
    79  	case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication:
    80  		q.writeSecret[level] = suiteSecret{suite, secret}
    81  	default:
    82  		q.t.Errorf("SetWriteSecret for unexpected level %v", level)
    83  	}
    84  }
    85  
    86  var errTransportParametersRequired = errors.New("transport parameters required")
    87  
    88  func runTestQUICConnection(ctx context.Context, cli, srv *testQUICConn, onHandleCryptoData func()) error {
    89  	a, b := cli, srv
    90  	for _, c := range []*testQUICConn{a, b} {
    91  		if !c.conn.conn.quic.started {
    92  			if err := c.conn.Start(ctx); err != nil {
    93  				return err
    94  			}
    95  		}
    96  	}
    97  	idleCount := 0
    98  	for {
    99  		e := a.conn.NextEvent()
   100  		switch e.Kind {
   101  		case QUICNoEvent:
   102  			idleCount++
   103  			if idleCount == 2 {
   104  				if !a.complete || !b.complete {
   105  					return errors.New("handshake incomplete")
   106  				}
   107  				return nil
   108  			}
   109  			a, b = b, a
   110  		case QUICSetReadSecret:
   111  			a.setReadSecret(e.Level, e.Suite, e.Data)
   112  		case QUICSetWriteSecret:
   113  			a.setWriteSecret(e.Level, e.Suite, e.Data)
   114  		case QUICWriteData:
   115  			if err := b.conn.HandleData(e.Level, e.Data); err != nil {
   116  				return err
   117  			}
   118  		case QUICTransportParameters:
   119  			a.gotParams = e.Data
   120  			if a.gotParams == nil {
   121  				a.gotParams = []byte{}
   122  			}
   123  		case QUICTransportParametersRequired:
   124  			return errTransportParametersRequired
   125  		case QUICHandshakeDone:
   126  			a.complete = true
   127  			if a == srv {
   128  				opts := QUICSessionTicketOptions{}
   129  				if err := srv.conn.SendSessionTicket(opts); err != nil {
   130  					return err
   131  				}
   132  			}
   133  		}
   134  		if e.Kind != QUICNoEvent {
   135  			idleCount = 0
   136  		}
   137  	}
   138  }
   139  
   140  func TestQUICConnection(t *testing.T) {
   141  	config := testConfig.Clone()
   142  	config.MinVersion = VersionTLS13
   143  
   144  	cli := newTestQUICClient(t, config)
   145  	cli.conn.SetTransportParameters(nil)
   146  
   147  	srv := newTestQUICServer(t, config)
   148  	srv.conn.SetTransportParameters(nil)
   149  
   150  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   151  		t.Fatalf("error during connection handshake: %v", err)
   152  	}
   153  
   154  	if _, ok := cli.readSecret[QUICEncryptionLevelHandshake]; !ok {
   155  		t.Errorf("client has no Handshake secret")
   156  	}
   157  	if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; !ok {
   158  		t.Errorf("client has no Application secret")
   159  	}
   160  	if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; !ok {
   161  		t.Errorf("server has no Handshake secret")
   162  	}
   163  	if _, ok := srv.readSecret[QUICEncryptionLevelApplication]; !ok {
   164  		t.Errorf("server has no Application secret")
   165  	}
   166  	for _, level := range []QUICEncryptionLevel{QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication} {
   167  		if _, ok := cli.readSecret[level]; !ok {
   168  			t.Errorf("client has no %v read secret", level)
   169  		}
   170  		if _, ok := srv.readSecret[level]; !ok {
   171  			t.Errorf("server has no %v read secret", level)
   172  		}
   173  		if !reflect.DeepEqual(cli.readSecret[level], srv.writeSecret[level]) {
   174  			t.Errorf("client read secret does not match server write secret for level %v", level)
   175  		}
   176  		if !reflect.DeepEqual(cli.writeSecret[level], srv.readSecret[level]) {
   177  			t.Errorf("client write secret does not match server read secret for level %v", level)
   178  		}
   179  	}
   180  }
   181  
   182  func TestQUICSessionResumption(t *testing.T) {
   183  	clientConfig := testConfig.Clone()
   184  	clientConfig.MinVersion = VersionTLS13
   185  	clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
   186  	clientConfig.ServerName = "example.go.dev"
   187  
   188  	serverConfig := testConfig.Clone()
   189  	serverConfig.MinVersion = VersionTLS13
   190  
   191  	cli := newTestQUICClient(t, clientConfig)
   192  	cli.conn.SetTransportParameters(nil)
   193  	srv := newTestQUICServer(t, serverConfig)
   194  	srv.conn.SetTransportParameters(nil)
   195  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   196  		t.Fatalf("error during first connection handshake: %v", err)
   197  	}
   198  	if cli.conn.ConnectionState().DidResume {
   199  		t.Errorf("first connection unexpectedly used session resumption")
   200  	}
   201  
   202  	cli2 := newTestQUICClient(t, clientConfig)
   203  	cli2.conn.SetTransportParameters(nil)
   204  	srv2 := newTestQUICServer(t, serverConfig)
   205  	srv2.conn.SetTransportParameters(nil)
   206  	if err := runTestQUICConnection(context.Background(), cli2, srv2, nil); err != nil {
   207  		t.Fatalf("error during second connection handshake: %v", err)
   208  	}
   209  	if !cli2.conn.ConnectionState().DidResume {
   210  		t.Errorf("second connection did not use session resumption")
   211  	}
   212  }
   213  
   214  func TestQUICPostHandshakeClientAuthentication(t *testing.T) {
   215  	// RFC 9001, Section 4.4.
   216  	config := testConfig.Clone()
   217  	config.MinVersion = VersionTLS13
   218  	cli := newTestQUICClient(t, config)
   219  	cli.conn.SetTransportParameters(nil)
   220  	srv := newTestQUICServer(t, config)
   221  	srv.conn.SetTransportParameters(nil)
   222  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   223  		t.Fatalf("error during connection handshake: %v", err)
   224  	}
   225  
   226  	certReq := new(certificateRequestMsgTLS13)
   227  	certReq.ocspStapling = true
   228  	certReq.scts = true
   229  	certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
   230  	certReqBytes, err := certReq.marshal()
   231  	if err != nil {
   232  		t.Fatal(err)
   233  	}
   234  	if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{
   235  		byte(typeCertificateRequest),
   236  		byte(0), byte(0), byte(len(certReqBytes)),
   237  	}, certReqBytes...)); err == nil {
   238  		t.Fatalf("post-handshake authentication request: got no error, want one")
   239  	}
   240  }
   241  
   242  func TestQUICPostHandshakeKeyUpdate(t *testing.T) {
   243  	// RFC 9001, Section 6.
   244  	config := testConfig.Clone()
   245  	config.MinVersion = VersionTLS13
   246  	cli := newTestQUICClient(t, config)
   247  	cli.conn.SetTransportParameters(nil)
   248  	srv := newTestQUICServer(t, config)
   249  	srv.conn.SetTransportParameters(nil)
   250  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   251  		t.Fatalf("error during connection handshake: %v", err)
   252  	}
   253  
   254  	keyUpdate := new(keyUpdateMsg)
   255  	keyUpdateBytes, err := keyUpdate.marshal()
   256  	if err != nil {
   257  		t.Fatal(err)
   258  	}
   259  	if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{
   260  		byte(typeKeyUpdate),
   261  		byte(0), byte(0), byte(len(keyUpdateBytes)),
   262  	}, keyUpdateBytes...)); !errors.Is(err, alertUnexpectedMessage) {
   263  		t.Fatalf("key update request: got error %v, want alertUnexpectedMessage", err)
   264  	}
   265  }
   266  
   267  func TestQUICHandshakeError(t *testing.T) {
   268  	clientConfig := testConfig.Clone()
   269  	clientConfig.MinVersion = VersionTLS13
   270  	clientConfig.InsecureSkipVerify = false
   271  	clientConfig.ServerName = "name"
   272  
   273  	serverConfig := testConfig.Clone()
   274  	serverConfig.MinVersion = VersionTLS13
   275  
   276  	cli := newTestQUICClient(t, clientConfig)
   277  	cli.conn.SetTransportParameters(nil)
   278  	srv := newTestQUICServer(t, serverConfig)
   279  	srv.conn.SetTransportParameters(nil)
   280  	err := runTestQUICConnection(context.Background(), cli, srv, nil)
   281  	if !errors.Is(err, AlertError(alertBadCertificate)) {
   282  		t.Errorf("connection handshake terminated with error %q, want alertBadCertificate", err)
   283  	}
   284  	var e *CertificateVerificationError
   285  	if !errors.As(err, &e) {
   286  		t.Errorf("connection handshake terminated with error %q, want CertificateVerificationError", err)
   287  	}
   288  }
   289  
   290  // Test that QUICConn.ConnectionState can be used during the handshake,
   291  // and that it reports the application protocol as soon as it has been
   292  // negotiated.
   293  func TestQUICConnectionState(t *testing.T) {
   294  	config := testConfig.Clone()
   295  	config.MinVersion = VersionTLS13
   296  	config.NextProtos = []string{"h3"}
   297  	cli := newTestQUICClient(t, config)
   298  	cli.conn.SetTransportParameters(nil)
   299  	srv := newTestQUICServer(t, config)
   300  	srv.conn.SetTransportParameters(nil)
   301  	onHandleCryptoData := func() {
   302  		cliCS := cli.conn.ConnectionState()
   303  		cliWantALPN := ""
   304  		if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; ok {
   305  			cliWantALPN = "h3"
   306  		}
   307  		if want, got := cliCS.NegotiatedProtocol, cliWantALPN; want != got {
   308  			t.Errorf("cli.ConnectionState().NegotiatedProtocol = %q, want %q", want, got)
   309  		}
   310  
   311  		srvCS := srv.conn.ConnectionState()
   312  		srvWantALPN := ""
   313  		if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; ok {
   314  			srvWantALPN = "h3"
   315  		}
   316  		if want, got := srvCS.NegotiatedProtocol, srvWantALPN; want != got {
   317  			t.Errorf("srv.ConnectionState().NegotiatedProtocol = %q, want %q", want, got)
   318  		}
   319  	}
   320  	if err := runTestQUICConnection(context.Background(), cli, srv, onHandleCryptoData); err != nil {
   321  		t.Fatalf("error during connection handshake: %v", err)
   322  	}
   323  }
   324  
   325  func TestQUICStartContextPropagation(t *testing.T) {
   326  	const key = "key"
   327  	const value = "value"
   328  	ctx := context.WithValue(context.Background(), key, value)
   329  	config := testConfig.Clone()
   330  	config.MinVersion = VersionTLS13
   331  	calls := 0
   332  	config.GetConfigForClient = func(info *ClientHelloInfo) (*Config, error) {
   333  		calls++
   334  		got, _ := info.Context().Value(key).(string)
   335  		if got != value {
   336  			t.Errorf("GetConfigForClient context key %q has value %q, want %q", key, got, value)
   337  		}
   338  		return nil, nil
   339  	}
   340  	cli := newTestQUICClient(t, config)
   341  	cli.conn.SetTransportParameters(nil)
   342  	srv := newTestQUICServer(t, config)
   343  	srv.conn.SetTransportParameters(nil)
   344  	if err := runTestQUICConnection(ctx, cli, srv, nil); err != nil {
   345  		t.Fatalf("error during connection handshake: %v", err)
   346  	}
   347  	if calls != 1 {
   348  		t.Errorf("GetConfigForClient called %v times, want 1", calls)
   349  	}
   350  }
   351  
   352  func TestQUICDelayedTransportParameters(t *testing.T) {
   353  	clientConfig := testConfig.Clone()
   354  	clientConfig.MinVersion = VersionTLS13
   355  	clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
   356  	clientConfig.ServerName = "example.go.dev"
   357  
   358  	serverConfig := testConfig.Clone()
   359  	serverConfig.MinVersion = VersionTLS13
   360  
   361  	cliParams := "client params"
   362  	srvParams := "server params"
   363  
   364  	cli := newTestQUICClient(t, clientConfig)
   365  	srv := newTestQUICServer(t, serverConfig)
   366  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired {
   367  		t.Fatalf("handshake with no client parameters: %v; want errTransportParametersRequired", err)
   368  	}
   369  	cli.conn.SetTransportParameters([]byte(cliParams))
   370  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired {
   371  		t.Fatalf("handshake with no server parameters: %v; want errTransportParametersRequired", err)
   372  	}
   373  	srv.conn.SetTransportParameters([]byte(srvParams))
   374  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   375  		t.Fatalf("error during connection handshake: %v", err)
   376  	}
   377  
   378  	if got, want := string(cli.gotParams), srvParams; got != want {
   379  		t.Errorf("client got transport params: %q, want %q", got, want)
   380  	}
   381  	if got, want := string(srv.gotParams), cliParams; got != want {
   382  		t.Errorf("server got transport params: %q, want %q", got, want)
   383  	}
   384  }
   385  
   386  func TestQUICEmptyTransportParameters(t *testing.T) {
   387  	config := testConfig.Clone()
   388  	config.MinVersion = VersionTLS13
   389  
   390  	cli := newTestQUICClient(t, config)
   391  	cli.conn.SetTransportParameters(nil)
   392  	srv := newTestQUICServer(t, config)
   393  	srv.conn.SetTransportParameters(nil)
   394  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   395  		t.Fatalf("error during connection handshake: %v", err)
   396  	}
   397  
   398  	if cli.gotParams == nil {
   399  		t.Errorf("client did not get transport params")
   400  	}
   401  	if srv.gotParams == nil {
   402  		t.Errorf("server did not get transport params")
   403  	}
   404  	if len(cli.gotParams) != 0 {
   405  		t.Errorf("client got transport params: %v, want empty", cli.gotParams)
   406  	}
   407  	if len(srv.gotParams) != 0 {
   408  		t.Errorf("server got transport params: %v, want empty", srv.gotParams)
   409  	}
   410  }
   411  
   412  func TestQUICCanceledWaitingForData(t *testing.T) {
   413  	config := testConfig.Clone()
   414  	config.MinVersion = VersionTLS13
   415  	cli := newTestQUICClient(t, config)
   416  	cli.conn.SetTransportParameters(nil)
   417  	cli.conn.Start(context.Background())
   418  	for cli.conn.NextEvent().Kind != QUICNoEvent {
   419  	}
   420  	err := cli.conn.Close()
   421  	if !errors.Is(err, alertCloseNotify) {
   422  		t.Errorf("conn.Close() = %v, want alertCloseNotify", err)
   423  	}
   424  }
   425  
   426  func TestQUICCanceledWaitingForTransportParams(t *testing.T) {
   427  	config := testConfig.Clone()
   428  	config.MinVersion = VersionTLS13
   429  	cli := newTestQUICClient(t, config)
   430  	cli.conn.Start(context.Background())
   431  	for cli.conn.NextEvent().Kind != QUICTransportParametersRequired {
   432  	}
   433  	err := cli.conn.Close()
   434  	if !errors.Is(err, alertCloseNotify) {
   435  		t.Errorf("conn.Close() = %v, want alertCloseNotify", err)
   436  	}
   437  }