github.com/fenixara/go@v0.0.0-20170127160404-96ea0918e670/src/crypto/tls/handshake_messages_test.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"math/rand"
    10  	"reflect"
    11  	"testing"
    12  	"testing/quick"
    13  )
    14  
    15  var tests = []interface{}{
    16  	&clientHelloMsg{},
    17  	&serverHelloMsg{},
    18  	&finishedMsg{},
    19  
    20  	&certificateMsg{},
    21  	&certificateRequestMsg{},
    22  	&certificateVerifyMsg{},
    23  	&certificateStatusMsg{},
    24  	&clientKeyExchangeMsg{},
    25  	&nextProtoMsg{},
    26  	&newSessionTicketMsg{},
    27  	&sessionState{},
    28  }
    29  
    30  type testMessage interface {
    31  	marshal() []byte
    32  	unmarshal([]byte) bool
    33  	equal(interface{}) bool
    34  }
    35  
    36  func TestMarshalUnmarshal(t *testing.T) {
    37  	rand := rand.New(rand.NewSource(0))
    38  
    39  	for i, iface := range tests {
    40  		ty := reflect.ValueOf(iface).Type()
    41  
    42  		n := 100
    43  		if testing.Short() {
    44  			n = 5
    45  		}
    46  		for j := 0; j < n; j++ {
    47  			v, ok := quick.Value(ty, rand)
    48  			if !ok {
    49  				t.Errorf("#%d: failed to create value", i)
    50  				break
    51  			}
    52  
    53  			m1 := v.Interface().(testMessage)
    54  			marshaled := m1.marshal()
    55  			m2 := iface.(testMessage)
    56  			if !m2.unmarshal(marshaled) {
    57  				t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
    58  				break
    59  			}
    60  			m2.marshal() // to fill any marshal cache in the message
    61  
    62  			if !m1.equal(m2) {
    63  				t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
    64  				break
    65  			}
    66  
    67  			if i >= 3 {
    68  				// The first three message types (ClientHello,
    69  				// ServerHello and Finished) are allowed to
    70  				// have parsable prefixes because the extension
    71  				// data is optional and the length of the
    72  				// Finished varies across versions.
    73  				for j := 0; j < len(marshaled); j++ {
    74  					if m2.unmarshal(marshaled[0:j]) {
    75  						t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
    76  						break
    77  					}
    78  				}
    79  			}
    80  		}
    81  	}
    82  }
    83  
    84  func TestFuzz(t *testing.T) {
    85  	rand := rand.New(rand.NewSource(0))
    86  	for _, iface := range tests {
    87  		m := iface.(testMessage)
    88  
    89  		for j := 0; j < 1000; j++ {
    90  			len := rand.Intn(100)
    91  			bytes := randomBytes(len, rand)
    92  			// This just looks for crashes due to bounds errors etc.
    93  			m.unmarshal(bytes)
    94  		}
    95  	}
    96  }
    97  
    98  func randomBytes(n int, rand *rand.Rand) []byte {
    99  	r := make([]byte, n)
   100  	for i := 0; i < n; i++ {
   101  		r[i] = byte(rand.Int31())
   102  	}
   103  	return r
   104  }
   105  
   106  func randomString(n int, rand *rand.Rand) string {
   107  	b := randomBytes(n, rand)
   108  	return string(b)
   109  }
   110  
   111  func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   112  	m := &clientHelloMsg{}
   113  	m.vers = uint16(rand.Intn(65536))
   114  	m.random = randomBytes(32, rand)
   115  	m.sessionId = randomBytes(rand.Intn(32), rand)
   116  	m.cipherSuites = make([]uint16, rand.Intn(63)+1)
   117  	for i := 0; i < len(m.cipherSuites); i++ {
   118  		m.cipherSuites[i] = uint16(rand.Int31())
   119  	}
   120  	m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
   121  	if rand.Intn(10) > 5 {
   122  		m.nextProtoNeg = true
   123  	}
   124  	if rand.Intn(10) > 5 {
   125  		m.serverName = randomString(rand.Intn(255), rand)
   126  	}
   127  	m.ocspStapling = rand.Intn(10) > 5
   128  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   129  	m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
   130  	for i := range m.supportedCurves {
   131  		m.supportedCurves[i] = CurveID(rand.Intn(30000))
   132  	}
   133  	if rand.Intn(10) > 5 {
   134  		m.ticketSupported = true
   135  		if rand.Intn(10) > 5 {
   136  			m.sessionTicket = randomBytes(rand.Intn(300), rand)
   137  		}
   138  	}
   139  	if rand.Intn(10) > 5 {
   140  		m.signatureAndHashes = supportedSignatureAlgorithms
   141  	}
   142  	m.alpnProtocols = make([]string, rand.Intn(5))
   143  	for i := range m.alpnProtocols {
   144  		m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand)
   145  	}
   146  	if rand.Intn(10) > 5 {
   147  		m.scts = true
   148  	}
   149  
   150  	return reflect.ValueOf(m)
   151  }
   152  
   153  func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   154  	m := &serverHelloMsg{}
   155  	m.vers = uint16(rand.Intn(65536))
   156  	m.random = randomBytes(32, rand)
   157  	m.sessionId = randomBytes(rand.Intn(32), rand)
   158  	m.cipherSuite = uint16(rand.Int31())
   159  	m.compressionMethod = uint8(rand.Intn(256))
   160  
   161  	if rand.Intn(10) > 5 {
   162  		m.nextProtoNeg = true
   163  
   164  		n := rand.Intn(10)
   165  		m.nextProtos = make([]string, n)
   166  		for i := 0; i < n; i++ {
   167  			m.nextProtos[i] = randomString(20, rand)
   168  		}
   169  	}
   170  
   171  	if rand.Intn(10) > 5 {
   172  		m.ocspStapling = true
   173  	}
   174  	if rand.Intn(10) > 5 {
   175  		m.ticketSupported = true
   176  	}
   177  	m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   178  
   179  	if rand.Intn(10) > 5 {
   180  		numSCTs := rand.Intn(4)
   181  		m.scts = make([][]byte, numSCTs)
   182  		for i := range m.scts {
   183  			m.scts[i] = randomBytes(rand.Intn(500), rand)
   184  		}
   185  	}
   186  
   187  	return reflect.ValueOf(m)
   188  }
   189  
   190  func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   191  	m := &certificateMsg{}
   192  	numCerts := rand.Intn(20)
   193  	m.certificates = make([][]byte, numCerts)
   194  	for i := 0; i < numCerts; i++ {
   195  		m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
   196  	}
   197  	return reflect.ValueOf(m)
   198  }
   199  
   200  func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   201  	m := &certificateRequestMsg{}
   202  	m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
   203  	numCAs := rand.Intn(100)
   204  	m.certificateAuthorities = make([][]byte, numCAs)
   205  	for i := 0; i < numCAs; i++ {
   206  		m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
   207  	}
   208  	return reflect.ValueOf(m)
   209  }
   210  
   211  func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   212  	m := &certificateVerifyMsg{}
   213  	m.signature = randomBytes(rand.Intn(15)+1, rand)
   214  	return reflect.ValueOf(m)
   215  }
   216  
   217  func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   218  	m := &certificateStatusMsg{}
   219  	if rand.Intn(10) > 5 {
   220  		m.statusType = statusTypeOCSP
   221  		m.response = randomBytes(rand.Intn(10)+1, rand)
   222  	} else {
   223  		m.statusType = 42
   224  	}
   225  	return reflect.ValueOf(m)
   226  }
   227  
   228  func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   229  	m := &clientKeyExchangeMsg{}
   230  	m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
   231  	return reflect.ValueOf(m)
   232  }
   233  
   234  func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   235  	m := &finishedMsg{}
   236  	m.verifyData = randomBytes(12, rand)
   237  	return reflect.ValueOf(m)
   238  }
   239  
   240  func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   241  	m := &nextProtoMsg{}
   242  	m.proto = randomString(rand.Intn(255), rand)
   243  	return reflect.ValueOf(m)
   244  }
   245  
   246  func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   247  	m := &newSessionTicketMsg{}
   248  	m.ticket = randomBytes(rand.Intn(4), rand)
   249  	return reflect.ValueOf(m)
   250  }
   251  
   252  func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
   253  	s := &sessionState{}
   254  	s.vers = uint16(rand.Intn(10000))
   255  	s.cipherSuite = uint16(rand.Intn(10000))
   256  	s.masterSecret = randomBytes(rand.Intn(100), rand)
   257  	numCerts := rand.Intn(20)
   258  	s.certificates = make([][]byte, numCerts)
   259  	for i := 0; i < numCerts; i++ {
   260  		s.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
   261  	}
   262  	return reflect.ValueOf(s)
   263  }
   264  
   265  func TestRejectEmptySCTList(t *testing.T) {
   266  	// https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that
   267  	// empty SCT lists are invalid.
   268  
   269  	var random [32]byte
   270  	sct := []byte{0x42, 0x42, 0x42, 0x42}
   271  	serverHello := serverHelloMsg{
   272  		vers:   VersionTLS12,
   273  		random: random[:],
   274  		scts:   [][]byte{sct},
   275  	}
   276  	serverHelloBytes := serverHello.marshal()
   277  
   278  	var serverHelloCopy serverHelloMsg
   279  	if !serverHelloCopy.unmarshal(serverHelloBytes) {
   280  		t.Fatal("Failed to unmarshal initial message")
   281  	}
   282  
   283  	// Change serverHelloBytes so that the SCT list is empty
   284  	i := bytes.Index(serverHelloBytes, sct)
   285  	if i < 0 {
   286  		t.Fatal("Cannot find SCT in ServerHello")
   287  	}
   288  
   289  	var serverHelloEmptySCT []byte
   290  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
   291  	// Append the extension length and SCT list length for an empty list.
   292  	serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
   293  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
   294  
   295  	// Update the handshake message length.
   296  	serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
   297  	serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
   298  	serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
   299  
   300  	// Update the extensions length
   301  	serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
   302  	serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
   303  
   304  	if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
   305  		t.Fatal("Unmarshaled ServerHello with empty SCT list")
   306  	}
   307  }
   308  
   309  func TestRejectEmptySCT(t *testing.T) {
   310  	// Not only must the SCT list be non-empty, but the SCT elements must
   311  	// not be zero length.
   312  
   313  	var random [32]byte
   314  	serverHello := serverHelloMsg{
   315  		vers:   VersionTLS12,
   316  		random: random[:],
   317  		scts:   [][]byte{nil},
   318  	}
   319  	serverHelloBytes := serverHello.marshal()
   320  
   321  	var serverHelloCopy serverHelloMsg
   322  	if serverHelloCopy.unmarshal(serverHelloBytes) {
   323  		t.Fatal("Unmarshaled ServerHello with zero-length SCT")
   324  	}
   325  }