github.com/goproxy0/go@v0.0.0-20171111080102-49cc0c489d2c/src/crypto/tls/handshake_messages_test.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"math/rand"
    10  	"reflect"
    11  	"strings"
    12  	"testing"
    13  	"testing/quick"
    14  )
    15  
    16  var tests = []interface{}{
    17  	&clientHelloMsg{},
    18  	&serverHelloMsg{},
    19  	&finishedMsg{},
    20  
    21  	&certificateMsg{},
    22  	&certificateRequestMsg{},
    23  	&certificateVerifyMsg{},
    24  	&certificateStatusMsg{},
    25  	&clientKeyExchangeMsg{},
    26  	&nextProtoMsg{},
    27  	&newSessionTicketMsg{},
    28  	&sessionState{},
    29  	&serverHelloMsg13{},
    30  	&encryptedExtensionsMsg{},
    31  	&certificateMsg13{},
    32  	&newSessionTicketMsg13{},
    33  	&sessionState13{},
    34  }
    35  
    36  type testMessage interface {
    37  	marshal() []byte
    38  	unmarshal([]byte) alert
    39  	equal(interface{}) bool
    40  }
    41  
    42  func TestMarshalUnmarshal(t *testing.T) {
    43  	rand := rand.New(rand.NewSource(0))
    44  
    45  	for i, iface := range tests {
    46  		ty := reflect.ValueOf(iface).Type()
    47  
    48  		n := 100
    49  		if testing.Short() {
    50  			n = 5
    51  		}
    52  		for j := 0; j < n; j++ {
    53  			v, ok := quick.Value(ty, rand)
    54  			if !ok {
    55  				t.Errorf("#%d: failed to create value", i)
    56  				break
    57  			}
    58  
    59  			m1 := v.Interface().(testMessage)
    60  			marshaled := m1.marshal()
    61  			m2 := iface.(testMessage)
    62  			if m2.unmarshal(marshaled) != alertSuccess {
    63  				t.Errorf("#%d.%d failed to unmarshal %#v %x", i, j, m1, marshaled)
    64  				break
    65  			}
    66  			m2.marshal() // to fill any marshal cache in the message
    67  
    68  			if !m1.equal(m2) {
    69  				t.Errorf("#%d.%d got:%#v want:%#v %x", i, j, m2, m1, marshaled)
    70  				break
    71  			}
    72  
    73  			if i >= 3 {
    74  				// The first three message types (ClientHello,
    75  				// ServerHello and Finished) are allowed to
    76  				// have parsable prefixes because the extension
    77  				// data is optional and the length of the
    78  				// Finished varies across versions.
    79  				for j := 0; j < len(marshaled); j++ {
    80  					if m2.unmarshal(marshaled[0:j]) == alertSuccess {
    81  						t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
    82  						break
    83  					}
    84  				}
    85  			}
    86  		}
    87  	}
    88  }
    89  
    90  func TestFuzz(t *testing.T) {
    91  	rand := rand.New(rand.NewSource(0))
    92  	for _, iface := range tests {
    93  		m := iface.(testMessage)
    94  
    95  		for j := 0; j < 1000; j++ {
    96  			len := rand.Intn(100)
    97  			bytes := randomBytes(len, rand)
    98  			// This just looks for crashes due to bounds errors etc.
    99  			m.unmarshal(bytes)
   100  		}
   101  	}
   102  }
   103  
   104  func randomBytes(n int, rand *rand.Rand) []byte {
   105  	r := make([]byte, n)
   106  	if _, err := rand.Read(r); err != nil {
   107  		panic("rand.Read failed: " + err.Error())
   108  	}
   109  	return r
   110  }
   111  
   112  func randomString(n int, rand *rand.Rand) string {
   113  	b := randomBytes(n, rand)
   114  	return string(b)
   115  }
   116  
   117  func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   118  	m := &clientHelloMsg{}
   119  	m.vers = uint16(rand.Intn(65536))
   120  	m.random = randomBytes(32, rand)
   121  	m.sessionId = randomBytes(rand.Intn(32), rand)
   122  	m.cipherSuites = make([]uint16, rand.Intn(63)+1)
   123  	for i := 0; i < len(m.cipherSuites); i++ {
   124  		cs := uint16(rand.Int31())
   125  		if cs == scsvRenegotiation {
   126  			cs += 1
   127  		}
   128  		m.cipherSuites[i] = cs
   129  	}
   130  	m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
   131  	if rand.Intn(10) > 5 {
   132  		m.nextProtoNeg = true
   133  	}
   134  	if rand.Intn(10) > 5 {
   135  		m.serverName = randomString(rand.Intn(255), rand)
   136  		for strings.HasSuffix(m.serverName, ".") {
   137  			m.serverName = m.serverName[:len(m.serverName)-1]
   138  		}
   139  	}
   140  	m.ocspStapling = rand.Intn(10) > 5
   141  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   142  	m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
   143  	for i := range m.supportedCurves {
   144  		m.supportedCurves[i] = CurveID(rand.Intn(30000))
   145  	}
   146  	if rand.Intn(10) > 5 {
   147  		m.ticketSupported = true
   148  		if rand.Intn(10) > 5 {
   149  			m.sessionTicket = randomBytes(rand.Intn(300), rand)
   150  		}
   151  	}
   152  	if rand.Intn(10) > 5 {
   153  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
   154  	}
   155  	m.alpnProtocols = make([]string, rand.Intn(5))
   156  	for i := range m.alpnProtocols {
   157  		m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand)
   158  	}
   159  	if rand.Intn(10) > 5 {
   160  		m.scts = true
   161  	}
   162  	m.keyShares = make([]keyShare, rand.Intn(4))
   163  	for i := range m.keyShares {
   164  		m.keyShares[i].group = CurveID(rand.Intn(30000))
   165  		m.keyShares[i].data = randomBytes(rand.Intn(300), rand)
   166  	}
   167  	m.supportedVersions = make([]uint16, rand.Intn(5))
   168  	for i := range m.supportedVersions {
   169  		m.supportedVersions[i] = uint16(rand.Intn(30000))
   170  	}
   171  	if rand.Intn(10) > 5 {
   172  		m.earlyData = true
   173  	}
   174  
   175  	return reflect.ValueOf(m)
   176  }
   177  
   178  func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   179  	m := &serverHelloMsg{}
   180  	m.vers = uint16(rand.Intn(65536))
   181  	m.random = randomBytes(32, rand)
   182  	m.sessionId = randomBytes(rand.Intn(32), rand)
   183  	m.cipherSuite = uint16(rand.Int31())
   184  	m.compressionMethod = uint8(rand.Intn(256))
   185  
   186  	if rand.Intn(10) > 5 {
   187  		m.nextProtoNeg = true
   188  
   189  		n := rand.Intn(10)
   190  		m.nextProtos = make([]string, n)
   191  		for i := 0; i < n; i++ {
   192  			m.nextProtos[i] = randomString(20, rand)
   193  		}
   194  	}
   195  
   196  	if rand.Intn(10) > 5 {
   197  		m.ocspStapling = true
   198  	}
   199  	if rand.Intn(10) > 5 {
   200  		m.ticketSupported = true
   201  	}
   202  	m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   203  
   204  	if rand.Intn(10) > 5 {
   205  		numSCTs := rand.Intn(4)
   206  		m.scts = make([][]byte, numSCTs)
   207  		for i := range m.scts {
   208  			m.scts[i] = randomBytes(rand.Intn(500), rand)
   209  		}
   210  	}
   211  
   212  	return reflect.ValueOf(m)
   213  }
   214  
   215  func (*serverHelloMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
   216  	m := &serverHelloMsg13{}
   217  	m.vers = uint16(rand.Intn(65536))
   218  	m.random = randomBytes(32, rand)
   219  	m.cipherSuite = uint16(rand.Int31())
   220  	m.keyShare.group = CurveID(rand.Intn(30000))
   221  	m.keyShare.data = randomBytes(rand.Intn(300), rand)
   222  	if rand.Intn(10) > 5 {
   223  		m.psk = true
   224  		m.pskIdentity = uint16(rand.Int31())
   225  	}
   226  
   227  	return reflect.ValueOf(m)
   228  }
   229  
   230  func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   231  	m := &encryptedExtensionsMsg{}
   232  	if rand.Intn(10) > 5 {
   233  		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   234  	}
   235  	if rand.Intn(10) > 5 {
   236  		m.earlyData = true
   237  	}
   238  
   239  	return reflect.ValueOf(m)
   240  }
   241  
   242  func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   243  	m := &certificateMsg{}
   244  	numCerts := rand.Intn(20)
   245  	m.certificates = make([][]byte, numCerts)
   246  	for i := 0; i < numCerts; i++ {
   247  		m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
   248  	}
   249  	return reflect.ValueOf(m)
   250  }
   251  
   252  func (*certificateMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
   253  	m := &certificateMsg13{}
   254  	numCerts := rand.Intn(20)
   255  	m.certificates = make([]certificateEntry, numCerts)
   256  	for i := 0; i < numCerts; i++ {
   257  		m.certificates[i].data = randomBytes(rand.Intn(10)+1, rand)
   258  		if rand.Intn(2) == 1 {
   259  			m.certificates[i].ocspStaple = randomBytes(rand.Intn(10)+1, rand)
   260  		}
   261  
   262  		numScts := rand.Intn(3)
   263  		for j := 0; j < numScts; j++ {
   264  			m.certificates[i].sctList = append(m.certificates[i].sctList, randomBytes(rand.Intn(10)+1, rand))
   265  		}
   266  	}
   267  	m.requestContext = randomBytes(rand.Intn(5), rand)
   268  	return reflect.ValueOf(m)
   269  }
   270  
   271  func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   272  	m := &certificateRequestMsg{}
   273  	m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
   274  	numCAs := rand.Intn(100)
   275  	m.certificateAuthorities = make([][]byte, numCAs)
   276  	for i := 0; i < numCAs; i++ {
   277  		m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
   278  	}
   279  	return reflect.ValueOf(m)
   280  }
   281  
   282  func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   283  	m := &certificateVerifyMsg{}
   284  	m.signature = randomBytes(rand.Intn(15)+1, rand)
   285  	return reflect.ValueOf(m)
   286  }
   287  
   288  func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   289  	m := &certificateStatusMsg{}
   290  	if rand.Intn(10) > 5 {
   291  		m.statusType = statusTypeOCSP
   292  		m.response = randomBytes(rand.Intn(10)+1, rand)
   293  	} else {
   294  		m.statusType = 42
   295  	}
   296  	return reflect.ValueOf(m)
   297  }
   298  
   299  func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   300  	m := &clientKeyExchangeMsg{}
   301  	m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
   302  	return reflect.ValueOf(m)
   303  }
   304  
   305  func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   306  	m := &finishedMsg{}
   307  	m.verifyData = randomBytes(12, rand)
   308  	return reflect.ValueOf(m)
   309  }
   310  
   311  func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   312  	m := &nextProtoMsg{}
   313  	m.proto = randomString(rand.Intn(255), rand)
   314  	return reflect.ValueOf(m)
   315  }
   316  
   317  func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   318  	m := &newSessionTicketMsg{}
   319  	m.ticket = randomBytes(rand.Intn(4), rand)
   320  	return reflect.ValueOf(m)
   321  }
   322  
   323  func (*newSessionTicketMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
   324  	m := &newSessionTicketMsg13{}
   325  	m.ageAdd = uint32(rand.Intn(0xffffffff))
   326  	m.lifetime = uint32(rand.Intn(0xffffffff))
   327  	m.ticket = randomBytes(rand.Intn(40), rand)
   328  	if rand.Intn(10) > 5 {
   329  		m.withEarlyDataInfo = true
   330  		m.maxEarlyDataLength = uint32(rand.Intn(0xffffffff))
   331  	}
   332  	return reflect.ValueOf(m)
   333  }
   334  
   335  func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
   336  	s := &sessionState{}
   337  	s.vers = uint16(rand.Intn(10000))
   338  	s.cipherSuite = uint16(rand.Intn(10000))
   339  	s.masterSecret = randomBytes(rand.Intn(100), rand)
   340  	numCerts := rand.Intn(20)
   341  	s.certificates = make([][]byte, numCerts)
   342  	for i := 0; i < numCerts; i++ {
   343  		s.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
   344  	}
   345  	return reflect.ValueOf(s)
   346  }
   347  
   348  func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value {
   349  	s := &sessionState13{}
   350  	s.vers = uint16(rand.Intn(10000))
   351  	s.suite = uint16(rand.Intn(10000))
   352  	s.ageAdd = uint32(rand.Intn(0xffffffff))
   353  	s.maxEarlyDataLen = uint32(rand.Intn(0xffffffff))
   354  	s.createdAt = uint64(rand.Int63n(0xfffffffffffffff))
   355  	s.resumptionSecret = randomBytes(rand.Intn(100), rand)
   356  	s.alpnProtocol = randomString(rand.Intn(100), rand)
   357  	s.SNI = randomString(rand.Intn(100), rand)
   358  	return reflect.ValueOf(s)
   359  }
   360  
   361  func TestRejectEmptySCTList(t *testing.T) {
   362  	// https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that
   363  	// empty SCT lists are invalid.
   364  
   365  	var random [32]byte
   366  	sct := []byte{0x42, 0x42, 0x42, 0x42}
   367  	serverHello := serverHelloMsg{
   368  		vers:   VersionTLS12,
   369  		random: random[:],
   370  		scts:   [][]byte{sct},
   371  	}
   372  	serverHelloBytes := serverHello.marshal()
   373  
   374  	var serverHelloCopy serverHelloMsg
   375  	if serverHelloCopy.unmarshal(serverHelloBytes) != alertSuccess {
   376  		t.Fatal("Failed to unmarshal initial message")
   377  	}
   378  
   379  	// Change serverHelloBytes so that the SCT list is empty
   380  	i := bytes.Index(serverHelloBytes, sct)
   381  	if i < 0 {
   382  		t.Fatal("Cannot find SCT in ServerHello")
   383  	}
   384  
   385  	var serverHelloEmptySCT []byte
   386  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
   387  	// Append the extension length and SCT list length for an empty list.
   388  	serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
   389  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
   390  
   391  	// Update the handshake message length.
   392  	serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
   393  	serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
   394  	serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
   395  
   396  	// Update the extensions length
   397  	serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
   398  	serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
   399  
   400  	if serverHelloCopy.unmarshal(serverHelloEmptySCT) == alertSuccess {
   401  		t.Fatal("Unmarshaled ServerHello with empty SCT list")
   402  	}
   403  }
   404  
   405  func TestRejectEmptySCT(t *testing.T) {
   406  	// Not only must the SCT list be non-empty, but the SCT elements must
   407  	// not be zero length.
   408  
   409  	var random [32]byte
   410  	serverHello := serverHelloMsg{
   411  		vers:   VersionTLS12,
   412  		random: random[:],
   413  		scts:   [][]byte{nil},
   414  	}
   415  	serverHelloBytes := serverHello.marshal()
   416  
   417  	var serverHelloCopy serverHelloMsg
   418  	if serverHelloCopy.unmarshal(serverHelloBytes) == alertSuccess {
   419  		t.Fatal("Unmarshaled ServerHello with zero-length SCT")
   420  	}
   421  }