github.com/zebozhuang/go@v0.0.0-20200207033046-f8a98f6f5c5d/src/crypto/tls/ticket.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/aes"
    10  	"crypto/cipher"
    11  	"crypto/hmac"
    12  	"crypto/sha256"
    13  	"crypto/subtle"
    14  	"errors"
    15  	"io"
    16  )
    17  
    18  // sessionState contains the information that is serialized into a session
    19  // ticket in order to later resume a connection.
    20  type sessionState struct {
    21  	vers         uint16
    22  	cipherSuite  uint16
    23  	masterSecret []byte
    24  	certificates [][]byte
    25  	// usedOldKey is true if the ticket from which this session came from
    26  	// was encrypted with an older key and thus should be refreshed.
    27  	usedOldKey bool
    28  }
    29  
    30  func (s *sessionState) equal(i interface{}) bool {
    31  	s1, ok := i.(*sessionState)
    32  	if !ok {
    33  		return false
    34  	}
    35  
    36  	if s.vers != s1.vers ||
    37  		s.cipherSuite != s1.cipherSuite ||
    38  		!bytes.Equal(s.masterSecret, s1.masterSecret) {
    39  		return false
    40  	}
    41  
    42  	if len(s.certificates) != len(s1.certificates) {
    43  		return false
    44  	}
    45  
    46  	for i := range s.certificates {
    47  		if !bytes.Equal(s.certificates[i], s1.certificates[i]) {
    48  			return false
    49  		}
    50  	}
    51  
    52  	return true
    53  }
    54  
    55  func (s *sessionState) marshal() []byte {
    56  	length := 2 + 2 + 2 + len(s.masterSecret) + 2
    57  	for _, cert := range s.certificates {
    58  		length += 4 + len(cert)
    59  	}
    60  
    61  	ret := make([]byte, length)
    62  	x := ret
    63  	x[0] = byte(s.vers >> 8)
    64  	x[1] = byte(s.vers)
    65  	x[2] = byte(s.cipherSuite >> 8)
    66  	x[3] = byte(s.cipherSuite)
    67  	x[4] = byte(len(s.masterSecret) >> 8)
    68  	x[5] = byte(len(s.masterSecret))
    69  	x = x[6:]
    70  	copy(x, s.masterSecret)
    71  	x = x[len(s.masterSecret):]
    72  
    73  	x[0] = byte(len(s.certificates) >> 8)
    74  	x[1] = byte(len(s.certificates))
    75  	x = x[2:]
    76  
    77  	for _, cert := range s.certificates {
    78  		x[0] = byte(len(cert) >> 24)
    79  		x[1] = byte(len(cert) >> 16)
    80  		x[2] = byte(len(cert) >> 8)
    81  		x[3] = byte(len(cert))
    82  		copy(x[4:], cert)
    83  		x = x[4+len(cert):]
    84  	}
    85  
    86  	return ret
    87  }
    88  
    89  func (s *sessionState) unmarshal(data []byte) bool {
    90  	if len(data) < 8 {
    91  		return false
    92  	}
    93  
    94  	s.vers = uint16(data[0])<<8 | uint16(data[1])
    95  	s.cipherSuite = uint16(data[2])<<8 | uint16(data[3])
    96  	masterSecretLen := int(data[4])<<8 | int(data[5])
    97  	data = data[6:]
    98  	if len(data) < masterSecretLen {
    99  		return false
   100  	}
   101  
   102  	s.masterSecret = data[:masterSecretLen]
   103  	data = data[masterSecretLen:]
   104  
   105  	if len(data) < 2 {
   106  		return false
   107  	}
   108  
   109  	numCerts := int(data[0])<<8 | int(data[1])
   110  	data = data[2:]
   111  
   112  	s.certificates = make([][]byte, numCerts)
   113  	for i := range s.certificates {
   114  		if len(data) < 4 {
   115  			return false
   116  		}
   117  		certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
   118  		data = data[4:]
   119  		if certLen < 0 {
   120  			return false
   121  		}
   122  		if len(data) < certLen {
   123  			return false
   124  		}
   125  		s.certificates[i] = data[:certLen]
   126  		data = data[certLen:]
   127  	}
   128  
   129  	return len(data) == 0
   130  }
   131  
   132  func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) {
   133  	serialized := state.marshal()
   134  	encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(serialized)+sha256.Size)
   135  	keyName := encrypted[:ticketKeyNameLen]
   136  	iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
   137  	macBytes := encrypted[len(encrypted)-sha256.Size:]
   138  
   139  	if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
   140  		return nil, err
   141  	}
   142  	key := c.config.ticketKeys()[0]
   143  	copy(keyName, key.keyName[:])
   144  	block, err := aes.NewCipher(key.aesKey[:])
   145  	if err != nil {
   146  		return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
   147  	}
   148  	cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], serialized)
   149  
   150  	mac := hmac.New(sha256.New, key.hmacKey[:])
   151  	mac.Write(encrypted[:len(encrypted)-sha256.Size])
   152  	mac.Sum(macBytes[:0])
   153  
   154  	return encrypted, nil
   155  }
   156  
   157  func (c *Conn) decryptTicket(encrypted []byte) (*sessionState, bool) {
   158  	if c.config.SessionTicketsDisabled ||
   159  		len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
   160  		return nil, false
   161  	}
   162  
   163  	keyName := encrypted[:ticketKeyNameLen]
   164  	iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
   165  	macBytes := encrypted[len(encrypted)-sha256.Size:]
   166  
   167  	keys := c.config.ticketKeys()
   168  	keyIndex := -1
   169  	for i, candidateKey := range keys {
   170  		if bytes.Equal(keyName, candidateKey.keyName[:]) {
   171  			keyIndex = i
   172  			break
   173  		}
   174  	}
   175  
   176  	if keyIndex == -1 {
   177  		return nil, false
   178  	}
   179  	key := &keys[keyIndex]
   180  
   181  	mac := hmac.New(sha256.New, key.hmacKey[:])
   182  	mac.Write(encrypted[:len(encrypted)-sha256.Size])
   183  	expected := mac.Sum(nil)
   184  
   185  	if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
   186  		return nil, false
   187  	}
   188  
   189  	block, err := aes.NewCipher(key.aesKey[:])
   190  	if err != nil {
   191  		return nil, false
   192  	}
   193  	ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
   194  	plaintext := ciphertext
   195  	cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
   196  
   197  	state := &sessionState{usedOldKey: keyIndex > 0}
   198  	ok := state.unmarshal(plaintext)
   199  	return state, ok
   200  }