github.com/lovishpuri/go-40569/src@v0.0.0-20230519171745-f8623e7c56cf/crypto/tls/handshake_messages_test.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"encoding/hex"
    10  	"math/rand"
    11  	"reflect"
    12  	"strings"
    13  	"testing"
    14  	"testing/quick"
    15  	"time"
    16  )
    17  
    18  var tests = []any{
    19  	&clientHelloMsg{},
    20  	&serverHelloMsg{},
    21  	&finishedMsg{},
    22  
    23  	&certificateMsg{},
    24  	&certificateRequestMsg{},
    25  	&certificateVerifyMsg{
    26  		hasSignatureAlgorithm: true,
    27  	},
    28  	&certificateStatusMsg{},
    29  	&clientKeyExchangeMsg{},
    30  	&newSessionTicketMsg{},
    31  	&sessionState{},
    32  	&sessionStateTLS13{},
    33  	&encryptedExtensionsMsg{},
    34  	&endOfEarlyDataMsg{},
    35  	&keyUpdateMsg{},
    36  	&newSessionTicketMsgTLS13{},
    37  	&certificateRequestMsgTLS13{},
    38  	&certificateMsgTLS13{},
    39  }
    40  
    41  func mustMarshal(t *testing.T, msg handshakeMessage) []byte {
    42  	t.Helper()
    43  	b, err := msg.marshal()
    44  	if err != nil {
    45  		t.Fatal(err)
    46  	}
    47  	return b
    48  }
    49  
    50  func TestMarshalUnmarshal(t *testing.T) {
    51  	rand := rand.New(rand.NewSource(time.Now().UnixNano()))
    52  
    53  	for i, iface := range tests {
    54  		ty := reflect.ValueOf(iface).Type()
    55  
    56  		n := 100
    57  		if testing.Short() {
    58  			n = 5
    59  		}
    60  		for j := 0; j < n; j++ {
    61  			v, ok := quick.Value(ty, rand)
    62  			if !ok {
    63  				t.Errorf("#%d: failed to create value", i)
    64  				break
    65  			}
    66  
    67  			m1 := v.Interface().(handshakeMessage)
    68  			marshaled := mustMarshal(t, m1)
    69  			m2 := iface.(handshakeMessage)
    70  			if !m2.unmarshal(marshaled) {
    71  				t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
    72  				break
    73  			}
    74  			m2.marshal() // to fill any marshal cache in the message
    75  
    76  			if !reflect.DeepEqual(m1, m2) {
    77  				t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
    78  				break
    79  			}
    80  
    81  			if i >= 3 {
    82  				// The first three message types (ClientHello,
    83  				// ServerHello and Finished) are allowed to
    84  				// have parsable prefixes because the extension
    85  				// data is optional and the length of the
    86  				// Finished varies across versions.
    87  				for j := 0; j < len(marshaled); j++ {
    88  					if m2.unmarshal(marshaled[0:j]) {
    89  						t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
    90  						break
    91  					}
    92  				}
    93  			}
    94  		}
    95  	}
    96  }
    97  
    98  func TestFuzz(t *testing.T) {
    99  	rand := rand.New(rand.NewSource(0))
   100  	for _, iface := range tests {
   101  		m := iface.(handshakeMessage)
   102  
   103  		for j := 0; j < 1000; j++ {
   104  			len := rand.Intn(100)
   105  			bytes := randomBytes(len, rand)
   106  			// This just looks for crashes due to bounds errors etc.
   107  			m.unmarshal(bytes)
   108  		}
   109  	}
   110  }
   111  
   112  func randomBytes(n int, rand *rand.Rand) []byte {
   113  	r := make([]byte, n)
   114  	if _, err := rand.Read(r); err != nil {
   115  		panic("rand.Read failed: " + err.Error())
   116  	}
   117  	return r
   118  }
   119  
   120  func randomString(n int, rand *rand.Rand) string {
   121  	b := randomBytes(n, rand)
   122  	return string(b)
   123  }
   124  
   125  func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   126  	m := &clientHelloMsg{}
   127  	m.vers = uint16(rand.Intn(65536))
   128  	m.random = randomBytes(32, rand)
   129  	m.sessionId = randomBytes(rand.Intn(32), rand)
   130  	m.cipherSuites = make([]uint16, rand.Intn(63)+1)
   131  	for i := 0; i < len(m.cipherSuites); i++ {
   132  		cs := uint16(rand.Int31())
   133  		if cs == scsvRenegotiation {
   134  			cs += 1
   135  		}
   136  		m.cipherSuites[i] = cs
   137  	}
   138  	m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
   139  	if rand.Intn(10) > 5 {
   140  		m.serverName = randomString(rand.Intn(255), rand)
   141  		for strings.HasSuffix(m.serverName, ".") {
   142  			m.serverName = m.serverName[:len(m.serverName)-1]
   143  		}
   144  	}
   145  	m.ocspStapling = rand.Intn(10) > 5
   146  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   147  	m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
   148  	for i := range m.supportedCurves {
   149  		m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
   150  	}
   151  	if rand.Intn(10) > 5 {
   152  		m.ticketSupported = true
   153  		if rand.Intn(10) > 5 {
   154  			m.sessionTicket = randomBytes(rand.Intn(300), rand)
   155  		} else {
   156  			m.sessionTicket = make([]byte, 0)
   157  		}
   158  	}
   159  	if rand.Intn(10) > 5 {
   160  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
   161  	}
   162  	if rand.Intn(10) > 5 {
   163  		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
   164  	}
   165  	for i := 0; i < rand.Intn(5); i++ {
   166  		m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
   167  	}
   168  	if rand.Intn(10) > 5 {
   169  		m.scts = true
   170  	}
   171  	if rand.Intn(10) > 5 {
   172  		m.secureRenegotiationSupported = true
   173  		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
   174  	}
   175  	for i := 0; i < rand.Intn(5); i++ {
   176  		m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
   177  	}
   178  	if rand.Intn(10) > 5 {
   179  		m.cookie = randomBytes(rand.Intn(500)+1, rand)
   180  	}
   181  	for i := 0; i < rand.Intn(5); i++ {
   182  		var ks keyShare
   183  		ks.group = CurveID(rand.Intn(30000) + 1)
   184  		ks.data = randomBytes(rand.Intn(200)+1, rand)
   185  		m.keyShares = append(m.keyShares, ks)
   186  	}
   187  	switch rand.Intn(3) {
   188  	case 1:
   189  		m.pskModes = []uint8{pskModeDHE}
   190  	case 2:
   191  		m.pskModes = []uint8{pskModeDHE, pskModePlain}
   192  	}
   193  	for i := 0; i < rand.Intn(5); i++ {
   194  		var psk pskIdentity
   195  		psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
   196  		psk.label = randomBytes(rand.Intn(500)+1, rand)
   197  		m.pskIdentities = append(m.pskIdentities, psk)
   198  		m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
   199  	}
   200  	if rand.Intn(10) > 5 {
   201  		m.earlyData = true
   202  	}
   203  
   204  	return reflect.ValueOf(m)
   205  }
   206  
   207  func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   208  	m := &serverHelloMsg{}
   209  	m.vers = uint16(rand.Intn(65536))
   210  	m.random = randomBytes(32, rand)
   211  	m.sessionId = randomBytes(rand.Intn(32), rand)
   212  	m.cipherSuite = uint16(rand.Int31())
   213  	m.compressionMethod = uint8(rand.Intn(256))
   214  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   215  
   216  	if rand.Intn(10) > 5 {
   217  		m.ocspStapling = true
   218  	}
   219  	if rand.Intn(10) > 5 {
   220  		m.ticketSupported = true
   221  	}
   222  	if rand.Intn(10) > 5 {
   223  		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   224  	}
   225  
   226  	for i := 0; i < rand.Intn(4); i++ {
   227  		m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
   228  	}
   229  
   230  	if rand.Intn(10) > 5 {
   231  		m.secureRenegotiationSupported = true
   232  		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
   233  	}
   234  	if rand.Intn(10) > 5 {
   235  		m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
   236  	}
   237  	if rand.Intn(10) > 5 {
   238  		m.cookie = randomBytes(rand.Intn(500)+1, rand)
   239  	}
   240  	if rand.Intn(10) > 5 {
   241  		for i := 0; i < rand.Intn(5); i++ {
   242  			m.serverShare.group = CurveID(rand.Intn(30000) + 1)
   243  			m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
   244  		}
   245  	} else if rand.Intn(10) > 5 {
   246  		m.selectedGroup = CurveID(rand.Intn(30000) + 1)
   247  	}
   248  	if rand.Intn(10) > 5 {
   249  		m.selectedIdentityPresent = true
   250  		m.selectedIdentity = uint16(rand.Intn(0xffff))
   251  	}
   252  
   253  	return reflect.ValueOf(m)
   254  }
   255  
   256  func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   257  	m := &encryptedExtensionsMsg{}
   258  
   259  	if rand.Intn(10) > 5 {
   260  		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   261  	}
   262  
   263  	return reflect.ValueOf(m)
   264  }
   265  
   266  func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   267  	m := &certificateMsg{}
   268  	numCerts := rand.Intn(20)
   269  	m.certificates = make([][]byte, numCerts)
   270  	for i := 0; i < numCerts; i++ {
   271  		m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
   272  	}
   273  	return reflect.ValueOf(m)
   274  }
   275  
   276  func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   277  	m := &certificateRequestMsg{}
   278  	m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
   279  	for i := 0; i < rand.Intn(100); i++ {
   280  		m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
   281  	}
   282  	return reflect.ValueOf(m)
   283  }
   284  
   285  func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   286  	m := &certificateVerifyMsg{}
   287  	m.hasSignatureAlgorithm = true
   288  	m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
   289  	m.signature = randomBytes(rand.Intn(15)+1, rand)
   290  	return reflect.ValueOf(m)
   291  }
   292  
   293  func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   294  	m := &certificateStatusMsg{}
   295  	m.response = randomBytes(rand.Intn(10)+1, rand)
   296  	return reflect.ValueOf(m)
   297  }
   298  
   299  func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   300  	m := &clientKeyExchangeMsg{}
   301  	m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
   302  	return reflect.ValueOf(m)
   303  }
   304  
   305  func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   306  	m := &finishedMsg{}
   307  	m.verifyData = randomBytes(12, rand)
   308  	return reflect.ValueOf(m)
   309  }
   310  
   311  func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   312  	m := &newSessionTicketMsg{}
   313  	m.ticket = randomBytes(rand.Intn(4), rand)
   314  	return reflect.ValueOf(m)
   315  }
   316  
   317  func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
   318  	s := &sessionState{}
   319  	s.vers = uint16(rand.Intn(10000))
   320  	s.cipherSuite = uint16(rand.Intn(10000))
   321  	s.masterSecret = randomBytes(rand.Intn(100)+1, rand)
   322  	s.createdAt = uint64(rand.Int63())
   323  	for i := 0; i < rand.Intn(20); i++ {
   324  		s.certificates = append(s.certificates, randomBytes(rand.Intn(500)+1, rand))
   325  	}
   326  	return reflect.ValueOf(s)
   327  }
   328  
   329  func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   330  	s := &sessionStateTLS13{}
   331  	s.cipherSuite = uint16(rand.Intn(10000))
   332  	s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand)
   333  	s.createdAt = uint64(rand.Int63())
   334  	for i := 0; i < rand.Intn(2)+1; i++ {
   335  		s.certificate.Certificate = append(
   336  			s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
   337  	}
   338  	if rand.Intn(10) > 5 {
   339  		s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
   340  	}
   341  	if rand.Intn(10) > 5 {
   342  		for i := 0; i < rand.Intn(2)+1; i++ {
   343  			s.certificate.SignedCertificateTimestamps = append(
   344  				s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
   345  		}
   346  	}
   347  	return reflect.ValueOf(s)
   348  }
   349  
   350  func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   351  	m := &endOfEarlyDataMsg{}
   352  	return reflect.ValueOf(m)
   353  }
   354  
   355  func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   356  	m := &keyUpdateMsg{}
   357  	m.updateRequested = rand.Intn(10) > 5
   358  	return reflect.ValueOf(m)
   359  }
   360  
   361  func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   362  	m := &newSessionTicketMsgTLS13{}
   363  	m.lifetime = uint32(rand.Intn(500000))
   364  	m.ageAdd = uint32(rand.Intn(500000))
   365  	m.nonce = randomBytes(rand.Intn(100), rand)
   366  	m.label = randomBytes(rand.Intn(1000), rand)
   367  	if rand.Intn(10) > 5 {
   368  		m.maxEarlyData = uint32(rand.Intn(500000))
   369  	}
   370  	return reflect.ValueOf(m)
   371  }
   372  
   373  func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   374  	m := &certificateRequestMsgTLS13{}
   375  	if rand.Intn(10) > 5 {
   376  		m.ocspStapling = true
   377  	}
   378  	if rand.Intn(10) > 5 {
   379  		m.scts = true
   380  	}
   381  	if rand.Intn(10) > 5 {
   382  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
   383  	}
   384  	if rand.Intn(10) > 5 {
   385  		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
   386  	}
   387  	if rand.Intn(10) > 5 {
   388  		m.certificateAuthorities = make([][]byte, 3)
   389  		for i := 0; i < 3; i++ {
   390  			m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
   391  		}
   392  	}
   393  	return reflect.ValueOf(m)
   394  }
   395  
   396  func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   397  	m := &certificateMsgTLS13{}
   398  	for i := 0; i < rand.Intn(2)+1; i++ {
   399  		m.certificate.Certificate = append(
   400  			m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
   401  	}
   402  	if rand.Intn(10) > 5 {
   403  		m.ocspStapling = true
   404  		m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
   405  	}
   406  	if rand.Intn(10) > 5 {
   407  		m.scts = true
   408  		for i := 0; i < rand.Intn(2)+1; i++ {
   409  			m.certificate.SignedCertificateTimestamps = append(
   410  				m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
   411  		}
   412  	}
   413  	return reflect.ValueOf(m)
   414  }
   415  
   416  func TestRejectEmptySCTList(t *testing.T) {
   417  	// RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.
   418  
   419  	var random [32]byte
   420  	sct := []byte{0x42, 0x42, 0x42, 0x42}
   421  	serverHello := &serverHelloMsg{
   422  		vers:   VersionTLS12,
   423  		random: random[:],
   424  		scts:   [][]byte{sct},
   425  	}
   426  	serverHelloBytes := mustMarshal(t, serverHello)
   427  
   428  	var serverHelloCopy serverHelloMsg
   429  	if !serverHelloCopy.unmarshal(serverHelloBytes) {
   430  		t.Fatal("Failed to unmarshal initial message")
   431  	}
   432  
   433  	// Change serverHelloBytes so that the SCT list is empty
   434  	i := bytes.Index(serverHelloBytes, sct)
   435  	if i < 0 {
   436  		t.Fatal("Cannot find SCT in ServerHello")
   437  	}
   438  
   439  	var serverHelloEmptySCT []byte
   440  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
   441  	// Append the extension length and SCT list length for an empty list.
   442  	serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
   443  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
   444  
   445  	// Update the handshake message length.
   446  	serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
   447  	serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
   448  	serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
   449  
   450  	// Update the extensions length
   451  	serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
   452  	serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
   453  
   454  	if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
   455  		t.Fatal("Unmarshaled ServerHello with empty SCT list")
   456  	}
   457  }
   458  
   459  func TestRejectEmptySCT(t *testing.T) {
   460  	// Not only must the SCT list be non-empty, but the SCT elements must
   461  	// not be zero length.
   462  
   463  	var random [32]byte
   464  	serverHello := &serverHelloMsg{
   465  		vers:   VersionTLS12,
   466  		random: random[:],
   467  		scts:   [][]byte{nil},
   468  	}
   469  	serverHelloBytes := mustMarshal(t, serverHello)
   470  
   471  	var serverHelloCopy serverHelloMsg
   472  	if serverHelloCopy.unmarshal(serverHelloBytes) {
   473  		t.Fatal("Unmarshaled ServerHello with zero-length SCT")
   474  	}
   475  }
   476  
   477  func TestRejectDuplicateExtensions(t *testing.T) {
   478  	clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f")
   479  	if err != nil {
   480  		t.Fatalf("failed to decode test ClientHello: %s", err)
   481  	}
   482  	var clientHelloCopy clientHelloMsg
   483  	if clientHelloCopy.unmarshal(clientHelloBytes) {
   484  		t.Error("Unmarshaled ClientHello with duplicate extensions")
   485  	}
   486  
   487  	serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000")
   488  	if err != nil {
   489  		t.Fatalf("failed to decode test ServerHello: %s", err)
   490  	}
   491  	var serverHelloCopy serverHelloMsg
   492  	if serverHelloCopy.unmarshal(serverHelloBytes) {
   493  		t.Fatal("Unmarshaled ServerHello with duplicate extensions")
   494  	}
   495  }