github.com/icodeface/tls@v0.0.0-20230910023335-34df9250cd12/ticket.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/aes"
    10  	"crypto/cipher"
    11  	"crypto/hmac"
    12  	"crypto/sha256"
    13  	"crypto/subtle"
    14  	"errors"
    15  	"github.com/icodeface/tls/internal/x/crypto/cryptobyte"
    16  	"io"
    17  )
    18  
    19  // sessionState contains the information that is serialized into a session
    20  // ticket in order to later resume a connection.
    21  type sessionState struct {
    22  	vers         uint16
    23  	cipherSuite  uint16
    24  	masterSecret []byte
    25  	certificates [][]byte
    26  	// usedOldKey is true if the ticket from which this session came from
    27  	// was encrypted with an older key and thus should be refreshed.
    28  	usedOldKey bool
    29  }
    30  
    31  func (s *sessionState) marshal() []byte {
    32  	length := 2 + 2 + 2 + len(s.masterSecret) + 2
    33  	for _, cert := range s.certificates {
    34  		length += 4 + len(cert)
    35  	}
    36  
    37  	ret := make([]byte, length)
    38  	x := ret
    39  	x[0] = byte(s.vers >> 8)
    40  	x[1] = byte(s.vers)
    41  	x[2] = byte(s.cipherSuite >> 8)
    42  	x[3] = byte(s.cipherSuite)
    43  	x[4] = byte(len(s.masterSecret) >> 8)
    44  	x[5] = byte(len(s.masterSecret))
    45  	x = x[6:]
    46  	copy(x, s.masterSecret)
    47  	x = x[len(s.masterSecret):]
    48  
    49  	x[0] = byte(len(s.certificates) >> 8)
    50  	x[1] = byte(len(s.certificates))
    51  	x = x[2:]
    52  
    53  	for _, cert := range s.certificates {
    54  		x[0] = byte(len(cert) >> 24)
    55  		x[1] = byte(len(cert) >> 16)
    56  		x[2] = byte(len(cert) >> 8)
    57  		x[3] = byte(len(cert))
    58  		copy(x[4:], cert)
    59  		x = x[4+len(cert):]
    60  	}
    61  
    62  	return ret
    63  }
    64  
    65  func (s *sessionState) unmarshal(data []byte) bool {
    66  	if len(data) < 8 {
    67  		return false
    68  	}
    69  
    70  	s.vers = uint16(data[0])<<8 | uint16(data[1])
    71  	s.cipherSuite = uint16(data[2])<<8 | uint16(data[3])
    72  	masterSecretLen := int(data[4])<<8 | int(data[5])
    73  	data = data[6:]
    74  	if len(data) < masterSecretLen {
    75  		return false
    76  	}
    77  
    78  	s.masterSecret = data[:masterSecretLen]
    79  	data = data[masterSecretLen:]
    80  
    81  	if len(data) < 2 {
    82  		return false
    83  	}
    84  
    85  	numCerts := int(data[0])<<8 | int(data[1])
    86  	data = data[2:]
    87  
    88  	s.certificates = make([][]byte, numCerts)
    89  	for i := range s.certificates {
    90  		if len(data) < 4 {
    91  			return false
    92  		}
    93  		certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
    94  		data = data[4:]
    95  		if certLen < 0 {
    96  			return false
    97  		}
    98  		if len(data) < certLen {
    99  			return false
   100  		}
   101  		s.certificates[i] = data[:certLen]
   102  		data = data[certLen:]
   103  	}
   104  
   105  	return len(data) == 0
   106  }
   107  
   108  // sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first
   109  // version (revision = 0) doesn't carry any of the information needed for 0-RTT
   110  // validation and the nonce is always empty.
   111  type sessionStateTLS13 struct {
   112  	// uint8 version  = 0x0304;
   113  	// uint8 revision = 0;
   114  	cipherSuite      uint16
   115  	createdAt        uint64
   116  	resumptionSecret []byte      // opaque resumption_master_secret<1..2^8-1>;
   117  	certificate      Certificate // CertificateEntry certificate_list<0..2^24-1>;
   118  }
   119  
   120  func (m *sessionStateTLS13) marshal() []byte {
   121  	var b cryptobyte.Builder
   122  	b.AddUint16(VersionTLS13)
   123  	b.AddUint8(0) // revision
   124  	b.AddUint16(m.cipherSuite)
   125  	addUint64(&b, m.createdAt)
   126  	b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   127  		b.AddBytes(m.resumptionSecret)
   128  	})
   129  	marshalCertificate(&b, m.certificate)
   130  	return b.BytesOrPanic()
   131  }
   132  
   133  func (m *sessionStateTLS13) unmarshal(data []byte) bool {
   134  	*m = sessionStateTLS13{}
   135  	s := cryptobyte.String(data)
   136  	var version uint16
   137  	var revision uint8
   138  	return s.ReadUint16(&version) &&
   139  		version == VersionTLS13 &&
   140  		s.ReadUint8(&revision) &&
   141  		revision == 0 &&
   142  		s.ReadUint16(&m.cipherSuite) &&
   143  		readUint64(&s, &m.createdAt) &&
   144  		readUint8LengthPrefixed(&s, &m.resumptionSecret) &&
   145  		len(m.resumptionSecret) != 0 &&
   146  		unmarshalCertificate(&s, &m.certificate) &&
   147  		s.Empty()
   148  }
   149  
   150  func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
   151  	encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size)
   152  	keyName := encrypted[:ticketKeyNameLen]
   153  	iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
   154  	macBytes := encrypted[len(encrypted)-sha256.Size:]
   155  
   156  	if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
   157  		return nil, err
   158  	}
   159  	key := c.config.ticketKeys()[0]
   160  	copy(keyName, key.keyName[:])
   161  	block, err := aes.NewCipher(key.aesKey[:])
   162  	if err != nil {
   163  		return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
   164  	}
   165  	cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state)
   166  
   167  	mac := hmac.New(sha256.New, key.hmacKey[:])
   168  	mac.Write(encrypted[:len(encrypted)-sha256.Size])
   169  	mac.Sum(macBytes[:0])
   170  
   171  	return encrypted, nil
   172  }
   173  
   174  func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) {
   175  	if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
   176  		return nil, false
   177  	}
   178  
   179  	keyName := encrypted[:ticketKeyNameLen]
   180  	iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
   181  	macBytes := encrypted[len(encrypted)-sha256.Size:]
   182  	ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
   183  
   184  	keys := c.config.ticketKeys()
   185  	keyIndex := -1
   186  	for i, candidateKey := range keys {
   187  		if bytes.Equal(keyName, candidateKey.keyName[:]) {
   188  			keyIndex = i
   189  			break
   190  		}
   191  	}
   192  
   193  	if keyIndex == -1 {
   194  		return nil, false
   195  	}
   196  	key := &keys[keyIndex]
   197  
   198  	mac := hmac.New(sha256.New, key.hmacKey[:])
   199  	mac.Write(encrypted[:len(encrypted)-sha256.Size])
   200  	expected := mac.Sum(nil)
   201  
   202  	if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
   203  		return nil, false
   204  	}
   205  
   206  	block, err := aes.NewCipher(key.aesKey[:])
   207  	if err != nil {
   208  		return nil, false
   209  	}
   210  	plaintext = make([]byte, len(ciphertext))
   211  	cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
   212  
   213  	return plaintext, keyIndex > 0
   214  }