github.com/JimmyHuang454/JLS-go@v0.0.0-20230831150107-90d536585ba0/tls/handshake_messages_test.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/x509"
    10  	"encoding/hex"
    11  	"math"
    12  	"math/rand"
    13  	"reflect"
    14  	"strings"
    15  	"testing"
    16  	"testing/quick"
    17  	"time"
    18  )
    19  
    20  var tests = []handshakeMessage{
    21  	&clientHelloMsg{},
    22  	&serverHelloMsg{},
    23  	&finishedMsg{},
    24  
    25  	&certificateMsg{},
    26  	&certificateRequestMsg{},
    27  	&certificateVerifyMsg{
    28  		hasSignatureAlgorithm: true,
    29  	},
    30  	&certificateStatusMsg{},
    31  	&clientKeyExchangeMsg{},
    32  	&newSessionTicketMsg{},
    33  	&encryptedExtensionsMsg{},
    34  	&endOfEarlyDataMsg{},
    35  	&keyUpdateMsg{},
    36  	&newSessionTicketMsgTLS13{},
    37  	&certificateRequestMsgTLS13{},
    38  	&certificateMsgTLS13{},
    39  	&SessionState{},
    40  }
    41  
    42  func mustMarshal(t *testing.T, msg handshakeMessage) []byte {
    43  	t.Helper()
    44  	b, err := msg.marshal()
    45  	if err != nil {
    46  		t.Fatal(err)
    47  	}
    48  	return b
    49  }
    50  
    51  func TestMarshalUnmarshal(t *testing.T) {
    52  	rand := rand.New(rand.NewSource(time.Now().UnixNano()))
    53  
    54  	for i, m := range tests {
    55  		ty := reflect.ValueOf(m).Type()
    56  
    57  		n := 100
    58  		if testing.Short() {
    59  			n = 5
    60  		}
    61  		for j := 0; j < n; j++ {
    62  			v, ok := quick.Value(ty, rand)
    63  			if !ok {
    64  				t.Errorf("#%d: failed to create value", i)
    65  				break
    66  			}
    67  
    68  			m1 := v.Interface().(handshakeMessage)
    69  			marshaled := mustMarshal(t, m1)
    70  			if !m.unmarshal(marshaled) {
    71  				t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
    72  				break
    73  			}
    74  			m.marshal() // to fill any marshal cache in the message
    75  
    76  			if m, ok := m.(*SessionState); ok {
    77  				m.activeCertHandles = nil
    78  			}
    79  
    80  			if !reflect.DeepEqual(m1, m) {
    81  				t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
    82  				break
    83  			}
    84  
    85  			if i >= 3 {
    86  				// The first three message types (ClientHello,
    87  				// ServerHello and Finished) are allowed to
    88  				// have parsable prefixes because the extension
    89  				// data is optional and the length of the
    90  				// Finished varies across versions.
    91  				for j := 0; j < len(marshaled); j++ {
    92  					if m.unmarshal(marshaled[0:j]) {
    93  						t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
    94  						break
    95  					}
    96  				}
    97  			}
    98  		}
    99  	}
   100  }
   101  
   102  func TestFuzz(t *testing.T) {
   103  	rand := rand.New(rand.NewSource(0))
   104  	for _, m := range tests {
   105  		for j := 0; j < 1000; j++ {
   106  			len := rand.Intn(1000)
   107  			bytes := randomBytes(len, rand)
   108  			// This just looks for crashes due to bounds errors etc.
   109  			m.unmarshal(bytes)
   110  		}
   111  	}
   112  }
   113  
   114  func randomBytes(n int, rand *rand.Rand) []byte {
   115  	r := make([]byte, n)
   116  	if _, err := rand.Read(r); err != nil {
   117  		panic("rand.Read failed: " + err.Error())
   118  	}
   119  	return r
   120  }
   121  
   122  func randomString(n int, rand *rand.Rand) string {
   123  	b := randomBytes(n, rand)
   124  	return string(b)
   125  }
   126  
   127  func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   128  	m := &clientHelloMsg{}
   129  	m.vers = uint16(rand.Intn(65536))
   130  	m.random = randomBytes(32, rand)
   131  	m.sessionId = randomBytes(rand.Intn(32), rand)
   132  	m.cipherSuites = make([]uint16, rand.Intn(63)+1)
   133  	for i := 0; i < len(m.cipherSuites); i++ {
   134  		cs := uint16(rand.Int31())
   135  		if cs == scsvRenegotiation {
   136  			cs += 1
   137  		}
   138  		m.cipherSuites[i] = cs
   139  	}
   140  	m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
   141  	if rand.Intn(10) > 5 {
   142  		m.serverName = randomString(rand.Intn(255), rand)
   143  		for strings.HasSuffix(m.serverName, ".") {
   144  			m.serverName = m.serverName[:len(m.serverName)-1]
   145  		}
   146  	}
   147  	m.ocspStapling = rand.Intn(10) > 5
   148  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   149  	m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
   150  	for i := range m.supportedCurves {
   151  		m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
   152  	}
   153  	if rand.Intn(10) > 5 {
   154  		m.ticketSupported = true
   155  		if rand.Intn(10) > 5 {
   156  			m.sessionTicket = randomBytes(rand.Intn(300), rand)
   157  		} else {
   158  			m.sessionTicket = make([]byte, 0)
   159  		}
   160  	}
   161  	if rand.Intn(10) > 5 {
   162  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
   163  	}
   164  	if rand.Intn(10) > 5 {
   165  		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
   166  	}
   167  	for i := 0; i < rand.Intn(5); i++ {
   168  		m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
   169  	}
   170  	if rand.Intn(10) > 5 {
   171  		m.scts = true
   172  	}
   173  	if rand.Intn(10) > 5 {
   174  		m.secureRenegotiationSupported = true
   175  		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
   176  	}
   177  	if rand.Intn(10) > 5 {
   178  		m.extendedMasterSecret = true
   179  	}
   180  	for i := 0; i < rand.Intn(5); i++ {
   181  		m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
   182  	}
   183  	if rand.Intn(10) > 5 {
   184  		m.cookie = randomBytes(rand.Intn(500)+1, rand)
   185  	}
   186  	for i := 0; i < rand.Intn(5); i++ {
   187  		var ks keyShare
   188  		ks.group = CurveID(rand.Intn(30000) + 1)
   189  		ks.data = randomBytes(rand.Intn(200)+1, rand)
   190  		m.keyShares = append(m.keyShares, ks)
   191  	}
   192  	switch rand.Intn(3) {
   193  	case 1:
   194  		m.pskModes = []uint8{pskModeDHE}
   195  	case 2:
   196  		m.pskModes = []uint8{pskModeDHE, pskModePlain}
   197  	}
   198  	for i := 0; i < rand.Intn(5); i++ {
   199  		var psk pskIdentity
   200  		psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
   201  		psk.label = randomBytes(rand.Intn(500)+1, rand)
   202  		m.pskIdentities = append(m.pskIdentities, psk)
   203  		m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
   204  	}
   205  	if rand.Intn(10) > 5 {
   206  		m.quicTransportParameters = randomBytes(rand.Intn(500), rand)
   207  	}
   208  	if rand.Intn(10) > 5 {
   209  		m.earlyData = true
   210  	}
   211  
   212  	return reflect.ValueOf(m)
   213  }
   214  
   215  func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   216  	m := &serverHelloMsg{}
   217  	m.vers = uint16(rand.Intn(65536))
   218  	m.random = randomBytes(32, rand)
   219  	m.sessionId = randomBytes(rand.Intn(32), rand)
   220  	m.cipherSuite = uint16(rand.Int31())
   221  	m.compressionMethod = uint8(rand.Intn(256))
   222  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   223  
   224  	if rand.Intn(10) > 5 {
   225  		m.ocspStapling = true
   226  	}
   227  	if rand.Intn(10) > 5 {
   228  		m.ticketSupported = true
   229  	}
   230  	if rand.Intn(10) > 5 {
   231  		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   232  	}
   233  
   234  	for i := 0; i < rand.Intn(4); i++ {
   235  		m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
   236  	}
   237  
   238  	if rand.Intn(10) > 5 {
   239  		m.secureRenegotiationSupported = true
   240  		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
   241  	}
   242  	if rand.Intn(10) > 5 {
   243  		m.extendedMasterSecret = true
   244  	}
   245  	if rand.Intn(10) > 5 {
   246  		m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
   247  	}
   248  	if rand.Intn(10) > 5 {
   249  		m.cookie = randomBytes(rand.Intn(500)+1, rand)
   250  	}
   251  	if rand.Intn(10) > 5 {
   252  		for i := 0; i < rand.Intn(5); i++ {
   253  			m.serverShare.group = CurveID(rand.Intn(30000) + 1)
   254  			m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
   255  		}
   256  	} else if rand.Intn(10) > 5 {
   257  		m.selectedGroup = CurveID(rand.Intn(30000) + 1)
   258  	}
   259  	if rand.Intn(10) > 5 {
   260  		m.selectedIdentityPresent = true
   261  		m.selectedIdentity = uint16(rand.Intn(0xffff))
   262  	}
   263  
   264  	return reflect.ValueOf(m)
   265  }
   266  
   267  func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   268  	m := &encryptedExtensionsMsg{}
   269  
   270  	if rand.Intn(10) > 5 {
   271  		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   272  	}
   273  	if rand.Intn(10) > 5 {
   274  		m.earlyData = true
   275  	}
   276  
   277  	return reflect.ValueOf(m)
   278  }
   279  
   280  func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   281  	m := &certificateMsg{}
   282  	numCerts := rand.Intn(20)
   283  	m.certificates = make([][]byte, numCerts)
   284  	for i := 0; i < numCerts; i++ {
   285  		m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
   286  	}
   287  	return reflect.ValueOf(m)
   288  }
   289  
   290  func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   291  	m := &certificateRequestMsg{}
   292  	m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
   293  	for i := 0; i < rand.Intn(100); i++ {
   294  		m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
   295  	}
   296  	return reflect.ValueOf(m)
   297  }
   298  
   299  func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   300  	m := &certificateVerifyMsg{}
   301  	m.hasSignatureAlgorithm = true
   302  	m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
   303  	m.signature = randomBytes(rand.Intn(15)+1, rand)
   304  	return reflect.ValueOf(m)
   305  }
   306  
   307  func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   308  	m := &certificateStatusMsg{}
   309  	m.response = randomBytes(rand.Intn(10)+1, rand)
   310  	return reflect.ValueOf(m)
   311  }
   312  
   313  func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   314  	m := &clientKeyExchangeMsg{}
   315  	m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
   316  	return reflect.ValueOf(m)
   317  }
   318  
   319  func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   320  	m := &finishedMsg{}
   321  	m.verifyData = randomBytes(12, rand)
   322  	return reflect.ValueOf(m)
   323  }
   324  
   325  func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   326  	m := &newSessionTicketMsg{}
   327  	m.ticket = randomBytes(rand.Intn(4), rand)
   328  	return reflect.ValueOf(m)
   329  }
   330  
   331  var sessionTestCerts []*x509.Certificate
   332  
   333  func init() {
   334  	cert, err := x509.ParseCertificate(testRSACertificate)
   335  	if err != nil {
   336  		panic(err)
   337  	}
   338  	sessionTestCerts = append(sessionTestCerts, cert)
   339  	cert, err = x509.ParseCertificate(testRSACertificateIssuer)
   340  	if err != nil {
   341  		panic(err)
   342  	}
   343  	sessionTestCerts = append(sessionTestCerts, cert)
   344  }
   345  
   346  func (*SessionState) Generate(rand *rand.Rand, size int) reflect.Value {
   347  	s := &SessionState{}
   348  	isTLS13 := rand.Intn(10) > 5
   349  	if isTLS13 {
   350  		s.version = VersionTLS13
   351  	} else {
   352  		s.version = uint16(rand.Intn(VersionTLS13))
   353  	}
   354  	s.isClient = rand.Intn(10) > 5
   355  	s.cipherSuite = uint16(rand.Intn(math.MaxUint16))
   356  	s.createdAt = uint64(rand.Int63())
   357  	s.secret = randomBytes(rand.Intn(100)+1, rand)
   358  	for n, i := rand.Intn(3), 0; i < n; i++ {
   359  		s.Extra = append(s.Extra, randomBytes(rand.Intn(100), rand))
   360  	}
   361  	if rand.Intn(10) > 5 {
   362  		s.EarlyData = true
   363  	}
   364  	if rand.Intn(10) > 5 {
   365  		s.extMasterSecret = true
   366  	}
   367  	if s.isClient || rand.Intn(10) > 5 {
   368  		if rand.Intn(10) > 5 {
   369  			s.peerCertificates = sessionTestCerts
   370  		} else {
   371  			s.peerCertificates = sessionTestCerts[:1]
   372  		}
   373  	}
   374  	if rand.Intn(10) > 5 && s.peerCertificates != nil {
   375  		s.ocspResponse = randomBytes(rand.Intn(100)+1, rand)
   376  	}
   377  	if rand.Intn(10) > 5 && s.peerCertificates != nil {
   378  		for i := 0; i < rand.Intn(2)+1; i++ {
   379  			s.scts = append(s.scts, randomBytes(rand.Intn(500)+1, rand))
   380  		}
   381  	}
   382  	if len(s.peerCertificates) > 0 {
   383  		for i := 0; i < rand.Intn(3); i++ {
   384  			if rand.Intn(10) > 5 {
   385  				s.verifiedChains = append(s.verifiedChains, s.peerCertificates)
   386  			} else {
   387  				s.verifiedChains = append(s.verifiedChains, s.peerCertificates[:1])
   388  			}
   389  		}
   390  	}
   391  	if rand.Intn(10) > 5 && s.EarlyData {
   392  		s.alpnProtocol = string(randomBytes(rand.Intn(10), rand))
   393  	}
   394  	if s.isClient {
   395  		if isTLS13 {
   396  			s.useBy = uint64(rand.Int63())
   397  			s.ageAdd = uint32(rand.Int63() & math.MaxUint32)
   398  		}
   399  	}
   400  	return reflect.ValueOf(s)
   401  }
   402  
   403  func (s *SessionState) marshal() ([]byte, error) { return s.Bytes() }
   404  func (s *SessionState) unmarshal(b []byte) bool {
   405  	ss, err := ParseSessionState(b)
   406  	if err != nil {
   407  		return false
   408  	}
   409  	*s = *ss
   410  	return true
   411  }
   412  
   413  func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   414  	m := &endOfEarlyDataMsg{}
   415  	return reflect.ValueOf(m)
   416  }
   417  
   418  func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   419  	m := &keyUpdateMsg{}
   420  	m.updateRequested = rand.Intn(10) > 5
   421  	return reflect.ValueOf(m)
   422  }
   423  
   424  func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   425  	m := &newSessionTicketMsgTLS13{}
   426  	m.lifetime = uint32(rand.Intn(500000))
   427  	m.ageAdd = uint32(rand.Intn(500000))
   428  	m.nonce = randomBytes(rand.Intn(100), rand)
   429  	m.label = randomBytes(rand.Intn(1000), rand)
   430  	if rand.Intn(10) > 5 {
   431  		m.maxEarlyData = uint32(rand.Intn(500000))
   432  	}
   433  	return reflect.ValueOf(m)
   434  }
   435  
   436  func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   437  	m := &certificateRequestMsgTLS13{}
   438  	if rand.Intn(10) > 5 {
   439  		m.ocspStapling = true
   440  	}
   441  	if rand.Intn(10) > 5 {
   442  		m.scts = true
   443  	}
   444  	if rand.Intn(10) > 5 {
   445  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
   446  	}
   447  	if rand.Intn(10) > 5 {
   448  		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
   449  	}
   450  	if rand.Intn(10) > 5 {
   451  		m.certificateAuthorities = make([][]byte, 3)
   452  		for i := 0; i < 3; i++ {
   453  			m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
   454  		}
   455  	}
   456  	return reflect.ValueOf(m)
   457  }
   458  
   459  func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   460  	m := &certificateMsgTLS13{}
   461  	for i := 0; i < rand.Intn(2)+1; i++ {
   462  		m.certificate.Certificate = append(
   463  			m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
   464  	}
   465  	if rand.Intn(10) > 5 {
   466  		m.ocspStapling = true
   467  		m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
   468  	}
   469  	if rand.Intn(10) > 5 {
   470  		m.scts = true
   471  		for i := 0; i < rand.Intn(2)+1; i++ {
   472  			m.certificate.SignedCertificateTimestamps = append(
   473  				m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
   474  		}
   475  	}
   476  	return reflect.ValueOf(m)
   477  }
   478  
   479  func TestRejectEmptySCTList(t *testing.T) {
   480  	// RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.
   481  
   482  	var random [32]byte
   483  	sct := []byte{0x42, 0x42, 0x42, 0x42}
   484  	serverHello := &serverHelloMsg{
   485  		vers:   VersionTLS12,
   486  		random: random[:],
   487  		scts:   [][]byte{sct},
   488  	}
   489  	serverHelloBytes := mustMarshal(t, serverHello)
   490  
   491  	var serverHelloCopy serverHelloMsg
   492  	if !serverHelloCopy.unmarshal(serverHelloBytes) {
   493  		t.Fatal("Failed to unmarshal initial message")
   494  	}
   495  
   496  	// Change serverHelloBytes so that the SCT list is empty
   497  	i := bytes.Index(serverHelloBytes, sct)
   498  	if i < 0 {
   499  		t.Fatal("Cannot find SCT in ServerHello")
   500  	}
   501  
   502  	var serverHelloEmptySCT []byte
   503  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
   504  	// Append the extension length and SCT list length for an empty list.
   505  	serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
   506  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
   507  
   508  	// Update the handshake message length.
   509  	serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
   510  	serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
   511  	serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
   512  
   513  	// Update the extensions length
   514  	serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
   515  	serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
   516  
   517  	if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
   518  		t.Fatal("Unmarshaled ServerHello with empty SCT list")
   519  	}
   520  }
   521  
   522  func TestRejectEmptySCT(t *testing.T) {
   523  	// Not only must the SCT list be non-empty, but the SCT elements must
   524  	// not be zero length.
   525  
   526  	var random [32]byte
   527  	serverHello := &serverHelloMsg{
   528  		vers:   VersionTLS12,
   529  		random: random[:],
   530  		scts:   [][]byte{nil},
   531  	}
   532  	serverHelloBytes := mustMarshal(t, serverHello)
   533  
   534  	var serverHelloCopy serverHelloMsg
   535  	if serverHelloCopy.unmarshal(serverHelloBytes) {
   536  		t.Fatal("Unmarshaled ServerHello with zero-length SCT")
   537  	}
   538  }
   539  
   540  func TestRejectDuplicateExtensions(t *testing.T) {
   541  	clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f")
   542  	if err != nil {
   543  		t.Fatalf("failed to decode test ClientHello: %s", err)
   544  	}
   545  	var clientHelloCopy clientHelloMsg
   546  	if clientHelloCopy.unmarshal(clientHelloBytes) {
   547  		t.Error("Unmarshaled ClientHello with duplicate extensions")
   548  	}
   549  
   550  	serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000")
   551  	if err != nil {
   552  		t.Fatalf("failed to decode test ServerHello: %s", err)
   553  	}
   554  	var serverHelloCopy serverHelloMsg
   555  	if serverHelloCopy.unmarshal(serverHelloBytes) {
   556  		t.Fatal("Unmarshaled ServerHello with duplicate extensions")
   557  	}
   558  }