github.com/kdevb0x/go@v0.0.0-20180115030120-39687051e9e7/src/crypto/tls/handshake_messages_test.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"math/rand"
    10  	"reflect"
    11  	"strings"
    12  	"testing"
    13  	"testing/quick"
    14  )
    15  
    16  var tests = []interface{}{
    17  	&clientHelloMsg{},
    18  	&serverHelloMsg{},
    19  	&finishedMsg{},
    20  
    21  	&certificateMsg{},
    22  	&certificateRequestMsg{},
    23  	&certificateVerifyMsg{},
    24  	&certificateStatusMsg{},
    25  	&clientKeyExchangeMsg{},
    26  	&nextProtoMsg{},
    27  	&newSessionTicketMsg{},
    28  	&sessionState{},
    29  }
    30  
    31  type testMessage interface {
    32  	marshal() []byte
    33  	unmarshal([]byte) bool
    34  	equal(interface{}) bool
    35  }
    36  
    37  func TestMarshalUnmarshal(t *testing.T) {
    38  	rand := rand.New(rand.NewSource(0))
    39  
    40  	for i, iface := range tests {
    41  		ty := reflect.ValueOf(iface).Type()
    42  
    43  		n := 100
    44  		if testing.Short() {
    45  			n = 5
    46  		}
    47  		for j := 0; j < n; j++ {
    48  			v, ok := quick.Value(ty, rand)
    49  			if !ok {
    50  				t.Errorf("#%d: failed to create value", i)
    51  				break
    52  			}
    53  
    54  			m1 := v.Interface().(testMessage)
    55  			marshaled := m1.marshal()
    56  			m2 := iface.(testMessage)
    57  			if !m2.unmarshal(marshaled) {
    58  				t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
    59  				break
    60  			}
    61  			m2.marshal() // to fill any marshal cache in the message
    62  
    63  			if !m1.equal(m2) {
    64  				t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
    65  				break
    66  			}
    67  
    68  			if i >= 3 {
    69  				// The first three message types (ClientHello,
    70  				// ServerHello and Finished) are allowed to
    71  				// have parsable prefixes because the extension
    72  				// data is optional and the length of the
    73  				// Finished varies across versions.
    74  				for j := 0; j < len(marshaled); j++ {
    75  					if m2.unmarshal(marshaled[0:j]) {
    76  						t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
    77  						break
    78  					}
    79  				}
    80  			}
    81  		}
    82  	}
    83  }
    84  
    85  func TestFuzz(t *testing.T) {
    86  	rand := rand.New(rand.NewSource(0))
    87  	for _, iface := range tests {
    88  		m := iface.(testMessage)
    89  
    90  		for j := 0; j < 1000; j++ {
    91  			len := rand.Intn(100)
    92  			bytes := randomBytes(len, rand)
    93  			// This just looks for crashes due to bounds errors etc.
    94  			m.unmarshal(bytes)
    95  		}
    96  	}
    97  }
    98  
    99  func randomBytes(n int, rand *rand.Rand) []byte {
   100  	r := make([]byte, n)
   101  	if _, err := rand.Read(r); err != nil {
   102  		panic("rand.Read failed: " + err.Error())
   103  	}
   104  	return r
   105  }
   106  
   107  func randomString(n int, rand *rand.Rand) string {
   108  	b := randomBytes(n, rand)
   109  	return string(b)
   110  }
   111  
   112  func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   113  	m := &clientHelloMsg{}
   114  	m.vers = uint16(rand.Intn(65536))
   115  	m.random = randomBytes(32, rand)
   116  	m.sessionId = randomBytes(rand.Intn(32), rand)
   117  	m.cipherSuites = make([]uint16, rand.Intn(63)+1)
   118  	for i := 0; i < len(m.cipherSuites); i++ {
   119  		cs := uint16(rand.Int31())
   120  		if cs == scsvRenegotiation {
   121  			cs += 1
   122  		}
   123  		m.cipherSuites[i] = cs
   124  	}
   125  	m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
   126  	if rand.Intn(10) > 5 {
   127  		m.nextProtoNeg = true
   128  	}
   129  	if rand.Intn(10) > 5 {
   130  		m.serverName = randomString(rand.Intn(255), rand)
   131  		for strings.HasSuffix(m.serverName, ".") {
   132  			m.serverName = m.serverName[:len(m.serverName)-1]
   133  		}
   134  	}
   135  	m.ocspStapling = rand.Intn(10) > 5
   136  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   137  	m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
   138  	for i := range m.supportedCurves {
   139  		m.supportedCurves[i] = CurveID(rand.Intn(30000))
   140  	}
   141  	if rand.Intn(10) > 5 {
   142  		m.ticketSupported = true
   143  		if rand.Intn(10) > 5 {
   144  			m.sessionTicket = randomBytes(rand.Intn(300), rand)
   145  		}
   146  	}
   147  	if rand.Intn(10) > 5 {
   148  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
   149  	}
   150  	m.alpnProtocols = make([]string, rand.Intn(5))
   151  	for i := range m.alpnProtocols {
   152  		m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand)
   153  	}
   154  	if rand.Intn(10) > 5 {
   155  		m.scts = true
   156  	}
   157  
   158  	return reflect.ValueOf(m)
   159  }
   160  
   161  func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   162  	m := &serverHelloMsg{}
   163  	m.vers = uint16(rand.Intn(65536))
   164  	m.random = randomBytes(32, rand)
   165  	m.sessionId = randomBytes(rand.Intn(32), rand)
   166  	m.cipherSuite = uint16(rand.Int31())
   167  	m.compressionMethod = uint8(rand.Intn(256))
   168  
   169  	if rand.Intn(10) > 5 {
   170  		m.nextProtoNeg = true
   171  
   172  		n := rand.Intn(10)
   173  		m.nextProtos = make([]string, n)
   174  		for i := 0; i < n; i++ {
   175  			m.nextProtos[i] = randomString(20, rand)
   176  		}
   177  	}
   178  
   179  	if rand.Intn(10) > 5 {
   180  		m.ocspStapling = true
   181  	}
   182  	if rand.Intn(10) > 5 {
   183  		m.ticketSupported = true
   184  	}
   185  	m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   186  
   187  	if rand.Intn(10) > 5 {
   188  		numSCTs := rand.Intn(4)
   189  		m.scts = make([][]byte, numSCTs)
   190  		for i := range m.scts {
   191  			m.scts[i] = randomBytes(rand.Intn(500), rand)
   192  		}
   193  	}
   194  
   195  	return reflect.ValueOf(m)
   196  }
   197  
   198  func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   199  	m := &certificateMsg{}
   200  	numCerts := rand.Intn(20)
   201  	m.certificates = make([][]byte, numCerts)
   202  	for i := 0; i < numCerts; i++ {
   203  		m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
   204  	}
   205  	return reflect.ValueOf(m)
   206  }
   207  
   208  func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   209  	m := &certificateRequestMsg{}
   210  	m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
   211  	numCAs := rand.Intn(100)
   212  	m.certificateAuthorities = make([][]byte, numCAs)
   213  	for i := 0; i < numCAs; i++ {
   214  		m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
   215  	}
   216  	return reflect.ValueOf(m)
   217  }
   218  
   219  func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   220  	m := &certificateVerifyMsg{}
   221  	m.signature = randomBytes(rand.Intn(15)+1, rand)
   222  	return reflect.ValueOf(m)
   223  }
   224  
   225  func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   226  	m := &certificateStatusMsg{}
   227  	if rand.Intn(10) > 5 {
   228  		m.statusType = statusTypeOCSP
   229  		m.response = randomBytes(rand.Intn(10)+1, rand)
   230  	} else {
   231  		m.statusType = 42
   232  	}
   233  	return reflect.ValueOf(m)
   234  }
   235  
   236  func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   237  	m := &clientKeyExchangeMsg{}
   238  	m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
   239  	return reflect.ValueOf(m)
   240  }
   241  
   242  func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   243  	m := &finishedMsg{}
   244  	m.verifyData = randomBytes(12, rand)
   245  	return reflect.ValueOf(m)
   246  }
   247  
   248  func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   249  	m := &nextProtoMsg{}
   250  	m.proto = randomString(rand.Intn(255), rand)
   251  	return reflect.ValueOf(m)
   252  }
   253  
   254  func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   255  	m := &newSessionTicketMsg{}
   256  	m.ticket = randomBytes(rand.Intn(4), rand)
   257  	return reflect.ValueOf(m)
   258  }
   259  
   260  func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
   261  	s := &sessionState{}
   262  	s.vers = uint16(rand.Intn(10000))
   263  	s.cipherSuite = uint16(rand.Intn(10000))
   264  	s.masterSecret = randomBytes(rand.Intn(100), rand)
   265  	numCerts := rand.Intn(20)
   266  	s.certificates = make([][]byte, numCerts)
   267  	for i := 0; i < numCerts; i++ {
   268  		s.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
   269  	}
   270  	return reflect.ValueOf(s)
   271  }
   272  
   273  func TestRejectEmptySCTList(t *testing.T) {
   274  	// https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that
   275  	// empty SCT lists are invalid.
   276  
   277  	var random [32]byte
   278  	sct := []byte{0x42, 0x42, 0x42, 0x42}
   279  	serverHello := serverHelloMsg{
   280  		vers:   VersionTLS12,
   281  		random: random[:],
   282  		scts:   [][]byte{sct},
   283  	}
   284  	serverHelloBytes := serverHello.marshal()
   285  
   286  	var serverHelloCopy serverHelloMsg
   287  	if !serverHelloCopy.unmarshal(serverHelloBytes) {
   288  		t.Fatal("Failed to unmarshal initial message")
   289  	}
   290  
   291  	// Change serverHelloBytes so that the SCT list is empty
   292  	i := bytes.Index(serverHelloBytes, sct)
   293  	if i < 0 {
   294  		t.Fatal("Cannot find SCT in ServerHello")
   295  	}
   296  
   297  	var serverHelloEmptySCT []byte
   298  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
   299  	// Append the extension length and SCT list length for an empty list.
   300  	serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
   301  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
   302  
   303  	// Update the handshake message length.
   304  	serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
   305  	serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
   306  	serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
   307  
   308  	// Update the extensions length
   309  	serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
   310  	serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
   311  
   312  	if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
   313  		t.Fatal("Unmarshaled ServerHello with empty SCT list")
   314  	}
   315  }
   316  
   317  func TestRejectEmptySCT(t *testing.T) {
   318  	// Not only must the SCT list be non-empty, but the SCT elements must
   319  	// not be zero length.
   320  
   321  	var random [32]byte
   322  	serverHello := serverHelloMsg{
   323  		vers:   VersionTLS12,
   324  		random: random[:],
   325  		scts:   [][]byte{nil},
   326  	}
   327  	serverHelloBytes := serverHello.marshal()
   328  
   329  	var serverHelloCopy serverHelloMsg
   330  	if serverHelloCopy.unmarshal(serverHelloBytes) {
   331  		t.Fatal("Unmarshaled ServerHello with zero-length SCT")
   332  	}
   333  }