github.com/goproxy0/go@v0.0.0-20171111080102-49cc0c489d2c/src/crypto/tls/ticket.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/aes"
    10  	"crypto/cipher"
    11  	"crypto/hmac"
    12  	"crypto/sha256"
    13  	"crypto/subtle"
    14  	"errors"
    15  	"io"
    16  )
    17  
    18  // A SessionTicketSealer provides a way to securely encapsulate
    19  // session state for storage on the client. All methods are safe for
    20  // concurrent use.
    21  type SessionTicketSealer interface {
    22  	// Seal returns a session ticket value that can be later passed to Unseal
    23  	// to recover the content, usually by encrypting it. The ticket will be sent
    24  	// to the client to be stored, and will be sent back in plaintext, so it can
    25  	// be read and modified by an attacker.
    26  	Seal(cs *ConnectionState, content []byte) (ticket []byte, err error)
    27  
    28  	// Unseal returns a session ticket contents. The ticket can't be safely
    29  	// assumed to have been generated by Seal.
    30  	// If unable to unseal the ticket, the connection will proceed with a
    31  	// complete handshake.
    32  	Unseal(chi *ClientHelloInfo, ticket []byte) (content []byte, success bool)
    33  }
    34  
    35  // sessionState contains the information that is serialized into a session
    36  // ticket in order to later resume a connection.
    37  type sessionState struct {
    38  	vers         uint16
    39  	cipherSuite  uint16
    40  	masterSecret []byte
    41  	certificates [][]byte
    42  	// usedOldKey is true if the ticket from which this session came from
    43  	// was encrypted with an older key and thus should be refreshed.
    44  	usedOldKey bool
    45  }
    46  
    47  func (s *sessionState) equal(i interface{}) bool {
    48  	s1, ok := i.(*sessionState)
    49  	if !ok {
    50  		return false
    51  	}
    52  
    53  	if s.vers != s1.vers ||
    54  		s.cipherSuite != s1.cipherSuite ||
    55  		!bytes.Equal(s.masterSecret, s1.masterSecret) {
    56  		return false
    57  	}
    58  
    59  	if len(s.certificates) != len(s1.certificates) {
    60  		return false
    61  	}
    62  
    63  	for i := range s.certificates {
    64  		if !bytes.Equal(s.certificates[i], s1.certificates[i]) {
    65  			return false
    66  		}
    67  	}
    68  
    69  	return true
    70  }
    71  
    72  func (s *sessionState) marshal() []byte {
    73  	length := 2 + 2 + 2 + len(s.masterSecret) + 2
    74  	for _, cert := range s.certificates {
    75  		length += 4 + len(cert)
    76  	}
    77  
    78  	ret := make([]byte, length)
    79  	x := ret
    80  	x[0] = byte(s.vers >> 8)
    81  	x[1] = byte(s.vers)
    82  	x[2] = byte(s.cipherSuite >> 8)
    83  	x[3] = byte(s.cipherSuite)
    84  	x[4] = byte(len(s.masterSecret) >> 8)
    85  	x[5] = byte(len(s.masterSecret))
    86  	x = x[6:]
    87  	copy(x, s.masterSecret)
    88  	x = x[len(s.masterSecret):]
    89  
    90  	x[0] = byte(len(s.certificates) >> 8)
    91  	x[1] = byte(len(s.certificates))
    92  	x = x[2:]
    93  
    94  	for _, cert := range s.certificates {
    95  		x[0] = byte(len(cert) >> 24)
    96  		x[1] = byte(len(cert) >> 16)
    97  		x[2] = byte(len(cert) >> 8)
    98  		x[3] = byte(len(cert))
    99  		copy(x[4:], cert)
   100  		x = x[4+len(cert):]
   101  	}
   102  
   103  	return ret
   104  }
   105  
   106  func (s *sessionState) unmarshal(data []byte) alert {
   107  	if len(data) < 8 {
   108  		return alertDecodeError
   109  	}
   110  
   111  	s.vers = uint16(data[0])<<8 | uint16(data[1])
   112  	s.cipherSuite = uint16(data[2])<<8 | uint16(data[3])
   113  	masterSecretLen := int(data[4])<<8 | int(data[5])
   114  	data = data[6:]
   115  	if len(data) < masterSecretLen {
   116  		return alertDecodeError
   117  	}
   118  
   119  	s.masterSecret = data[:masterSecretLen]
   120  	data = data[masterSecretLen:]
   121  
   122  	if len(data) < 2 {
   123  		return alertDecodeError
   124  	}
   125  
   126  	numCerts := int(data[0])<<8 | int(data[1])
   127  	data = data[2:]
   128  
   129  	s.certificates = make([][]byte, numCerts)
   130  	for i := range s.certificates {
   131  		if len(data) < 4 {
   132  			return alertDecodeError
   133  		}
   134  		certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
   135  		data = data[4:]
   136  		if certLen < 0 {
   137  			return alertDecodeError
   138  		}
   139  		if len(data) < certLen {
   140  			return alertDecodeError
   141  		}
   142  		s.certificates[i] = data[:certLen]
   143  		data = data[certLen:]
   144  	}
   145  
   146  	if len(data) != 0 {
   147  		return alertDecodeError
   148  	}
   149  	return alertSuccess
   150  }
   151  
   152  type sessionState13 struct {
   153  	vers             uint16
   154  	suite            uint16
   155  	ageAdd           uint32
   156  	createdAt        uint64
   157  	maxEarlyDataLen  uint32
   158  	resumptionSecret []byte
   159  	alpnProtocol     string
   160  	SNI              string
   161  }
   162  
   163  func (s *sessionState13) equal(i interface{}) bool {
   164  	s1, ok := i.(*sessionState13)
   165  	if !ok {
   166  		return false
   167  	}
   168  
   169  	return s.vers == s1.vers &&
   170  		s.suite == s1.suite &&
   171  		s.ageAdd == s1.ageAdd &&
   172  		s.createdAt == s1.createdAt &&
   173  		s.maxEarlyDataLen == s1.maxEarlyDataLen &&
   174  		bytes.Equal(s.resumptionSecret, s1.resumptionSecret) &&
   175  		s.alpnProtocol == s1.alpnProtocol &&
   176  		s.SNI == s1.SNI
   177  }
   178  
   179  func (s *sessionState13) marshal() []byte {
   180  	length := 2 + 2 + 4 + 8 + 4 + 2 + len(s.resumptionSecret) + 2 + len(s.alpnProtocol) + 2 + len(s.SNI)
   181  
   182  	x := make([]byte, length)
   183  	x[0] = byte(s.vers >> 8)
   184  	x[1] = byte(s.vers)
   185  	x[2] = byte(s.suite >> 8)
   186  	x[3] = byte(s.suite)
   187  	x[4] = byte(s.ageAdd >> 24)
   188  	x[5] = byte(s.ageAdd >> 16)
   189  	x[6] = byte(s.ageAdd >> 8)
   190  	x[7] = byte(s.ageAdd)
   191  	x[8] = byte(s.createdAt >> 56)
   192  	x[9] = byte(s.createdAt >> 48)
   193  	x[10] = byte(s.createdAt >> 40)
   194  	x[11] = byte(s.createdAt >> 32)
   195  	x[12] = byte(s.createdAt >> 24)
   196  	x[13] = byte(s.createdAt >> 16)
   197  	x[14] = byte(s.createdAt >> 8)
   198  	x[15] = byte(s.createdAt)
   199  	x[16] = byte(s.maxEarlyDataLen >> 24)
   200  	x[17] = byte(s.maxEarlyDataLen >> 16)
   201  	x[18] = byte(s.maxEarlyDataLen >> 8)
   202  	x[19] = byte(s.maxEarlyDataLen)
   203  	x[20] = byte(len(s.resumptionSecret) >> 8)
   204  	x[21] = byte(len(s.resumptionSecret))
   205  	copy(x[22:], s.resumptionSecret)
   206  	z := x[22+len(s.resumptionSecret):]
   207  	z[0] = byte(len(s.alpnProtocol) >> 8)
   208  	z[1] = byte(len(s.alpnProtocol))
   209  	copy(z[2:], s.alpnProtocol)
   210  	z = z[2+len(s.alpnProtocol):]
   211  	z[0] = byte(len(s.SNI) >> 8)
   212  	z[1] = byte(len(s.SNI))
   213  	copy(z[2:], s.SNI)
   214  
   215  	return x
   216  }
   217  
   218  func (s *sessionState13) unmarshal(data []byte) alert {
   219  	if len(data) < 24 {
   220  		return alertDecodeError
   221  	}
   222  
   223  	s.vers = uint16(data[0])<<8 | uint16(data[1])
   224  	s.suite = uint16(data[2])<<8 | uint16(data[3])
   225  	s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
   226  	s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 |
   227  		uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15])
   228  	s.maxEarlyDataLen = uint32(data[16])<<24 | uint32(data[17])<<16 | uint32(data[18])<<8 | uint32(data[19])
   229  
   230  	l := int(data[20])<<8 | int(data[21])
   231  	if len(data) < 22+l+2 {
   232  		return alertDecodeError
   233  	}
   234  	s.resumptionSecret = data[22 : 22+l]
   235  	z := data[22+l:]
   236  
   237  	l = int(z[0])<<8 | int(z[1])
   238  	if len(z) < 2+l+2 {
   239  		return alertDecodeError
   240  	}
   241  	s.alpnProtocol = string(z[2 : 2+l])
   242  	z = z[2+l:]
   243  
   244  	l = int(z[0])<<8 | int(z[1])
   245  	if len(z) != 2+l {
   246  		return alertDecodeError
   247  	}
   248  	s.SNI = string(z[2 : 2+l])
   249  
   250  	return alertSuccess
   251  }
   252  
   253  func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) {
   254  	encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(serialized)+sha256.Size)
   255  	keyName := encrypted[:ticketKeyNameLen]
   256  	iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
   257  	macBytes := encrypted[len(encrypted)-sha256.Size:]
   258  
   259  	if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
   260  		return nil, err
   261  	}
   262  	key := c.config.ticketKeys()[0]
   263  	copy(keyName, key.keyName[:])
   264  	block, err := aes.NewCipher(key.aesKey[:])
   265  	if err != nil {
   266  		return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
   267  	}
   268  	cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], serialized)
   269  
   270  	mac := hmac.New(sha256.New, key.hmacKey[:])
   271  	mac.Write(encrypted[:len(encrypted)-sha256.Size])
   272  	mac.Sum(macBytes[:0])
   273  
   274  	return encrypted, nil
   275  }
   276  
   277  func (c *Conn) decryptTicket(encrypted []byte) (serialized []byte, usedOldKey bool) {
   278  	if c.config.SessionTicketsDisabled ||
   279  		len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
   280  		return nil, false
   281  	}
   282  
   283  	keyName := encrypted[:ticketKeyNameLen]
   284  	iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
   285  	macBytes := encrypted[len(encrypted)-sha256.Size:]
   286  
   287  	keys := c.config.ticketKeys()
   288  	keyIndex := -1
   289  	for i, candidateKey := range keys {
   290  		if bytes.Equal(keyName, candidateKey.keyName[:]) {
   291  			keyIndex = i
   292  			break
   293  		}
   294  	}
   295  
   296  	if keyIndex == -1 {
   297  		return nil, false
   298  	}
   299  	key := &keys[keyIndex]
   300  
   301  	mac := hmac.New(sha256.New, key.hmacKey[:])
   302  	mac.Write(encrypted[:len(encrypted)-sha256.Size])
   303  	expected := mac.Sum(nil)
   304  
   305  	if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
   306  		return nil, false
   307  	}
   308  
   309  	block, err := aes.NewCipher(key.aesKey[:])
   310  	if err != nil {
   311  		return nil, false
   312  	}
   313  	ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
   314  	plaintext := ciphertext
   315  	cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
   316  
   317  	return plaintext, keyIndex > 0
   318  }