github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/gmtls/handshake_messages_test.go (about)

     1  // Copyright 2022 s1ren@github.com/hxx258456.
     2  
     3  /*
     4  gmtls是基于`golang/go`的`tls`包实现的国密改造版本。
     5  对应版权声明: thrid_licenses/github.com/golang/go/LICENSE
     6  */
     7  
     8  package gmtls
     9  
    10  import (
    11  	"bytes"
    12  	"math/rand"
    13  	"reflect"
    14  	"strings"
    15  	"testing"
    16  	"testing/quick"
    17  	"time"
    18  )
    19  
    20  var tests = []interface{}{
    21  	&clientHelloMsg{},
    22  	&serverHelloMsg{},
    23  	&finishedMsg{},
    24  
    25  	&certificateMsg{},
    26  	&certificateRequestMsg{},
    27  	&certificateVerifyMsg{
    28  		hasSignatureAlgorithm: true,
    29  	},
    30  	&certificateStatusMsg{},
    31  	&clientKeyExchangeMsg{},
    32  	&newSessionTicketMsg{},
    33  	&sessionState{},
    34  	&sessionStateTLS13{},
    35  	&encryptedExtensionsMsg{},
    36  	&endOfEarlyDataMsg{},
    37  	&keyUpdateMsg{},
    38  	&newSessionTicketMsgTLS13{},
    39  	&certificateRequestMsgTLS13{},
    40  	&certificateMsgTLS13{},
    41  }
    42  
    43  func TestMarshalUnmarshal(t *testing.T) {
    44  	randNew := rand.New(rand.NewSource(time.Now().UnixNano()))
    45  
    46  	for i, iface := range tests {
    47  		ty := reflect.ValueOf(iface).Type()
    48  
    49  		n := 100
    50  		if testing.Short() {
    51  			n = 5
    52  		}
    53  		for j := 0; j < n; j++ {
    54  			v, ok := quick.Value(ty, randNew)
    55  			if !ok {
    56  				t.Errorf("#%d: failed to create value", i)
    57  				break
    58  			}
    59  
    60  			m1 := v.Interface().(handshakeMessage)
    61  			marshaled := m1.marshal()
    62  			m2 := iface.(handshakeMessage)
    63  			if !m2.unmarshal(marshaled) {
    64  				t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
    65  				break
    66  			}
    67  			m2.marshal() // to fill any marshal cache in the message
    68  
    69  			if !reflect.DeepEqual(m1, m2) {
    70  				t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
    71  				break
    72  			}
    73  
    74  			if i >= 3 {
    75  				// The first three message types (ClientHello,
    76  				// ServerHello and Finished) are allowed to
    77  				// have parsable prefixes because the extension
    78  				// data is optional and the length of the
    79  				// Finished varies across versions.
    80  				for j := 0; j < len(marshaled); j++ {
    81  					if m2.unmarshal(marshaled[0:j]) {
    82  						t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
    83  						break
    84  					}
    85  				}
    86  			}
    87  		}
    88  	}
    89  }
    90  
    91  func TestFuzz(t *testing.T) {
    92  	randNew := rand.New(rand.NewSource(0))
    93  	for _, iface := range tests {
    94  		m := iface.(handshakeMessage)
    95  
    96  		for j := 0; j < 1000; j++ {
    97  			length := randNew.Intn(100)
    98  			ranBytes := randomBytes(length, randNew)
    99  			// This just looks for crashes due to bounds errors etc.
   100  			m.unmarshal(ranBytes)
   101  		}
   102  	}
   103  }
   104  func randomBytes(n int, rand *rand.Rand) []byte {
   105  	r := make([]byte, n)
   106  	if _, err := rand.Read(r); err != nil {
   107  		panic("rand.Read failed: " + err.Error())
   108  	}
   109  	return r
   110  }
   111  
   112  func randomString(n int, rand *rand.Rand) string {
   113  	b := randomBytes(n, rand)
   114  	return string(b)
   115  }
   116  
   117  //goland:noinspection GoUnusedParameter
   118  func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   119  	m := &clientHelloMsg{}
   120  	m.vers = uint16(rand.Intn(65536))
   121  	m.random = randomBytes(32, rand)
   122  	m.sessionId = randomBytes(rand.Intn(32), rand)
   123  	m.cipherSuites = make([]uint16, rand.Intn(63)+1)
   124  	for i := 0; i < len(m.cipherSuites); i++ {
   125  		cs := uint16(rand.Int31())
   126  		if cs == scsvRenegotiation {
   127  			cs += 1
   128  		}
   129  		m.cipherSuites[i] = cs
   130  	}
   131  	m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
   132  	if rand.Intn(10) > 5 {
   133  		m.serverName = randomString(rand.Intn(255), rand)
   134  		for strings.HasSuffix(m.serverName, ".") {
   135  			m.serverName = m.serverName[:len(m.serverName)-1]
   136  		}
   137  	}
   138  	m.ocspStapling = rand.Intn(10) > 5
   139  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   140  	m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
   141  	for i := range m.supportedCurves {
   142  		m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
   143  	}
   144  	if rand.Intn(10) > 5 {
   145  		m.ticketSupported = true
   146  		if rand.Intn(10) > 5 {
   147  			m.sessionTicket = randomBytes(rand.Intn(300), rand)
   148  		} else {
   149  			m.sessionTicket = make([]byte, 0)
   150  		}
   151  	}
   152  	if rand.Intn(10) > 5 {
   153  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
   154  	}
   155  	if rand.Intn(10) > 5 {
   156  		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms
   157  	}
   158  	for i := 0; i < rand.Intn(5); i++ {
   159  		m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
   160  	}
   161  	if rand.Intn(10) > 5 {
   162  		m.scts = true
   163  	}
   164  	if rand.Intn(10) > 5 {
   165  		m.secureRenegotiationSupported = true
   166  		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
   167  	}
   168  	for i := 0; i < rand.Intn(5); i++ {
   169  		m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
   170  	}
   171  	if rand.Intn(10) > 5 {
   172  		m.cookie = randomBytes(rand.Intn(500)+1, rand)
   173  	}
   174  	for i := 0; i < rand.Intn(5); i++ {
   175  		var ks keyShare
   176  		ks.group = CurveID(rand.Intn(30000) + 1)
   177  		ks.data = randomBytes(rand.Intn(200)+1, rand)
   178  		m.keyShares = append(m.keyShares, ks)
   179  	}
   180  	switch rand.Intn(3) {
   181  	case 1:
   182  		m.pskModes = []uint8{pskModeDHE}
   183  	case 2:
   184  		m.pskModes = []uint8{pskModeDHE, pskModePlain}
   185  	}
   186  	for i := 0; i < rand.Intn(5); i++ {
   187  		var psk pskIdentity
   188  		psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
   189  		psk.label = randomBytes(rand.Intn(500)+1, rand)
   190  		m.pskIdentities = append(m.pskIdentities, psk)
   191  		m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
   192  	}
   193  	if rand.Intn(10) > 5 {
   194  		m.earlyData = true
   195  	}
   196  
   197  	return reflect.ValueOf(m)
   198  }
   199  
   200  //goland:noinspection GoUnusedParameter
   201  func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   202  	m := &serverHelloMsg{}
   203  	m.vers = uint16(rand.Intn(65536))
   204  	m.random = randomBytes(32, rand)
   205  	m.sessionId = randomBytes(rand.Intn(32), rand)
   206  	m.cipherSuite = uint16(rand.Int31())
   207  	m.compressionMethod = uint8(rand.Intn(256))
   208  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   209  
   210  	if rand.Intn(10) > 5 {
   211  		m.ocspStapling = true
   212  	}
   213  	if rand.Intn(10) > 5 {
   214  		m.ticketSupported = true
   215  	}
   216  	if rand.Intn(10) > 5 {
   217  		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   218  	}
   219  
   220  	for i := 0; i < rand.Intn(4); i++ {
   221  		m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
   222  	}
   223  
   224  	if rand.Intn(10) > 5 {
   225  		m.secureRenegotiationSupported = true
   226  		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
   227  	}
   228  	if rand.Intn(10) > 5 {
   229  		m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
   230  	}
   231  	if rand.Intn(10) > 5 {
   232  		m.cookie = randomBytes(rand.Intn(500)+1, rand)
   233  	}
   234  	if rand.Intn(10) > 5 {
   235  		for i := 0; i < rand.Intn(5); i++ {
   236  			m.serverShare.group = CurveID(rand.Intn(30000) + 1)
   237  			m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
   238  		}
   239  	} else if rand.Intn(10) > 5 {
   240  		m.selectedGroup = CurveID(rand.Intn(30000) + 1)
   241  	}
   242  	if rand.Intn(10) > 5 {
   243  		m.selectedIdentityPresent = true
   244  		m.selectedIdentity = uint16(rand.Intn(0xffff))
   245  	}
   246  
   247  	return reflect.ValueOf(m)
   248  }
   249  
   250  //goland:noinspection GoUnusedParameter
   251  func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   252  	m := &encryptedExtensionsMsg{}
   253  
   254  	if rand.Intn(10) > 5 {
   255  		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   256  	}
   257  
   258  	return reflect.ValueOf(m)
   259  }
   260  
   261  //goland:noinspection GoUnusedParameter
   262  func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   263  	m := &certificateMsg{}
   264  	numCerts := rand.Intn(20)
   265  	m.certificates = make([][]byte, numCerts)
   266  	for i := 0; i < numCerts; i++ {
   267  		m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
   268  	}
   269  	return reflect.ValueOf(m)
   270  }
   271  
   272  //goland:noinspection GoUnusedParameter
   273  func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   274  	m := &certificateRequestMsg{}
   275  	m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
   276  	for i := 0; i < rand.Intn(100); i++ {
   277  		m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
   278  	}
   279  	return reflect.ValueOf(m)
   280  }
   281  
   282  //goland:noinspection GoUnusedParameter
   283  func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   284  	m := &certificateVerifyMsg{}
   285  	m.hasSignatureAlgorithm = true
   286  	m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
   287  	m.signature = randomBytes(rand.Intn(15)+1, rand)
   288  	return reflect.ValueOf(m)
   289  }
   290  
   291  //goland:noinspection GoUnusedParameter
   292  func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   293  	m := &certificateStatusMsg{}
   294  	m.response = randomBytes(rand.Intn(10)+1, rand)
   295  	return reflect.ValueOf(m)
   296  }
   297  
   298  //goland:noinspection GoUnusedParameter
   299  func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   300  	m := &clientKeyExchangeMsg{}
   301  	m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
   302  	return reflect.ValueOf(m)
   303  }
   304  
   305  //goland:noinspection GoUnusedParameter
   306  func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   307  	m := &finishedMsg{}
   308  	m.verifyData = randomBytes(12, rand)
   309  	return reflect.ValueOf(m)
   310  }
   311  
   312  //goland:noinspection GoUnusedParameter
   313  func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   314  	m := &newSessionTicketMsg{}
   315  	m.ticket = randomBytes(rand.Intn(4), rand)
   316  	return reflect.ValueOf(m)
   317  }
   318  
   319  //goland:noinspection GoUnusedParameter
   320  func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
   321  	s := &sessionState{}
   322  	s.vers = uint16(rand.Intn(10000))
   323  	s.cipherSuite = uint16(rand.Intn(10000))
   324  	s.masterSecret = randomBytes(rand.Intn(100)+1, rand)
   325  	s.createdAt = uint64(rand.Int63())
   326  	for i := 0; i < rand.Intn(20); i++ {
   327  		s.certificates = append(s.certificates, randomBytes(rand.Intn(500)+1, rand))
   328  	}
   329  	return reflect.ValueOf(s)
   330  }
   331  
   332  //goland:noinspection GoUnusedParameter
   333  func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   334  	s := &sessionStateTLS13{}
   335  	s.cipherSuite = uint16(rand.Intn(10000))
   336  	s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand)
   337  	s.createdAt = uint64(rand.Int63())
   338  	for i := 0; i < rand.Intn(2)+1; i++ {
   339  		s.certificate.Certificate = append(
   340  			s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
   341  	}
   342  	if rand.Intn(10) > 5 {
   343  		s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
   344  	}
   345  	if rand.Intn(10) > 5 {
   346  		for i := 0; i < rand.Intn(2)+1; i++ {
   347  			s.certificate.SignedCertificateTimestamps = append(
   348  				s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
   349  		}
   350  	}
   351  	return reflect.ValueOf(s)
   352  }
   353  
   354  //goland:noinspection GoUnusedParameter
   355  func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   356  	m := &endOfEarlyDataMsg{}
   357  	return reflect.ValueOf(m)
   358  }
   359  
   360  //goland:noinspection GoUnusedParameter
   361  func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   362  	m := &keyUpdateMsg{}
   363  	m.updateRequested = rand.Intn(10) > 5
   364  	return reflect.ValueOf(m)
   365  }
   366  
   367  //goland:noinspection GoUnusedParameter
   368  func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   369  	m := &newSessionTicketMsgTLS13{}
   370  	m.lifetime = uint32(rand.Intn(500000))
   371  	m.ageAdd = uint32(rand.Intn(500000))
   372  	m.nonce = randomBytes(rand.Intn(100), rand)
   373  	m.label = randomBytes(rand.Intn(1000), rand)
   374  	if rand.Intn(10) > 5 {
   375  		m.maxEarlyData = uint32(rand.Intn(500000))
   376  	}
   377  	return reflect.ValueOf(m)
   378  }
   379  
   380  //goland:noinspection GoUnusedParameter
   381  func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   382  	m := &certificateRequestMsgTLS13{}
   383  	if rand.Intn(10) > 5 {
   384  		m.ocspStapling = true
   385  	}
   386  	if rand.Intn(10) > 5 {
   387  		m.scts = true
   388  	}
   389  	if rand.Intn(10) > 5 {
   390  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
   391  	}
   392  	if rand.Intn(10) > 5 {
   393  		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms
   394  	}
   395  	if rand.Intn(10) > 5 {
   396  		m.certificateAuthorities = make([][]byte, 3)
   397  		for i := 0; i < 3; i++ {
   398  			m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
   399  		}
   400  	}
   401  	return reflect.ValueOf(m)
   402  }
   403  
   404  //goland:noinspection GoUnusedParameter
   405  func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   406  	m := &certificateMsgTLS13{}
   407  	for i := 0; i < rand.Intn(2)+1; i++ {
   408  		m.certificate.Certificate = append(
   409  			m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
   410  	}
   411  	if rand.Intn(10) > 5 {
   412  		m.ocspStapling = true
   413  		m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
   414  	}
   415  	if rand.Intn(10) > 5 {
   416  		m.scts = true
   417  		for i := 0; i < rand.Intn(2)+1; i++ {
   418  			m.certificate.SignedCertificateTimestamps = append(
   419  				m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
   420  		}
   421  	}
   422  	return reflect.ValueOf(m)
   423  }
   424  
   425  func TestRejectEmptySCTList(t *testing.T) {
   426  	// RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.
   427  
   428  	var random [32]byte
   429  	sct := []byte{0x42, 0x42, 0x42, 0x42}
   430  	serverHello := serverHelloMsg{
   431  		vers:   VersionTLS12,
   432  		random: random[:],
   433  		scts:   [][]byte{sct},
   434  	}
   435  	serverHelloBytes := serverHello.marshal()
   436  
   437  	var serverHelloCopy serverHelloMsg
   438  	if !serverHelloCopy.unmarshal(serverHelloBytes) {
   439  		t.Fatal("Failed to unmarshal initial message")
   440  	}
   441  
   442  	// Change serverHelloBytes so that the SCT list is empty
   443  	i := bytes.Index(serverHelloBytes, sct)
   444  	if i < 0 {
   445  		t.Fatal("Cannot find SCT in ServerHello")
   446  	}
   447  
   448  	var serverHelloEmptySCT []byte
   449  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
   450  	// Append the extension length and SCT list length for an empty list.
   451  	serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
   452  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
   453  
   454  	// Update the handshake message length.
   455  	serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
   456  	serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
   457  	serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
   458  
   459  	// Update the extensions length
   460  	serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
   461  	serverHelloEmptySCT[43] = byte(len(serverHelloEmptySCT) - 44)
   462  
   463  	if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
   464  		t.Fatal("Unmarshaled ServerHello with empty SCT list")
   465  	}
   466  }
   467  
   468  func TestRejectEmptySCT(t *testing.T) {
   469  	// Not only must the SCT list be non-empty, but the SCT elements must
   470  	// not be zero length.
   471  
   472  	var random [32]byte
   473  	serverHello := serverHelloMsg{
   474  		vers:   VersionTLS12,
   475  		random: random[:],
   476  		scts:   [][]byte{nil},
   477  	}
   478  	serverHelloBytes := serverHello.marshal()
   479  
   480  	var serverHelloCopy serverHelloMsg
   481  	if serverHelloCopy.unmarshal(serverHelloBytes) {
   482  		t.Fatal("Unmarshaled ServerHello with zero-length SCT")
   483  	}
   484  }