gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/gmtls/handshake_messages_test.go (about)

     1  // Copyright (c) 2022 zhaochun
     2  // core-gm is licensed under Mulan PSL v2.
     3  // You can use this software according to the terms and conditions of the Mulan PSL v2.
     4  // You may obtain a copy of Mulan PSL v2 at:
     5  //          http://license.coscl.org.cn/MulanPSL2
     6  // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
     7  // See the Mulan PSL v2 for more details.
     8  
     9  /*
    10  gmtls是基于`golang/go`的`tls`包实现的国密改造版本。
    11  对应版权声明: thrid_licenses/github.com/golang/go/LICENSE
    12  */
    13  
    14  package gmtls
    15  
    16  import (
    17  	"bytes"
    18  	"math/rand"
    19  	"reflect"
    20  	"strings"
    21  	"testing"
    22  	"testing/quick"
    23  	"time"
    24  )
    25  
    26  var tests = []interface{}{
    27  	&clientHelloMsg{},
    28  	&serverHelloMsg{},
    29  	&finishedMsg{},
    30  
    31  	&certificateMsg{},
    32  	&certificateRequestMsg{},
    33  	&certificateVerifyMsg{
    34  		hasSignatureAlgorithm: true,
    35  	},
    36  	&certificateStatusMsg{},
    37  	&clientKeyExchangeMsg{},
    38  	&newSessionTicketMsg{},
    39  	&sessionState{},
    40  	&sessionStateTLS13{},
    41  	&encryptedExtensionsMsg{},
    42  	&endOfEarlyDataMsg{},
    43  	&keyUpdateMsg{},
    44  	&newSessionTicketMsgTLS13{},
    45  	&certificateRequestMsgTLS13{},
    46  	&certificateMsgTLS13{},
    47  }
    48  
    49  func TestMarshalUnmarshal(t *testing.T) {
    50  	randNew := rand.New(rand.NewSource(time.Now().UnixNano()))
    51  
    52  	for i, iface := range tests {
    53  		ty := reflect.ValueOf(iface).Type()
    54  
    55  		n := 100
    56  		if testing.Short() {
    57  			n = 5
    58  		}
    59  		for j := 0; j < n; j++ {
    60  			v, ok := quick.Value(ty, randNew)
    61  			if !ok {
    62  				t.Errorf("#%d: failed to create value", i)
    63  				break
    64  			}
    65  
    66  			m1 := v.Interface().(handshakeMessage)
    67  			marshaled := m1.marshal()
    68  			m2 := iface.(handshakeMessage)
    69  			if !m2.unmarshal(marshaled) {
    70  				t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
    71  				break
    72  			}
    73  			m2.marshal() // to fill any marshal cache in the message
    74  
    75  			if !reflect.DeepEqual(m1, m2) {
    76  				t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
    77  				break
    78  			}
    79  
    80  			if i >= 3 {
    81  				// The first three message types (ClientHello,
    82  				// ServerHello and Finished) are allowed to
    83  				// have parsable prefixes because the extension
    84  				// data is optional and the length of the
    85  				// Finished varies across versions.
    86  				for j := 0; j < len(marshaled); j++ {
    87  					if m2.unmarshal(marshaled[0:j]) {
    88  						t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
    89  						break
    90  					}
    91  				}
    92  			}
    93  		}
    94  	}
    95  }
    96  
    97  func TestFuzz(t *testing.T) {
    98  	randNew := rand.New(rand.NewSource(0))
    99  	for _, iface := range tests {
   100  		m := iface.(handshakeMessage)
   101  
   102  		for j := 0; j < 1000; j++ {
   103  			length := randNew.Intn(100)
   104  			ranBytes := randomBytes(length, randNew)
   105  			// This just looks for crashes due to bounds errors etc.
   106  			m.unmarshal(ranBytes)
   107  		}
   108  	}
   109  }
   110  
   111  func randomBytes(n int, rand *rand.Rand) []byte {
   112  	r := make([]byte, n)
   113  	if _, err := rand.Read(r); err != nil {
   114  		panic("rand.Read failed: " + err.Error())
   115  	}
   116  	return r
   117  }
   118  
   119  func randomString(n int, rand *rand.Rand) string {
   120  	b := randomBytes(n, rand)
   121  	return string(b)
   122  }
   123  
   124  //goland:noinspection GoUnusedParameter
   125  func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   126  	m := &clientHelloMsg{}
   127  	m.vers = uint16(rand.Intn(65536))
   128  	m.random = randomBytes(32, rand)
   129  	m.sessionId = randomBytes(rand.Intn(32), rand)
   130  	m.cipherSuites = make([]uint16, rand.Intn(63)+1)
   131  	for i := 0; i < len(m.cipherSuites); i++ {
   132  		cs := uint16(rand.Int31())
   133  		if cs == scsvRenegotiation {
   134  			cs += 1
   135  		}
   136  		m.cipherSuites[i] = cs
   137  	}
   138  	m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
   139  	if rand.Intn(10) > 5 {
   140  		m.serverName = randomString(rand.Intn(255), rand)
   141  		for strings.HasSuffix(m.serverName, ".") {
   142  			m.serverName = m.serverName[:len(m.serverName)-1]
   143  		}
   144  	}
   145  	m.ocspStapling = rand.Intn(10) > 5
   146  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   147  	m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
   148  	for i := range m.supportedCurves {
   149  		m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
   150  	}
   151  	if rand.Intn(10) > 5 {
   152  		m.ticketSupported = true
   153  		if rand.Intn(10) > 5 {
   154  			m.sessionTicket = randomBytes(rand.Intn(300), rand)
   155  		} else {
   156  			m.sessionTicket = make([]byte, 0)
   157  		}
   158  	}
   159  	if rand.Intn(10) > 5 {
   160  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
   161  	}
   162  	if rand.Intn(10) > 5 {
   163  		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms
   164  	}
   165  	for i := 0; i < rand.Intn(5); i++ {
   166  		m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
   167  	}
   168  	if rand.Intn(10) > 5 {
   169  		m.scts = true
   170  	}
   171  	if rand.Intn(10) > 5 {
   172  		m.secureRenegotiationSupported = true
   173  		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
   174  	}
   175  	for i := 0; i < rand.Intn(5); i++ {
   176  		m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
   177  	}
   178  	if rand.Intn(10) > 5 {
   179  		m.cookie = randomBytes(rand.Intn(500)+1, rand)
   180  	}
   181  	for i := 0; i < rand.Intn(5); i++ {
   182  		var ks keyShare
   183  		ks.group = CurveID(rand.Intn(30000) + 1)
   184  		ks.data = randomBytes(rand.Intn(200)+1, rand)
   185  		m.keyShares = append(m.keyShares, ks)
   186  	}
   187  	switch rand.Intn(3) {
   188  	case 1:
   189  		m.pskModes = []uint8{pskModeDHE}
   190  	case 2:
   191  		m.pskModes = []uint8{pskModeDHE, pskModePlain}
   192  	}
   193  	for i := 0; i < rand.Intn(5); i++ {
   194  		var psk pskIdentity
   195  		psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
   196  		psk.label = randomBytes(rand.Intn(500)+1, rand)
   197  		m.pskIdentities = append(m.pskIdentities, psk)
   198  		m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
   199  	}
   200  	if rand.Intn(10) > 5 {
   201  		m.earlyData = true
   202  	}
   203  
   204  	return reflect.ValueOf(m)
   205  }
   206  
   207  //goland:noinspection GoUnusedParameter
   208  func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   209  	m := &serverHelloMsg{}
   210  	m.vers = uint16(rand.Intn(65536))
   211  	m.random = randomBytes(32, rand)
   212  	m.sessionId = randomBytes(rand.Intn(32), rand)
   213  	m.cipherSuite = uint16(rand.Int31())
   214  	m.compressionMethod = uint8(rand.Intn(256))
   215  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   216  
   217  	if rand.Intn(10) > 5 {
   218  		m.ocspStapling = true
   219  	}
   220  	if rand.Intn(10) > 5 {
   221  		m.ticketSupported = true
   222  	}
   223  	if rand.Intn(10) > 5 {
   224  		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   225  	}
   226  
   227  	for i := 0; i < rand.Intn(4); i++ {
   228  		m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
   229  	}
   230  
   231  	if rand.Intn(10) > 5 {
   232  		m.secureRenegotiationSupported = true
   233  		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
   234  	}
   235  	if rand.Intn(10) > 5 {
   236  		m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
   237  	}
   238  	if rand.Intn(10) > 5 {
   239  		m.cookie = randomBytes(rand.Intn(500)+1, rand)
   240  	}
   241  	if rand.Intn(10) > 5 {
   242  		for i := 0; i < rand.Intn(5); i++ {
   243  			m.serverShare.group = CurveID(rand.Intn(30000) + 1)
   244  			m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
   245  		}
   246  	} else if rand.Intn(10) > 5 {
   247  		m.selectedGroup = CurveID(rand.Intn(30000) + 1)
   248  	}
   249  	if rand.Intn(10) > 5 {
   250  		m.selectedIdentityPresent = true
   251  		m.selectedIdentity = uint16(rand.Intn(0xffff))
   252  	}
   253  
   254  	return reflect.ValueOf(m)
   255  }
   256  
   257  //goland:noinspection GoUnusedParameter
   258  func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   259  	m := &encryptedExtensionsMsg{}
   260  
   261  	if rand.Intn(10) > 5 {
   262  		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   263  	}
   264  
   265  	return reflect.ValueOf(m)
   266  }
   267  
   268  //goland:noinspection GoUnusedParameter
   269  func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   270  	m := &certificateMsg{}
   271  	numCerts := rand.Intn(20)
   272  	m.certificates = make([][]byte, numCerts)
   273  	for i := 0; i < numCerts; i++ {
   274  		m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
   275  	}
   276  	return reflect.ValueOf(m)
   277  }
   278  
   279  //goland:noinspection GoUnusedParameter
   280  func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   281  	m := &certificateRequestMsg{}
   282  	m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
   283  	for i := 0; i < rand.Intn(100); i++ {
   284  		m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
   285  	}
   286  	return reflect.ValueOf(m)
   287  }
   288  
   289  //goland:noinspection GoUnusedParameter
   290  func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   291  	m := &certificateVerifyMsg{}
   292  	m.hasSignatureAlgorithm = true
   293  	m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
   294  	m.signature = randomBytes(rand.Intn(15)+1, rand)
   295  	return reflect.ValueOf(m)
   296  }
   297  
   298  //goland:noinspection GoUnusedParameter
   299  func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   300  	m := &certificateStatusMsg{}
   301  	m.response = randomBytes(rand.Intn(10)+1, rand)
   302  	return reflect.ValueOf(m)
   303  }
   304  
   305  //goland:noinspection GoUnusedParameter
   306  func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   307  	m := &clientKeyExchangeMsg{}
   308  	m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
   309  	return reflect.ValueOf(m)
   310  }
   311  
   312  //goland:noinspection GoUnusedParameter
   313  func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   314  	m := &finishedMsg{}
   315  	m.verifyData = randomBytes(12, rand)
   316  	return reflect.ValueOf(m)
   317  }
   318  
   319  //goland:noinspection GoUnusedParameter
   320  func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   321  	m := &newSessionTicketMsg{}
   322  	m.ticket = randomBytes(rand.Intn(4), rand)
   323  	return reflect.ValueOf(m)
   324  }
   325  
   326  //goland:noinspection GoUnusedParameter
   327  func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
   328  	s := &sessionState{}
   329  	s.vers = uint16(rand.Intn(10000))
   330  	s.cipherSuite = uint16(rand.Intn(10000))
   331  	s.masterSecret = randomBytes(rand.Intn(100)+1, rand)
   332  	s.createdAt = uint64(rand.Int63())
   333  	for i := 0; i < rand.Intn(20); i++ {
   334  		s.certificates = append(s.certificates, randomBytes(rand.Intn(500)+1, rand))
   335  	}
   336  	return reflect.ValueOf(s)
   337  }
   338  
   339  //goland:noinspection GoUnusedParameter
   340  func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   341  	s := &sessionStateTLS13{}
   342  	s.cipherSuite = uint16(rand.Intn(10000))
   343  	s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand)
   344  	s.createdAt = uint64(rand.Int63())
   345  	for i := 0; i < rand.Intn(2)+1; i++ {
   346  		s.certificate.Certificate = append(
   347  			s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
   348  	}
   349  	if rand.Intn(10) > 5 {
   350  		s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
   351  	}
   352  	if rand.Intn(10) > 5 {
   353  		for i := 0; i < rand.Intn(2)+1; i++ {
   354  			s.certificate.SignedCertificateTimestamps = append(
   355  				s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
   356  		}
   357  	}
   358  	return reflect.ValueOf(s)
   359  }
   360  
   361  //goland:noinspection GoUnusedParameter
   362  func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   363  	m := &endOfEarlyDataMsg{}
   364  	return reflect.ValueOf(m)
   365  }
   366  
   367  //goland:noinspection GoUnusedParameter
   368  func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   369  	m := &keyUpdateMsg{}
   370  	m.updateRequested = rand.Intn(10) > 5
   371  	return reflect.ValueOf(m)
   372  }
   373  
   374  //goland:noinspection GoUnusedParameter
   375  func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   376  	m := &newSessionTicketMsgTLS13{}
   377  	m.lifetime = uint32(rand.Intn(500000))
   378  	m.ageAdd = uint32(rand.Intn(500000))
   379  	m.nonce = randomBytes(rand.Intn(100), rand)
   380  	m.label = randomBytes(rand.Intn(1000), rand)
   381  	if rand.Intn(10) > 5 {
   382  		m.maxEarlyData = uint32(rand.Intn(500000))
   383  	}
   384  	return reflect.ValueOf(m)
   385  }
   386  
   387  //goland:noinspection GoUnusedParameter
   388  func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   389  	m := &certificateRequestMsgTLS13{}
   390  	if rand.Intn(10) > 5 {
   391  		m.ocspStapling = true
   392  	}
   393  	if rand.Intn(10) > 5 {
   394  		m.scts = true
   395  	}
   396  	if rand.Intn(10) > 5 {
   397  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
   398  	}
   399  	if rand.Intn(10) > 5 {
   400  		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms
   401  	}
   402  	if rand.Intn(10) > 5 {
   403  		m.certificateAuthorities = make([][]byte, 3)
   404  		for i := 0; i < 3; i++ {
   405  			m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
   406  		}
   407  	}
   408  	return reflect.ValueOf(m)
   409  }
   410  
   411  //goland:noinspection GoUnusedParameter
   412  func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   413  	m := &certificateMsgTLS13{}
   414  	for i := 0; i < rand.Intn(2)+1; i++ {
   415  		m.certificate.Certificate = append(
   416  			m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
   417  	}
   418  	if rand.Intn(10) > 5 {
   419  		m.ocspStapling = true
   420  		m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
   421  	}
   422  	if rand.Intn(10) > 5 {
   423  		m.scts = true
   424  		for i := 0; i < rand.Intn(2)+1; i++ {
   425  			m.certificate.SignedCertificateTimestamps = append(
   426  				m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
   427  		}
   428  	}
   429  	return reflect.ValueOf(m)
   430  }
   431  
   432  func TestRejectEmptySCTList(t *testing.T) {
   433  	// RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.
   434  
   435  	var random [32]byte
   436  	sct := []byte{0x42, 0x42, 0x42, 0x42}
   437  	serverHello := serverHelloMsg{
   438  		vers:   VersionTLS12,
   439  		random: random[:],
   440  		scts:   [][]byte{sct},
   441  	}
   442  	serverHelloBytes := serverHello.marshal()
   443  
   444  	var serverHelloCopy serverHelloMsg
   445  	if !serverHelloCopy.unmarshal(serverHelloBytes) {
   446  		t.Fatal("Failed to unmarshal initial message")
   447  	}
   448  
   449  	// Change serverHelloBytes so that the SCT list is empty
   450  	i := bytes.Index(serverHelloBytes, sct)
   451  	if i < 0 {
   452  		t.Fatal("Cannot find SCT in ServerHello")
   453  	}
   454  
   455  	var serverHelloEmptySCT []byte
   456  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
   457  	// Append the extension length and SCT list length for an empty list.
   458  	serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
   459  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
   460  
   461  	// Update the handshake message length.
   462  	serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
   463  	serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
   464  	serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
   465  
   466  	// Update the extensions length
   467  	serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
   468  	serverHelloEmptySCT[43] = byte(len(serverHelloEmptySCT) - 44)
   469  
   470  	if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
   471  		t.Fatal("Unmarshaled ServerHello with empty SCT list")
   472  	}
   473  }
   474  
   475  func TestRejectEmptySCT(t *testing.T) {
   476  	// Not only must the SCT list be non-empty, but the SCT elements must
   477  	// not be zero length.
   478  
   479  	var random [32]byte
   480  	serverHello := serverHelloMsg{
   481  		vers:   VersionTLS12,
   482  		random: random[:],
   483  		scts:   [][]byte{nil},
   484  	}
   485  	serverHelloBytes := serverHello.marshal()
   486  
   487  	var serverHelloCopy serverHelloMsg
   488  	if serverHelloCopy.unmarshal(serverHelloBytes) {
   489  		t.Fatal("Unmarshaled ServerHello with zero-length SCT")
   490  	}
   491  }