github.com/mtsmfm/go/src@v0.0.0-20221020090648-44bdcb9f8fde/crypto/tls/handshake_messages_test.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"encoding/hex"
    10  	"math/rand"
    11  	"reflect"
    12  	"strings"
    13  	"testing"
    14  	"testing/quick"
    15  	"time"
    16  )
    17  
    18  var tests = []any{
    19  	&clientHelloMsg{},
    20  	&serverHelloMsg{},
    21  	&finishedMsg{},
    22  
    23  	&certificateMsg{},
    24  	&certificateRequestMsg{},
    25  	&certificateVerifyMsg{
    26  		hasSignatureAlgorithm: true,
    27  	},
    28  	&certificateStatusMsg{},
    29  	&clientKeyExchangeMsg{},
    30  	&newSessionTicketMsg{},
    31  	&sessionState{},
    32  	&sessionStateTLS13{},
    33  	&encryptedExtensionsMsg{},
    34  	&endOfEarlyDataMsg{},
    35  	&keyUpdateMsg{},
    36  	&newSessionTicketMsgTLS13{},
    37  	&certificateRequestMsgTLS13{},
    38  	&certificateMsgTLS13{},
    39  }
    40  
    41  func TestMarshalUnmarshal(t *testing.T) {
    42  	rand := rand.New(rand.NewSource(time.Now().UnixNano()))
    43  
    44  	for i, iface := range tests {
    45  		ty := reflect.ValueOf(iface).Type()
    46  
    47  		n := 100
    48  		if testing.Short() {
    49  			n = 5
    50  		}
    51  		for j := 0; j < n; j++ {
    52  			v, ok := quick.Value(ty, rand)
    53  			if !ok {
    54  				t.Errorf("#%d: failed to create value", i)
    55  				break
    56  			}
    57  
    58  			m1 := v.Interface().(handshakeMessage)
    59  			marshaled := m1.marshal()
    60  			m2 := iface.(handshakeMessage)
    61  			if !m2.unmarshal(marshaled) {
    62  				t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
    63  				break
    64  			}
    65  			m2.marshal() // to fill any marshal cache in the message
    66  
    67  			if !reflect.DeepEqual(m1, m2) {
    68  				t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
    69  				break
    70  			}
    71  
    72  			if i >= 3 {
    73  				// The first three message types (ClientHello,
    74  				// ServerHello and Finished) are allowed to
    75  				// have parsable prefixes because the extension
    76  				// data is optional and the length of the
    77  				// Finished varies across versions.
    78  				for j := 0; j < len(marshaled); j++ {
    79  					if m2.unmarshal(marshaled[0:j]) {
    80  						t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
    81  						break
    82  					}
    83  				}
    84  			}
    85  		}
    86  	}
    87  }
    88  
    89  func TestFuzz(t *testing.T) {
    90  	rand := rand.New(rand.NewSource(0))
    91  	for _, iface := range tests {
    92  		m := iface.(handshakeMessage)
    93  
    94  		for j := 0; j < 1000; j++ {
    95  			len := rand.Intn(100)
    96  			bytes := randomBytes(len, rand)
    97  			// This just looks for crashes due to bounds errors etc.
    98  			m.unmarshal(bytes)
    99  		}
   100  	}
   101  }
   102  
   103  func randomBytes(n int, rand *rand.Rand) []byte {
   104  	r := make([]byte, n)
   105  	if _, err := rand.Read(r); err != nil {
   106  		panic("rand.Read failed: " + err.Error())
   107  	}
   108  	return r
   109  }
   110  
   111  func randomString(n int, rand *rand.Rand) string {
   112  	b := randomBytes(n, rand)
   113  	return string(b)
   114  }
   115  
   116  func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   117  	m := &clientHelloMsg{}
   118  	m.vers = uint16(rand.Intn(65536))
   119  	m.random = randomBytes(32, rand)
   120  	m.sessionId = randomBytes(rand.Intn(32), rand)
   121  	m.cipherSuites = make([]uint16, rand.Intn(63)+1)
   122  	for i := 0; i < len(m.cipherSuites); i++ {
   123  		cs := uint16(rand.Int31())
   124  		if cs == scsvRenegotiation {
   125  			cs += 1
   126  		}
   127  		m.cipherSuites[i] = cs
   128  	}
   129  	m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
   130  	if rand.Intn(10) > 5 {
   131  		m.serverName = randomString(rand.Intn(255), rand)
   132  		for strings.HasSuffix(m.serverName, ".") {
   133  			m.serverName = m.serverName[:len(m.serverName)-1]
   134  		}
   135  	}
   136  	m.ocspStapling = rand.Intn(10) > 5
   137  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   138  	m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
   139  	for i := range m.supportedCurves {
   140  		m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
   141  	}
   142  	if rand.Intn(10) > 5 {
   143  		m.ticketSupported = true
   144  		if rand.Intn(10) > 5 {
   145  			m.sessionTicket = randomBytes(rand.Intn(300), rand)
   146  		} else {
   147  			m.sessionTicket = make([]byte, 0)
   148  		}
   149  	}
   150  	if rand.Intn(10) > 5 {
   151  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
   152  	}
   153  	if rand.Intn(10) > 5 {
   154  		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
   155  	}
   156  	for i := 0; i < rand.Intn(5); i++ {
   157  		m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
   158  	}
   159  	if rand.Intn(10) > 5 {
   160  		m.scts = true
   161  	}
   162  	if rand.Intn(10) > 5 {
   163  		m.secureRenegotiationSupported = true
   164  		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
   165  	}
   166  	for i := 0; i < rand.Intn(5); i++ {
   167  		m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
   168  	}
   169  	if rand.Intn(10) > 5 {
   170  		m.cookie = randomBytes(rand.Intn(500)+1, rand)
   171  	}
   172  	for i := 0; i < rand.Intn(5); i++ {
   173  		var ks keyShare
   174  		ks.group = CurveID(rand.Intn(30000) + 1)
   175  		ks.data = randomBytes(rand.Intn(200)+1, rand)
   176  		m.keyShares = append(m.keyShares, ks)
   177  	}
   178  	switch rand.Intn(3) {
   179  	case 1:
   180  		m.pskModes = []uint8{pskModeDHE}
   181  	case 2:
   182  		m.pskModes = []uint8{pskModeDHE, pskModePlain}
   183  	}
   184  	for i := 0; i < rand.Intn(5); i++ {
   185  		var psk pskIdentity
   186  		psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
   187  		psk.label = randomBytes(rand.Intn(500)+1, rand)
   188  		m.pskIdentities = append(m.pskIdentities, psk)
   189  		m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
   190  	}
   191  	if rand.Intn(10) > 5 {
   192  		m.earlyData = true
   193  	}
   194  
   195  	return reflect.ValueOf(m)
   196  }
   197  
   198  func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   199  	m := &serverHelloMsg{}
   200  	m.vers = uint16(rand.Intn(65536))
   201  	m.random = randomBytes(32, rand)
   202  	m.sessionId = randomBytes(rand.Intn(32), rand)
   203  	m.cipherSuite = uint16(rand.Int31())
   204  	m.compressionMethod = uint8(rand.Intn(256))
   205  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   206  
   207  	if rand.Intn(10) > 5 {
   208  		m.ocspStapling = true
   209  	}
   210  	if rand.Intn(10) > 5 {
   211  		m.ticketSupported = true
   212  	}
   213  	if rand.Intn(10) > 5 {
   214  		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   215  	}
   216  
   217  	for i := 0; i < rand.Intn(4); i++ {
   218  		m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
   219  	}
   220  
   221  	if rand.Intn(10) > 5 {
   222  		m.secureRenegotiationSupported = true
   223  		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
   224  	}
   225  	if rand.Intn(10) > 5 {
   226  		m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
   227  	}
   228  	if rand.Intn(10) > 5 {
   229  		m.cookie = randomBytes(rand.Intn(500)+1, rand)
   230  	}
   231  	if rand.Intn(10) > 5 {
   232  		for i := 0; i < rand.Intn(5); i++ {
   233  			m.serverShare.group = CurveID(rand.Intn(30000) + 1)
   234  			m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
   235  		}
   236  	} else if rand.Intn(10) > 5 {
   237  		m.selectedGroup = CurveID(rand.Intn(30000) + 1)
   238  	}
   239  	if rand.Intn(10) > 5 {
   240  		m.selectedIdentityPresent = true
   241  		m.selectedIdentity = uint16(rand.Intn(0xffff))
   242  	}
   243  
   244  	return reflect.ValueOf(m)
   245  }
   246  
   247  func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   248  	m := &encryptedExtensionsMsg{}
   249  
   250  	if rand.Intn(10) > 5 {
   251  		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   252  	}
   253  
   254  	return reflect.ValueOf(m)
   255  }
   256  
   257  func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   258  	m := &certificateMsg{}
   259  	numCerts := rand.Intn(20)
   260  	m.certificates = make([][]byte, numCerts)
   261  	for i := 0; i < numCerts; i++ {
   262  		m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
   263  	}
   264  	return reflect.ValueOf(m)
   265  }
   266  
   267  func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   268  	m := &certificateRequestMsg{}
   269  	m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
   270  	for i := 0; i < rand.Intn(100); i++ {
   271  		m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
   272  	}
   273  	return reflect.ValueOf(m)
   274  }
   275  
   276  func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   277  	m := &certificateVerifyMsg{}
   278  	m.hasSignatureAlgorithm = true
   279  	m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
   280  	m.signature = randomBytes(rand.Intn(15)+1, rand)
   281  	return reflect.ValueOf(m)
   282  }
   283  
   284  func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   285  	m := &certificateStatusMsg{}
   286  	m.response = randomBytes(rand.Intn(10)+1, rand)
   287  	return reflect.ValueOf(m)
   288  }
   289  
   290  func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   291  	m := &clientKeyExchangeMsg{}
   292  	m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
   293  	return reflect.ValueOf(m)
   294  }
   295  
   296  func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   297  	m := &finishedMsg{}
   298  	m.verifyData = randomBytes(12, rand)
   299  	return reflect.ValueOf(m)
   300  }
   301  
   302  func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   303  	m := &newSessionTicketMsg{}
   304  	m.ticket = randomBytes(rand.Intn(4), rand)
   305  	return reflect.ValueOf(m)
   306  }
   307  
   308  func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
   309  	s := &sessionState{}
   310  	s.vers = uint16(rand.Intn(10000))
   311  	s.cipherSuite = uint16(rand.Intn(10000))
   312  	s.masterSecret = randomBytes(rand.Intn(100)+1, rand)
   313  	s.createdAt = uint64(rand.Int63())
   314  	for i := 0; i < rand.Intn(20); i++ {
   315  		s.certificates = append(s.certificates, randomBytes(rand.Intn(500)+1, rand))
   316  	}
   317  	return reflect.ValueOf(s)
   318  }
   319  
   320  func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   321  	s := &sessionStateTLS13{}
   322  	s.cipherSuite = uint16(rand.Intn(10000))
   323  	s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand)
   324  	s.createdAt = uint64(rand.Int63())
   325  	for i := 0; i < rand.Intn(2)+1; i++ {
   326  		s.certificate.Certificate = append(
   327  			s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
   328  	}
   329  	if rand.Intn(10) > 5 {
   330  		s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
   331  	}
   332  	if rand.Intn(10) > 5 {
   333  		for i := 0; i < rand.Intn(2)+1; i++ {
   334  			s.certificate.SignedCertificateTimestamps = append(
   335  				s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
   336  		}
   337  	}
   338  	return reflect.ValueOf(s)
   339  }
   340  
   341  func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   342  	m := &endOfEarlyDataMsg{}
   343  	return reflect.ValueOf(m)
   344  }
   345  
   346  func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   347  	m := &keyUpdateMsg{}
   348  	m.updateRequested = rand.Intn(10) > 5
   349  	return reflect.ValueOf(m)
   350  }
   351  
   352  func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   353  	m := &newSessionTicketMsgTLS13{}
   354  	m.lifetime = uint32(rand.Intn(500000))
   355  	m.ageAdd = uint32(rand.Intn(500000))
   356  	m.nonce = randomBytes(rand.Intn(100), rand)
   357  	m.label = randomBytes(rand.Intn(1000), rand)
   358  	if rand.Intn(10) > 5 {
   359  		m.maxEarlyData = uint32(rand.Intn(500000))
   360  	}
   361  	return reflect.ValueOf(m)
   362  }
   363  
   364  func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   365  	m := &certificateRequestMsgTLS13{}
   366  	if rand.Intn(10) > 5 {
   367  		m.ocspStapling = true
   368  	}
   369  	if rand.Intn(10) > 5 {
   370  		m.scts = true
   371  	}
   372  	if rand.Intn(10) > 5 {
   373  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
   374  	}
   375  	if rand.Intn(10) > 5 {
   376  		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
   377  	}
   378  	if rand.Intn(10) > 5 {
   379  		m.certificateAuthorities = make([][]byte, 3)
   380  		for i := 0; i < 3; i++ {
   381  			m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
   382  		}
   383  	}
   384  	return reflect.ValueOf(m)
   385  }
   386  
   387  func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   388  	m := &certificateMsgTLS13{}
   389  	for i := 0; i < rand.Intn(2)+1; i++ {
   390  		m.certificate.Certificate = append(
   391  			m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
   392  	}
   393  	if rand.Intn(10) > 5 {
   394  		m.ocspStapling = true
   395  		m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
   396  	}
   397  	if rand.Intn(10) > 5 {
   398  		m.scts = true
   399  		for i := 0; i < rand.Intn(2)+1; i++ {
   400  			m.certificate.SignedCertificateTimestamps = append(
   401  				m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
   402  		}
   403  	}
   404  	return reflect.ValueOf(m)
   405  }
   406  
   407  func TestRejectEmptySCTList(t *testing.T) {
   408  	// RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.
   409  
   410  	var random [32]byte
   411  	sct := []byte{0x42, 0x42, 0x42, 0x42}
   412  	serverHello := serverHelloMsg{
   413  		vers:   VersionTLS12,
   414  		random: random[:],
   415  		scts:   [][]byte{sct},
   416  	}
   417  	serverHelloBytes := serverHello.marshal()
   418  
   419  	var serverHelloCopy serverHelloMsg
   420  	if !serverHelloCopy.unmarshal(serverHelloBytes) {
   421  		t.Fatal("Failed to unmarshal initial message")
   422  	}
   423  
   424  	// Change serverHelloBytes so that the SCT list is empty
   425  	i := bytes.Index(serverHelloBytes, sct)
   426  	if i < 0 {
   427  		t.Fatal("Cannot find SCT in ServerHello")
   428  	}
   429  
   430  	var serverHelloEmptySCT []byte
   431  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
   432  	// Append the extension length and SCT list length for an empty list.
   433  	serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
   434  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
   435  
   436  	// Update the handshake message length.
   437  	serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
   438  	serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
   439  	serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
   440  
   441  	// Update the extensions length
   442  	serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
   443  	serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
   444  
   445  	if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
   446  		t.Fatal("Unmarshaled ServerHello with empty SCT list")
   447  	}
   448  }
   449  
   450  func TestRejectEmptySCT(t *testing.T) {
   451  	// Not only must the SCT list be non-empty, but the SCT elements must
   452  	// not be zero length.
   453  
   454  	var random [32]byte
   455  	serverHello := serverHelloMsg{
   456  		vers:   VersionTLS12,
   457  		random: random[:],
   458  		scts:   [][]byte{nil},
   459  	}
   460  	serverHelloBytes := serverHello.marshal()
   461  
   462  	var serverHelloCopy serverHelloMsg
   463  	if serverHelloCopy.unmarshal(serverHelloBytes) {
   464  		t.Fatal("Unmarshaled ServerHello with zero-length SCT")
   465  	}
   466  }
   467  
   468  func TestRejectDuplicateExtensions(t *testing.T) {
   469  	clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f")
   470  	if err != nil {
   471  		t.Fatalf("failed to decode test ClientHello: %s", err)
   472  	}
   473  	var clientHelloCopy clientHelloMsg
   474  	if clientHelloCopy.unmarshal(clientHelloBytes) {
   475  		t.Error("Unmarshaled ClientHello with duplicate extensions")
   476  	}
   477  
   478  	serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000")
   479  	if err != nil {
   480  		t.Fatalf("failed to decode test ServerHello: %s", err)
   481  	}
   482  	var serverHelloCopy serverHelloMsg
   483  	if serverHelloCopy.unmarshal(serverHelloBytes) {
   484  		t.Fatal("Unmarshaled ServerHello with duplicate extensions")
   485  	}
   486  }