github.com/zebozhuang/go@v0.0.0-20200207033046-f8a98f6f5c5d/src/crypto/tls/handshake_messages_test.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"math/rand"
    10  	"reflect"
    11  	"strings"
    12  	"testing"
    13  	"testing/quick"
    14  )
    15  
    16  var tests = []interface{}{
    17  	&clientHelloMsg{},
    18  	&serverHelloMsg{},
    19  	&finishedMsg{},
    20  
    21  	&certificateMsg{},
    22  	&certificateRequestMsg{},
    23  	&certificateVerifyMsg{},
    24  	&certificateStatusMsg{},
    25  	&clientKeyExchangeMsg{},
    26  	&nextProtoMsg{},
    27  	&newSessionTicketMsg{},
    28  	&sessionState{},
    29  }
    30  
    31  type testMessage interface {
    32  	marshal() []byte
    33  	unmarshal([]byte) bool
    34  	equal(interface{}) bool
    35  }
    36  
    37  func TestMarshalUnmarshal(t *testing.T) {
    38  	rand := rand.New(rand.NewSource(0))
    39  
    40  	for i, iface := range tests {
    41  		ty := reflect.ValueOf(iface).Type()
    42  
    43  		n := 100
    44  		if testing.Short() {
    45  			n = 5
    46  		}
    47  		for j := 0; j < n; j++ {
    48  			v, ok := quick.Value(ty, rand)
    49  			if !ok {
    50  				t.Errorf("#%d: failed to create value", i)
    51  				break
    52  			}
    53  
    54  			m1 := v.Interface().(testMessage)
    55  			marshaled := m1.marshal()
    56  			m2 := iface.(testMessage)
    57  			if !m2.unmarshal(marshaled) {
    58  				t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
    59  				break
    60  			}
    61  			m2.marshal() // to fill any marshal cache in the message
    62  
    63  			if !m1.equal(m2) {
    64  				t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
    65  				break
    66  			}
    67  
    68  			if i >= 3 {
    69  				// The first three message types (ClientHello,
    70  				// ServerHello and Finished) are allowed to
    71  				// have parsable prefixes because the extension
    72  				// data is optional and the length of the
    73  				// Finished varies across versions.
    74  				for j := 0; j < len(marshaled); j++ {
    75  					if m2.unmarshal(marshaled[0:j]) {
    76  						t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
    77  						break
    78  					}
    79  				}
    80  			}
    81  		}
    82  	}
    83  }
    84  
    85  func TestFuzz(t *testing.T) {
    86  	rand := rand.New(rand.NewSource(0))
    87  	for _, iface := range tests {
    88  		m := iface.(testMessage)
    89  
    90  		for j := 0; j < 1000; j++ {
    91  			len := rand.Intn(100)
    92  			bytes := randomBytes(len, rand)
    93  			// This just looks for crashes due to bounds errors etc.
    94  			m.unmarshal(bytes)
    95  		}
    96  	}
    97  }
    98  
    99  func randomBytes(n int, rand *rand.Rand) []byte {
   100  	r := make([]byte, n)
   101  	for i := 0; i < n; i++ {
   102  		r[i] = byte(rand.Int31())
   103  	}
   104  	return r
   105  }
   106  
   107  func randomString(n int, rand *rand.Rand) string {
   108  	b := randomBytes(n, rand)
   109  	return string(b)
   110  }
   111  
   112  func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   113  	m := &clientHelloMsg{}
   114  	m.vers = uint16(rand.Intn(65536))
   115  	m.random = randomBytes(32, rand)
   116  	m.sessionId = randomBytes(rand.Intn(32), rand)
   117  	m.cipherSuites = make([]uint16, rand.Intn(63)+1)
   118  	for i := 0; i < len(m.cipherSuites); i++ {
   119  		m.cipherSuites[i] = uint16(rand.Int31())
   120  	}
   121  	m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
   122  	if rand.Intn(10) > 5 {
   123  		m.nextProtoNeg = true
   124  	}
   125  	if rand.Intn(10) > 5 {
   126  		m.serverName = randomString(rand.Intn(255), rand)
   127  		for strings.HasSuffix(m.serverName, ".") {
   128  			m.serverName = m.serverName[:len(m.serverName)-1]
   129  		}
   130  	}
   131  	m.ocspStapling = rand.Intn(10) > 5
   132  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   133  	m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
   134  	for i := range m.supportedCurves {
   135  		m.supportedCurves[i] = CurveID(rand.Intn(30000))
   136  	}
   137  	if rand.Intn(10) > 5 {
   138  		m.ticketSupported = true
   139  		if rand.Intn(10) > 5 {
   140  			m.sessionTicket = randomBytes(rand.Intn(300), rand)
   141  		}
   142  	}
   143  	if rand.Intn(10) > 5 {
   144  		m.signatureAndHashes = supportedSignatureAlgorithms
   145  	}
   146  	m.alpnProtocols = make([]string, rand.Intn(5))
   147  	for i := range m.alpnProtocols {
   148  		m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand)
   149  	}
   150  	if rand.Intn(10) > 5 {
   151  		m.scts = true
   152  	}
   153  
   154  	return reflect.ValueOf(m)
   155  }
   156  
   157  func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   158  	m := &serverHelloMsg{}
   159  	m.vers = uint16(rand.Intn(65536))
   160  	m.random = randomBytes(32, rand)
   161  	m.sessionId = randomBytes(rand.Intn(32), rand)
   162  	m.cipherSuite = uint16(rand.Int31())
   163  	m.compressionMethod = uint8(rand.Intn(256))
   164  
   165  	if rand.Intn(10) > 5 {
   166  		m.nextProtoNeg = true
   167  
   168  		n := rand.Intn(10)
   169  		m.nextProtos = make([]string, n)
   170  		for i := 0; i < n; i++ {
   171  			m.nextProtos[i] = randomString(20, rand)
   172  		}
   173  	}
   174  
   175  	if rand.Intn(10) > 5 {
   176  		m.ocspStapling = true
   177  	}
   178  	if rand.Intn(10) > 5 {
   179  		m.ticketSupported = true
   180  	}
   181  	m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   182  
   183  	if rand.Intn(10) > 5 {
   184  		numSCTs := rand.Intn(4)
   185  		m.scts = make([][]byte, numSCTs)
   186  		for i := range m.scts {
   187  			m.scts[i] = randomBytes(rand.Intn(500), rand)
   188  		}
   189  	}
   190  
   191  	return reflect.ValueOf(m)
   192  }
   193  
   194  func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   195  	m := &certificateMsg{}
   196  	numCerts := rand.Intn(20)
   197  	m.certificates = make([][]byte, numCerts)
   198  	for i := 0; i < numCerts; i++ {
   199  		m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
   200  	}
   201  	return reflect.ValueOf(m)
   202  }
   203  
   204  func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   205  	m := &certificateRequestMsg{}
   206  	m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
   207  	numCAs := rand.Intn(100)
   208  	m.certificateAuthorities = make([][]byte, numCAs)
   209  	for i := 0; i < numCAs; i++ {
   210  		m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
   211  	}
   212  	return reflect.ValueOf(m)
   213  }
   214  
   215  func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   216  	m := &certificateVerifyMsg{}
   217  	m.signature = randomBytes(rand.Intn(15)+1, rand)
   218  	return reflect.ValueOf(m)
   219  }
   220  
   221  func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   222  	m := &certificateStatusMsg{}
   223  	if rand.Intn(10) > 5 {
   224  		m.statusType = statusTypeOCSP
   225  		m.response = randomBytes(rand.Intn(10)+1, rand)
   226  	} else {
   227  		m.statusType = 42
   228  	}
   229  	return reflect.ValueOf(m)
   230  }
   231  
   232  func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   233  	m := &clientKeyExchangeMsg{}
   234  	m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
   235  	return reflect.ValueOf(m)
   236  }
   237  
   238  func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   239  	m := &finishedMsg{}
   240  	m.verifyData = randomBytes(12, rand)
   241  	return reflect.ValueOf(m)
   242  }
   243  
   244  func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   245  	m := &nextProtoMsg{}
   246  	m.proto = randomString(rand.Intn(255), rand)
   247  	return reflect.ValueOf(m)
   248  }
   249  
   250  func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   251  	m := &newSessionTicketMsg{}
   252  	m.ticket = randomBytes(rand.Intn(4), rand)
   253  	return reflect.ValueOf(m)
   254  }
   255  
   256  func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
   257  	s := &sessionState{}
   258  	s.vers = uint16(rand.Intn(10000))
   259  	s.cipherSuite = uint16(rand.Intn(10000))
   260  	s.masterSecret = randomBytes(rand.Intn(100), rand)
   261  	numCerts := rand.Intn(20)
   262  	s.certificates = make([][]byte, numCerts)
   263  	for i := 0; i < numCerts; i++ {
   264  		s.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
   265  	}
   266  	return reflect.ValueOf(s)
   267  }
   268  
   269  func TestRejectEmptySCTList(t *testing.T) {
   270  	// https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that
   271  	// empty SCT lists are invalid.
   272  
   273  	var random [32]byte
   274  	sct := []byte{0x42, 0x42, 0x42, 0x42}
   275  	serverHello := serverHelloMsg{
   276  		vers:   VersionTLS12,
   277  		random: random[:],
   278  		scts:   [][]byte{sct},
   279  	}
   280  	serverHelloBytes := serverHello.marshal()
   281  
   282  	var serverHelloCopy serverHelloMsg
   283  	if !serverHelloCopy.unmarshal(serverHelloBytes) {
   284  		t.Fatal("Failed to unmarshal initial message")
   285  	}
   286  
   287  	// Change serverHelloBytes so that the SCT list is empty
   288  	i := bytes.Index(serverHelloBytes, sct)
   289  	if i < 0 {
   290  		t.Fatal("Cannot find SCT in ServerHello")
   291  	}
   292  
   293  	var serverHelloEmptySCT []byte
   294  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
   295  	// Append the extension length and SCT list length for an empty list.
   296  	serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
   297  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
   298  
   299  	// Update the handshake message length.
   300  	serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
   301  	serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
   302  	serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
   303  
   304  	// Update the extensions length
   305  	serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
   306  	serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
   307  
   308  	if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
   309  		t.Fatal("Unmarshaled ServerHello with empty SCT list")
   310  	}
   311  }
   312  
   313  func TestRejectEmptySCT(t *testing.T) {
   314  	// Not only must the SCT list be non-empty, but the SCT elements must
   315  	// not be zero length.
   316  
   317  	var random [32]byte
   318  	serverHello := serverHelloMsg{
   319  		vers:   VersionTLS12,
   320  		random: random[:],
   321  		scts:   [][]byte{nil},
   322  	}
   323  	serverHelloBytes := serverHello.marshal()
   324  
   325  	var serverHelloCopy serverHelloMsg
   326  	if serverHelloCopy.unmarshal(serverHelloBytes) {
   327  		t.Fatal("Unmarshaled ServerHello with zero-length SCT")
   328  	}
   329  }