github.com/Psiphon-Labs/tls-tris@v0.0.0-20230824155421-58bf6d336a9a/handshake_messages_test.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"math/rand"
    10  	"reflect"
    11  	"strings"
    12  	"testing"
    13  	"testing/quick"
    14  )
    15  
    16  var tests = []interface{}{
    17  	&clientHelloMsg{},
    18  	&serverHelloMsg{},
    19  	&finishedMsg{},
    20  
    21  	&certificateMsg{},
    22  	&certificateRequestMsg{},
    23  	&certificateRequestMsg13{},
    24  	&certificateVerifyMsg{},
    25  	&certificateStatusMsg{},
    26  	&clientKeyExchangeMsg{},
    27  	&nextProtoMsg{},
    28  	&newSessionTicketMsg{},
    29  	&sessionState{},
    30  	&encryptedExtensionsMsg{},
    31  	&certificateMsg13{},
    32  	&newSessionTicketMsg13{},
    33  	&sessionState13{},
    34  }
    35  
    36  type testMessage interface {
    37  	marshal() []byte
    38  	unmarshal([]byte) alert
    39  	equal(interface{}) bool
    40  }
    41  
    42  func TestMarshalUnmarshal(t *testing.T) {
    43  	rand := rand.New(rand.NewSource(0))
    44  
    45  	for i, iface := range tests {
    46  		ty := reflect.ValueOf(iface).Type()
    47  
    48  		n := 100
    49  		if testing.Short() {
    50  			n = 5
    51  		}
    52  		for j := 0; j < n; j++ {
    53  			v, ok := quick.Value(ty, rand)
    54  			if !ok {
    55  				t.Errorf("#%d: failed to create value", i)
    56  				break
    57  			}
    58  
    59  			m1 := v.Interface().(testMessage)
    60  			marshaled := m1.marshal()
    61  			m2 := iface.(testMessage)
    62  			if m2.unmarshal(marshaled) != alertSuccess {
    63  				t.Errorf("#%d.%d failed to unmarshal %#v %x", i, j, m1, marshaled)
    64  				break
    65  			}
    66  			m2.marshal() // to fill any marshal cache in the message
    67  
    68  			if !m1.equal(m2) {
    69  				t.Errorf("#%d.%d got:%#v want:%#v %x", i, j, m2, m1, marshaled)
    70  				break
    71  			}
    72  
    73  			if i >= 3 {
    74  				// The first three message types (ClientHello,
    75  				// ServerHello and Finished) are allowed to
    76  				// have parsable prefixes because the extension
    77  				// data is optional and the length of the
    78  				// Finished varies across versions.
    79  				for j := 0; j < len(marshaled); j++ {
    80  					if m2.unmarshal(marshaled[0:j]) == alertSuccess {
    81  						t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
    82  						break
    83  					}
    84  				}
    85  			}
    86  		}
    87  	}
    88  }
    89  
    90  func TestFuzz(t *testing.T) {
    91  	rand := rand.New(rand.NewSource(0))
    92  	for _, iface := range tests {
    93  		m := iface.(testMessage)
    94  
    95  		for j := 0; j < 1000; j++ {
    96  			len := rand.Intn(100)
    97  			bytes := randomBytes(len, rand)
    98  			// This just looks for crashes due to bounds errors etc.
    99  			m.unmarshal(bytes)
   100  		}
   101  	}
   102  }
   103  
   104  func randomBytes(n int, rand *rand.Rand) []byte {
   105  	r := make([]byte, n)
   106  	if _, err := rand.Read(r); err != nil {
   107  		panic("rand.Read failed: " + err.Error())
   108  	}
   109  	return r
   110  }
   111  
   112  func randomString(n int, rand *rand.Rand) string {
   113  	b := randomBytes(n, rand)
   114  	return string(b)
   115  }
   116  
   117  func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   118  	m := &clientHelloMsg{}
   119  	m.vers = uint16(rand.Intn(65536))
   120  	m.random = randomBytes(32, rand)
   121  	m.sessionId = randomBytes(rand.Intn(32), rand)
   122  	m.cipherSuites = make([]uint16, rand.Intn(63)+1)
   123  	for i := 0; i < len(m.cipherSuites); i++ {
   124  		cs := uint16(rand.Int31())
   125  		if cs == scsvRenegotiation {
   126  			cs += 1
   127  		}
   128  		m.cipherSuites[i] = cs
   129  	}
   130  	m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
   131  	if rand.Intn(10) > 5 {
   132  		m.nextProtoNeg = true
   133  	}
   134  	if rand.Intn(10) > 5 {
   135  		m.serverName = randomString(rand.Intn(255), rand)
   136  		for strings.HasSuffix(m.serverName, ".") {
   137  			m.serverName = m.serverName[:len(m.serverName)-1]
   138  		}
   139  	}
   140  	m.ocspStapling = rand.Intn(10) > 5
   141  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   142  	m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
   143  	for i := range m.supportedCurves {
   144  		m.supportedCurves[i] = CurveID(rand.Intn(30000))
   145  	}
   146  	if rand.Intn(10) > 5 {
   147  		m.ticketSupported = true
   148  		if rand.Intn(10) > 5 {
   149  			m.sessionTicket = randomBytes(rand.Intn(300), rand)
   150  		}
   151  	}
   152  	if rand.Intn(10) > 5 {
   153  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
   154  	}
   155  	m.alpnProtocols = make([]string, rand.Intn(5))
   156  	for i := range m.alpnProtocols {
   157  		m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand)
   158  	}
   159  	if rand.Intn(10) > 5 {
   160  		m.scts = true
   161  	}
   162  	m.keyShares = make([]keyShare, rand.Intn(4))
   163  	for i := range m.keyShares {
   164  		m.keyShares[i].group = CurveID(rand.Intn(30000))
   165  		m.keyShares[i].data = randomBytes(rand.Intn(300)+1, rand)
   166  	}
   167  	m.supportedVersions = make([]uint16, rand.Intn(5))
   168  	for i := range m.supportedVersions {
   169  		m.supportedVersions[i] = uint16(rand.Intn(30000))
   170  	}
   171  	if rand.Intn(10) > 5 {
   172  		m.earlyData = true
   173  	}
   174  	if rand.Intn(10) > 5 {
   175  		m.delegatedCredential = true
   176  	}
   177  	if rand.Intn(10) > 5 {
   178  		m.extendedMSSupported = true
   179  	}
   180  	return reflect.ValueOf(m)
   181  }
   182  
   183  func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   184  	m := &serverHelloMsg{}
   185  	m.vers = uint16(rand.Intn(65536))
   186  	m.random = randomBytes(32, rand)
   187  	m.sessionId = randomBytes(rand.Intn(32), rand)
   188  	m.cipherSuite = uint16(rand.Int31())
   189  	m.compressionMethod = uint8(rand.Intn(256))
   190  
   191  	if rand.Intn(10) > 5 {
   192  		m.nextProtoNeg = true
   193  
   194  		n := rand.Intn(10)
   195  		m.nextProtos = make([]string, n)
   196  		for i := 0; i < n; i++ {
   197  			m.nextProtos[i] = randomString(20, rand)
   198  		}
   199  	}
   200  
   201  	if rand.Intn(10) > 5 {
   202  		m.ocspStapling = true
   203  	}
   204  	if rand.Intn(10) > 5 {
   205  		m.ticketSupported = true
   206  	}
   207  	if rand.Intn(10) > 5 {
   208  		m.extendedMSSupported = true
   209  	}
   210  	m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   211  
   212  	if rand.Intn(10) > 5 {
   213  		numSCTs := rand.Intn(4)
   214  		m.scts = make([][]byte, numSCTs)
   215  		for i := range m.scts {
   216  			m.scts[i] = randomBytes(rand.Intn(500)+1, rand)
   217  		}
   218  	}
   219  
   220  	if rand.Intn(10) > 5 {
   221  		m.keyShare.group = CurveID(rand.Intn(30000) + 1)
   222  		m.keyShare.data = randomBytes(rand.Intn(300)+1, rand)
   223  	}
   224  	if rand.Intn(10) > 5 {
   225  		m.psk = true
   226  		m.pskIdentity = uint16(rand.Int31())
   227  	}
   228  
   229  	return reflect.ValueOf(m)
   230  }
   231  
   232  func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   233  	m := &encryptedExtensionsMsg{}
   234  	if rand.Intn(10) > 5 {
   235  		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   236  	}
   237  	if rand.Intn(10) > 5 {
   238  		m.earlyData = true
   239  	}
   240  
   241  	return reflect.ValueOf(m)
   242  }
   243  
   244  func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   245  	m := &certificateMsg{}
   246  	numCerts := rand.Intn(20)
   247  	m.certificates = make([][]byte, numCerts)
   248  	for i := 0; i < numCerts; i++ {
   249  		m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
   250  	}
   251  	return reflect.ValueOf(m)
   252  }
   253  
   254  func (*certificateMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
   255  	m := &certificateMsg13{}
   256  	numCerts := rand.Intn(20)
   257  	m.certificates = make([]certificateEntry, numCerts)
   258  	for i := 0; i < numCerts; i++ {
   259  		m.certificates[i].data = randomBytes(rand.Intn(10)+1, rand)
   260  		if rand.Intn(2) == 1 {
   261  			m.certificates[i].ocspStaple = randomBytes(rand.Intn(10)+1, rand)
   262  		}
   263  
   264  		numScts := rand.Intn(3)
   265  		for j := 0; j < numScts; j++ {
   266  			m.certificates[i].sctList = append(m.certificates[i].sctList, randomBytes(rand.Intn(10)+1, rand))
   267  		}
   268  	}
   269  	m.requestContext = randomBytes(rand.Intn(5), rand)
   270  	return reflect.ValueOf(m)
   271  }
   272  
   273  func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   274  	m := &certificateRequestMsg{}
   275  	m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
   276  	numCAs := rand.Intn(100)
   277  	m.certificateAuthorities = make([][]byte, numCAs)
   278  	for i := 0; i < numCAs; i++ {
   279  		m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
   280  	}
   281  	return reflect.ValueOf(m)
   282  }
   283  
   284  func (*certificateRequestMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
   285  	m := &certificateRequestMsg13{}
   286  	m.requestContext = randomBytes(rand.Intn(5), rand)
   287  	m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
   288  	numCAs := rand.Intn(100)
   289  	m.certificateAuthorities = make([][]byte, numCAs)
   290  	for i := 0; i < numCAs; i++ {
   291  		m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
   292  	}
   293  	return reflect.ValueOf(m)
   294  }
   295  
   296  func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   297  	m := &certificateVerifyMsg{}
   298  	m.signature = randomBytes(rand.Intn(15)+1, rand)
   299  	return reflect.ValueOf(m)
   300  }
   301  
   302  func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   303  	m := &certificateStatusMsg{}
   304  	if rand.Intn(10) > 5 {
   305  		m.statusType = statusTypeOCSP
   306  		m.response = randomBytes(rand.Intn(10)+1, rand)
   307  	} else {
   308  		m.statusType = 42
   309  	}
   310  	return reflect.ValueOf(m)
   311  }
   312  
   313  func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   314  	m := &clientKeyExchangeMsg{}
   315  	m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
   316  	return reflect.ValueOf(m)
   317  }
   318  
   319  func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   320  	m := &finishedMsg{}
   321  	m.verifyData = randomBytes(12, rand)
   322  	return reflect.ValueOf(m)
   323  }
   324  
   325  func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   326  	m := &nextProtoMsg{}
   327  	m.proto = randomString(rand.Intn(255), rand)
   328  	return reflect.ValueOf(m)
   329  }
   330  
   331  func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   332  	m := &newSessionTicketMsg{}
   333  	m.ticket = randomBytes(rand.Intn(4), rand)
   334  	return reflect.ValueOf(m)
   335  }
   336  
   337  func (*newSessionTicketMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
   338  	m := &newSessionTicketMsg13{}
   339  	m.ageAdd = uint32(rand.Intn(0xffffffff))
   340  	m.lifetime = uint32(rand.Intn(0xffffffff))
   341  	m.nonce = randomBytes(1+rand.Intn(255), rand)
   342  	m.ticket = randomBytes(1+rand.Intn(40), rand)
   343  	if rand.Intn(10) > 5 {
   344  		m.withEarlyDataInfo = true
   345  		m.maxEarlyDataLength = uint32(rand.Intn(0xffffffff))
   346  	}
   347  	return reflect.ValueOf(m)
   348  }
   349  
   350  func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
   351  	s := &sessionState{}
   352  	s.vers = uint16(rand.Intn(10000))
   353  	s.cipherSuite = uint16(rand.Intn(10000))
   354  	s.masterSecret = randomBytes(rand.Intn(100), rand)
   355  	if rand.Intn(10) > 5 {
   356  		s.usedEMS = true
   357  	}
   358  	numCerts := rand.Intn(20)
   359  	s.certificates = make([][]byte, numCerts)
   360  	for i := 0; i < numCerts; i++ {
   361  		s.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
   362  	}
   363  	return reflect.ValueOf(s)
   364  }
   365  
   366  func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value {
   367  	s := &sessionState13{}
   368  	s.vers = uint16(rand.Intn(10000))
   369  	s.suite = uint16(rand.Intn(10000))
   370  	s.ageAdd = uint32(rand.Intn(0xffffffff))
   371  	s.maxEarlyDataLen = uint32(rand.Intn(0xffffffff))
   372  	s.createdAt = uint64(rand.Int63n(0xfffffffffffffff))
   373  	s.pskSecret = randomBytes(rand.Intn(100), rand)
   374  	s.alpnProtocol = randomString(rand.Intn(100), rand)
   375  	s.SNI = randomString(rand.Intn(100), rand)
   376  	return reflect.ValueOf(s)
   377  }
   378  
   379  func TestRejectEmptySCTList(t *testing.T) {
   380  	// https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that
   381  	// empty SCT lists are invalid.
   382  
   383  	var random [32]byte
   384  	sct := []byte{0x42, 0x42, 0x42, 0x42}
   385  	serverHello := serverHelloMsg{
   386  		vers:   VersionTLS12,
   387  		random: random[:],
   388  		scts:   [][]byte{sct},
   389  	}
   390  	serverHelloBytes := serverHello.marshal()
   391  
   392  	var serverHelloCopy serverHelloMsg
   393  	if serverHelloCopy.unmarshal(serverHelloBytes) != alertSuccess {
   394  		t.Fatal("Failed to unmarshal initial message")
   395  	}
   396  
   397  	// Change serverHelloBytes so that the SCT list is empty
   398  	i := bytes.Index(serverHelloBytes, sct)
   399  	if i < 0 {
   400  		t.Fatal("Cannot find SCT in ServerHello")
   401  	}
   402  
   403  	var serverHelloEmptySCT []byte
   404  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
   405  	// Append the extension length and SCT list length for an empty list.
   406  	serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
   407  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
   408  
   409  	// Update the handshake message length.
   410  	serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
   411  	serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
   412  	serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
   413  
   414  	// Update the extensions length
   415  	serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
   416  	serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
   417  
   418  	if serverHelloCopy.unmarshal(serverHelloEmptySCT) == alertSuccess {
   419  		t.Fatal("Unmarshaled ServerHello with empty SCT list")
   420  	}
   421  }
   422  
   423  func TestRejectEmptySCT(t *testing.T) {
   424  	// Not only must the SCT list be non-empty, but the SCT elements must
   425  	// not be zero length.
   426  
   427  	var random [32]byte
   428  	serverHello := serverHelloMsg{
   429  		vers:   VersionTLS12,
   430  		random: random[:],
   431  		scts:   [][]byte{nil},
   432  	}
   433  	serverHelloBytes := serverHello.marshal()
   434  
   435  	var serverHelloCopy serverHelloMsg
   436  	if serverHelloCopy.unmarshal(serverHelloBytes) == alertSuccess {
   437  		t.Fatal("Unmarshaled ServerHello with zero-length SCT")
   438  	}
   439  }