github.com/Psiphon-Labs/tls-tris@v0.0.0-20230824155421-58bf6d336a9a/ticket.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/aes"
    10  	"crypto/cipher"
    11  	"crypto/hmac"
    12  	"crypto/sha256"
    13  	"crypto/subtle"
    14  	"errors"
    15  	"io"
    16  
    17  	// [Psiphon]
    18  	"crypto/rand"
    19  	"crypto/x509"
    20  	"math/big"
    21  	math_rand "math/rand"
    22  )
    23  
    24  // [Psiphon]
    25  var obfuscateSessionTickets = true
    26  
    27  // A SessionTicketSealer provides a way to securely encapsulate
    28  // session state for storage on the client. All methods are safe for
    29  // concurrent use.
    30  type SessionTicketSealer interface {
    31  	// Seal returns a session ticket value that can be later passed to Unseal
    32  	// to recover the content, usually by encrypting it. The ticket will be sent
    33  	// to the client to be stored, and will be sent back in plaintext, so it can
    34  	// be read and modified by an attacker.
    35  	Seal(cs *ConnectionState, content []byte) (ticket []byte, err error)
    36  
    37  	// Unseal returns a session ticket contents. The ticket can't be safely
    38  	// assumed to have been generated by Seal.
    39  	// If unable to unseal the ticket, the connection will proceed with a
    40  	// complete handshake.
    41  	Unseal(chi *ClientHelloInfo, ticket []byte) (content []byte, success bool)
    42  }
    43  
    44  // sessionState contains the information that is serialized into a session
    45  // ticket in order to later resume a connection.
    46  type sessionState struct {
    47  	vers         uint16
    48  	cipherSuite  uint16
    49  	usedEMS      bool
    50  	masterSecret []byte
    51  	certificates [][]byte
    52  	// usedOldKey is true if the ticket from which this session came from
    53  	// was encrypted with an older key and thus should be refreshed.
    54  	usedOldKey bool
    55  }
    56  
    57  func (s *sessionState) equal(i interface{}) bool {
    58  	s1, ok := i.(*sessionState)
    59  	if !ok {
    60  		return false
    61  	}
    62  
    63  	if s.vers != s1.vers ||
    64  		s.usedEMS != s1.usedEMS ||
    65  		s.cipherSuite != s1.cipherSuite ||
    66  		!bytes.Equal(s.masterSecret, s1.masterSecret) {
    67  		return false
    68  	}
    69  
    70  	if len(s.certificates) != len(s1.certificates) {
    71  		return false
    72  	}
    73  
    74  	for i := range s.certificates {
    75  		if !bytes.Equal(s.certificates[i], s1.certificates[i]) {
    76  			return false
    77  		}
    78  	}
    79  
    80  	return true
    81  }
    82  
    83  func (s *sessionState) marshal() []byte {
    84  	length := 2 + 2 + 2 + len(s.masterSecret) + 2
    85  	for _, cert := range s.certificates {
    86  		length += 4 + len(cert)
    87  	}
    88  
    89  	// [Psiphon]
    90  	// Pad golang TLS session ticket to a more typical size.
    91  	if obfuscateSessionTickets {
    92  		paddedSizes := []int{160, 176, 192, 208, 218, 224, 240, 255}
    93  		initialSize := 120
    94  		randomInt, err := rand.Int(rand.Reader, big.NewInt(int64(len(paddedSizes))))
    95  		index := 0
    96  		if err == nil {
    97  			index = int(randomInt.Int64())
    98  		} else {
    99  			index = math_rand.Intn(len(paddedSizes))
   100  		}
   101  		paddingSize := paddedSizes[index] - initialSize
   102  		length += paddingSize
   103  	}
   104  
   105  	ret := make([]byte, length)
   106  	x := ret
   107  	was_used := byte(0)
   108  	if s.usedEMS {
   109  		was_used = byte(0x80)
   110  	}
   111  
   112  	x[0] = byte(s.vers>>8) | byte(was_used)
   113  	x[1] = byte(s.vers)
   114  	x[2] = byte(s.cipherSuite >> 8)
   115  	x[3] = byte(s.cipherSuite)
   116  	x[4] = byte(len(s.masterSecret) >> 8)
   117  	x[5] = byte(len(s.masterSecret))
   118  	x = x[6:]
   119  	copy(x, s.masterSecret)
   120  	x = x[len(s.masterSecret):]
   121  
   122  	x[0] = byte(len(s.certificates) >> 8)
   123  	x[1] = byte(len(s.certificates))
   124  	x = x[2:]
   125  
   126  	for _, cert := range s.certificates {
   127  		x[0] = byte(len(cert) >> 24)
   128  		x[1] = byte(len(cert) >> 16)
   129  		x[2] = byte(len(cert) >> 8)
   130  		x[3] = byte(len(cert))
   131  		copy(x[4:], cert)
   132  		x = x[4+len(cert):]
   133  	}
   134  
   135  	return ret
   136  }
   137  
   138  func (s *sessionState) unmarshal(data []byte) alert {
   139  	if len(data) < 8 {
   140  		return alertDecodeError
   141  	}
   142  
   143  	s.vers = (uint16(data[0])<<8 | uint16(data[1])) & 0x7fff
   144  	s.cipherSuite = uint16(data[2])<<8 | uint16(data[3])
   145  	s.usedEMS = (data[0] & 0x80) == 0x80
   146  	masterSecretLen := int(data[4])<<8 | int(data[5])
   147  	data = data[6:]
   148  	if len(data) < masterSecretLen {
   149  		return alertDecodeError
   150  	}
   151  
   152  	s.masterSecret = data[:masterSecretLen]
   153  	data = data[masterSecretLen:]
   154  
   155  	if len(data) < 2 {
   156  		return alertDecodeError
   157  	}
   158  
   159  	numCerts := int(data[0])<<8 | int(data[1])
   160  	data = data[2:]
   161  
   162  	s.certificates = make([][]byte, numCerts)
   163  	for i := range s.certificates {
   164  		if len(data) < 4 {
   165  			return alertDecodeError
   166  		}
   167  		certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
   168  		data = data[4:]
   169  		if certLen < 0 {
   170  			return alertDecodeError
   171  		}
   172  		if len(data) < certLen {
   173  			return alertDecodeError
   174  		}
   175  		s.certificates[i] = data[:certLen]
   176  		data = data[certLen:]
   177  	}
   178  
   179  	// [Psiphon]
   180  	// Ignore padding for obfuscated session tickets
   181  	//if len(data) != 0 {
   182  	//	return alertDecodeError
   183  	//}
   184  	return alertSuccess
   185  }
   186  
   187  // [Psiphon]
   188  type ObfuscatedClientSessionState struct {
   189  	SessionTicket      []uint8
   190  	Vers               uint16
   191  	CipherSuite        uint16
   192  	MasterSecret       []byte
   193  	ServerCertificates []*x509.Certificate
   194  	VerifiedChains     [][]*x509.Certificate
   195  	UseEMS             bool
   196  }
   197  
   198  var obfuscatedSessionTicketCipherSuite = TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
   199  
   200  // [Psiphon]
   201  // NewObfuscatedClientSessionState produces obfuscated session tickets.
   202  //
   203  // Obfuscated Session Tickets
   204  //
   205  // Obfuscated session tickets is a network traffic obfuscation protocol that appears
   206  // to be valid TLS using session tickets. The client actually generates the session
   207  // ticket and encrypts it with a shared secret, enabling a TLS session that entirely
   208  // skips the most fingerprintable aspects of TLS.
   209  // The scheme is described here:
   210  // https://lists.torproject.org/pipermail/tor-dev/2016-September/011354.html
   211  //
   212  // Circumvention notes:
   213  //  - TLS session ticket implementations are widespread:
   214  //    https://istlsfastyet.com/#cdn-paas.
   215  //  - An adversary cannot easily block session ticket capability, as this requires
   216  //    a downgrade attack against TLS.
   217  //  - Anti-probing defence is provided, as the adversary must use the correct obfuscation
   218  //    shared secret to form valid obfuscation session ticket; otherwise server offers
   219  //    standard session tickets.
   220  //  - Limitation: an adversary with the obfuscation shared secret can decrypt the session
   221  //    ticket and observe the plaintext traffic. It's assumed that the adversary will not
   222  //    learn the obfuscated shared secret without also learning the address of the TLS
   223  //    server and blocking it anyway; it's also assumed that the TLS payload is not
   224  //    plaintext but is protected with some other security layer (e.g., SSH).
   225  //
   226  // Implementation notes:
   227  //   - The TLS ClientHello includes an SNI field, even when using session tickets, so
   228  //     the client should populate the ServerName.
   229  //   - Server should set its SetSessionTicketKeys with first a standard key, followed by
   230  //     the obfuscation shared secret.
   231  //   - Since the client creates the session ticket, it selects parameters that were not
   232  //     negotiated with the server, such as the cipher suite. It's implicitly assumed that
   233  //     the server can support the selected parameters.
   234  //   - Obfuscated session tickets are not supported for TLS 1.3 _clients_, which use a
   235  //     distinct scheme. Obfuscated session ticket support in this package is intended to
   236  //     support TLS 1.2 clients.
   237  //
   238  func NewObfuscatedClientSessionState(sharedSecret [32]byte) (*ObfuscatedClientSessionState, error) {
   239  
   240  	// Create a session ticket that wasn't actually issued by the server.
   241  	vers := uint16(VersionTLS12)
   242  	cipherSuite := obfuscatedSessionTicketCipherSuite
   243  	masterSecret := make([]byte, masterSecretLength)
   244  	_, err := rand.Read(masterSecret)
   245  	if err != nil {
   246  		return nil, err
   247  	}
   248  	serverState := &sessionState{
   249  		vers:         vers,
   250  		cipherSuite:  cipherSuite,
   251  		masterSecret: masterSecret,
   252  		certificates: nil,
   253  	}
   254  	c := &Conn{
   255  		config: &Config{
   256  			sessionTicketKeys: []ticketKey{ticketKeyFromBytes(sharedSecret)},
   257  		},
   258  	}
   259  	sessionTicket, err := c.encryptTicket(serverState.marshal())
   260  	if err != nil {
   261  		return nil, err
   262  	}
   263  
   264  	// ObfuscatedClientSessionState fields are used to construct
   265  	// ClientSessionState objects for use in ClientSessionCaches. The client will
   266  	// use this cache to pretend it got that session ticket from the server.
   267  	clientState := &ObfuscatedClientSessionState{
   268  		SessionTicket: sessionTicket,
   269  		Vers:          vers,
   270  		CipherSuite:   cipherSuite,
   271  		MasterSecret:  masterSecret,
   272  	}
   273  
   274  	return clientState, nil
   275  }
   276  
   277  func ContainsObfuscatedSessionTicketCipherSuite(cipherSuites []uint16) bool {
   278  	for _, cipherSuite := range cipherSuites {
   279  		if cipherSuite == obfuscatedSessionTicketCipherSuite {
   280  			return true
   281  		}
   282  	}
   283  	return false
   284  }
   285  
   286  type sessionState13 struct {
   287  	vers            uint16
   288  	suite           uint16
   289  	ageAdd          uint32
   290  	createdAt       uint64
   291  	maxEarlyDataLen uint32
   292  	pskSecret       []byte
   293  	alpnProtocol    string
   294  	SNI             string
   295  }
   296  
   297  func (s *sessionState13) equal(i interface{}) bool {
   298  	s1, ok := i.(*sessionState13)
   299  	if !ok {
   300  		return false
   301  	}
   302  
   303  	return s.vers == s1.vers &&
   304  		s.suite == s1.suite &&
   305  		s.ageAdd == s1.ageAdd &&
   306  		s.createdAt == s1.createdAt &&
   307  		s.maxEarlyDataLen == s1.maxEarlyDataLen &&
   308  		subtle.ConstantTimeCompare(s.pskSecret, s1.pskSecret) == 1 &&
   309  		s.alpnProtocol == s1.alpnProtocol &&
   310  		s.SNI == s1.SNI
   311  }
   312  
   313  func (s *sessionState13) marshal() []byte {
   314  	length := 2 + 2 + 4 + 8 + 4 + 2 + len(s.pskSecret) + 2 + len(s.alpnProtocol) + 2 + len(s.SNI)
   315  
   316  	x := make([]byte, length)
   317  	x[0] = byte(s.vers >> 8)
   318  	x[1] = byte(s.vers)
   319  	x[2] = byte(s.suite >> 8)
   320  	x[3] = byte(s.suite)
   321  	x[4] = byte(s.ageAdd >> 24)
   322  	x[5] = byte(s.ageAdd >> 16)
   323  	x[6] = byte(s.ageAdd >> 8)
   324  	x[7] = byte(s.ageAdd)
   325  	x[8] = byte(s.createdAt >> 56)
   326  	x[9] = byte(s.createdAt >> 48)
   327  	x[10] = byte(s.createdAt >> 40)
   328  	x[11] = byte(s.createdAt >> 32)
   329  	x[12] = byte(s.createdAt >> 24)
   330  	x[13] = byte(s.createdAt >> 16)
   331  	x[14] = byte(s.createdAt >> 8)
   332  	x[15] = byte(s.createdAt)
   333  	x[16] = byte(s.maxEarlyDataLen >> 24)
   334  	x[17] = byte(s.maxEarlyDataLen >> 16)
   335  	x[18] = byte(s.maxEarlyDataLen >> 8)
   336  	x[19] = byte(s.maxEarlyDataLen)
   337  	x[20] = byte(len(s.pskSecret) >> 8)
   338  	x[21] = byte(len(s.pskSecret))
   339  	copy(x[22:], s.pskSecret)
   340  	z := x[22+len(s.pskSecret):]
   341  	z[0] = byte(len(s.alpnProtocol) >> 8)
   342  	z[1] = byte(len(s.alpnProtocol))
   343  	copy(z[2:], s.alpnProtocol)
   344  	z = z[2+len(s.alpnProtocol):]
   345  	z[0] = byte(len(s.SNI) >> 8)
   346  	z[1] = byte(len(s.SNI))
   347  	copy(z[2:], s.SNI)
   348  
   349  	return x
   350  }
   351  
   352  func (s *sessionState13) unmarshal(data []byte) alert {
   353  	if len(data) < 24 {
   354  		return alertDecodeError
   355  	}
   356  
   357  	s.vers = uint16(data[0])<<8 | uint16(data[1])
   358  	s.suite = uint16(data[2])<<8 | uint16(data[3])
   359  	s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
   360  	s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 |
   361  		uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15])
   362  	s.maxEarlyDataLen = uint32(data[16])<<24 | uint32(data[17])<<16 | uint32(data[18])<<8 | uint32(data[19])
   363  
   364  	l := int(data[20])<<8 | int(data[21])
   365  	if len(data) < 22+l+2 {
   366  		return alertDecodeError
   367  	}
   368  	s.pskSecret = data[22 : 22+l]
   369  	z := data[22+l:]
   370  
   371  	l = int(z[0])<<8 | int(z[1])
   372  	if len(z) < 2+l+2 {
   373  		return alertDecodeError
   374  	}
   375  	s.alpnProtocol = string(z[2 : 2+l])
   376  	z = z[2+l:]
   377  
   378  	l = int(z[0])<<8 | int(z[1])
   379  	if len(z) != 2+l {
   380  		return alertDecodeError
   381  	}
   382  	s.SNI = string(z[2 : 2+l])
   383  
   384  	return alertSuccess
   385  }
   386  
   387  func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) {
   388  	encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(serialized)+sha256.Size)
   389  	keyName := encrypted[:ticketKeyNameLen]
   390  	iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
   391  	macBytes := encrypted[len(encrypted)-sha256.Size:]
   392  
   393  	if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
   394  		return nil, err
   395  	}
   396  	key := c.config.ticketKeys()[0]
   397  	copy(keyName, key.keyName[:])
   398  	block, err := aes.NewCipher(key.aesKey[:])
   399  	if err != nil {
   400  		return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
   401  	}
   402  	cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], serialized)
   403  
   404  	mac := hmac.New(sha256.New, key.hmacKey[:])
   405  	mac.Write(encrypted[:len(encrypted)-sha256.Size])
   406  	mac.Sum(macBytes[:0])
   407  
   408  	return encrypted, nil
   409  }
   410  
   411  func (c *Conn) decryptTicket(encrypted []byte) (serialized []byte, usedOldKey bool) {
   412  	if c.config.SessionTicketsDisabled ||
   413  		len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
   414  		return nil, false
   415  	}
   416  
   417  	keyName := encrypted[:ticketKeyNameLen]
   418  	iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
   419  	macBytes := encrypted[len(encrypted)-sha256.Size:]
   420  
   421  	keys := c.config.ticketKeys()
   422  	keyIndex := -1
   423  	for i, candidateKey := range keys {
   424  		if bytes.Equal(keyName, candidateKey.keyName[:]) {
   425  			keyIndex = i
   426  			break
   427  		}
   428  	}
   429  
   430  	if keyIndex == -1 {
   431  		return nil, false
   432  	}
   433  	key := &keys[keyIndex]
   434  
   435  	mac := hmac.New(sha256.New, key.hmacKey[:])
   436  	mac.Write(encrypted[:len(encrypted)-sha256.Size])
   437  	expected := mac.Sum(nil)
   438  
   439  	if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
   440  		return nil, false
   441  	}
   442  
   443  	block, err := aes.NewCipher(key.aesKey[:])
   444  	if err != nil {
   445  		return nil, false
   446  	}
   447  	ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
   448  	plaintext := ciphertext
   449  	cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
   450  
   451  	return plaintext, keyIndex > 0
   452  }