github.com/zmap/zcrypto@v0.0.0-20240512203510-0fef58d9a9db/tls/ticket.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/aes"
    10  	"crypto/cipher"
    11  	"crypto/hmac"
    12  	"crypto/sha256"
    13  	"crypto/subtle"
    14  	"errors"
    15  	"io"
    16  )
    17  
    18  // sessionState contains the information that is serialized into a session
    19  // ticket in order to later resume a connection.
    20  type sessionState struct {
    21  	vers                 uint16
    22  	cipherSuite          uint16
    23  	masterSecret         []byte
    24  	certificates         [][]byte
    25  	extendedMasterSecret bool
    26  }
    27  
    28  func (s *sessionState) equal(i interface{}) bool {
    29  	s1, ok := i.(*sessionState)
    30  	if !ok {
    31  		return false
    32  	}
    33  
    34  	if s.vers != s1.vers ||
    35  		s.cipherSuite != s1.cipherSuite ||
    36  		!bytes.Equal(s.masterSecret, s1.masterSecret) ||
    37  		s.extendedMasterSecret != s1.extendedMasterSecret {
    38  		return false
    39  	}
    40  
    41  	if len(s.certificates) != len(s1.certificates) {
    42  		return false
    43  	}
    44  
    45  	for i := range s.certificates {
    46  		if !bytes.Equal(s.certificates[i], s1.certificates[i]) {
    47  			return false
    48  		}
    49  	}
    50  
    51  	return true
    52  }
    53  
    54  func (s *sessionState) marshal() []byte {
    55  	length := 2 + 2 + 2 + len(s.masterSecret) + 2
    56  	for _, cert := range s.certificates {
    57  		length += 4 + len(cert)
    58  	}
    59  
    60  	ret := make([]byte, length)
    61  	x := ret
    62  	x[0] = byte(s.vers >> 8)
    63  	x[1] = byte(s.vers)
    64  	x[2] = byte(s.cipherSuite >> 8)
    65  	x[3] = byte(s.cipherSuite)
    66  	x[4] = byte(len(s.masterSecret) >> 8)
    67  	x[5] = byte(len(s.masterSecret))
    68  	x = x[6:]
    69  	copy(x, s.masterSecret)
    70  	x = x[len(s.masterSecret):]
    71  
    72  	x[0] = byte(len(s.certificates) >> 8)
    73  	x[1] = byte(len(s.certificates))
    74  	x = x[2:]
    75  
    76  	for _, cert := range s.certificates {
    77  		x[0] = byte(len(cert) >> 24)
    78  		x[1] = byte(len(cert) >> 16)
    79  		x[2] = byte(len(cert) >> 8)
    80  		x[3] = byte(len(cert))
    81  		copy(x[4:], cert)
    82  		x = x[4+len(cert):]
    83  	}
    84  
    85  	if s.extendedMasterSecret {
    86  		x[0] = 1
    87  	}
    88  
    89  	return ret
    90  }
    91  
    92  func (s *sessionState) unmarshal(data []byte) bool {
    93  	if len(data) < 8 {
    94  		return false
    95  	}
    96  
    97  	s.vers = uint16(data[0])<<8 | uint16(data[1])
    98  	s.cipherSuite = uint16(data[2])<<8 | uint16(data[3])
    99  	masterSecretLen := int(data[4])<<8 | int(data[5])
   100  	data = data[6:]
   101  	if len(data) < masterSecretLen {
   102  		return false
   103  	}
   104  
   105  	s.masterSecret = data[:masterSecretLen]
   106  	data = data[masterSecretLen:]
   107  
   108  	if len(data) < 2 {
   109  		return false
   110  	}
   111  
   112  	numCerts := int(data[0])<<8 | int(data[1])
   113  	data = data[2:]
   114  
   115  	s.certificates = make([][]byte, numCerts)
   116  	for i := range s.certificates {
   117  		if len(data) < 4 {
   118  			return false
   119  		}
   120  		certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
   121  		data = data[4:]
   122  		if certLen < 0 {
   123  			return false
   124  		}
   125  		if len(data) < certLen {
   126  			return false
   127  		}
   128  		s.certificates[i] = data[:certLen]
   129  		data = data[certLen:]
   130  	}
   131  
   132  	if len(data) > 0 {
   133  		return false
   134  	}
   135  
   136  	return true
   137  }
   138  
   139  func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) {
   140  	serialized := state.marshal()
   141  	encrypted := make([]byte, aes.BlockSize+len(serialized)+sha256.Size)
   142  	iv := encrypted[:aes.BlockSize]
   143  	macBytes := encrypted[len(encrypted)-sha256.Size:]
   144  
   145  	if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
   146  		return nil, err
   147  	}
   148  	block, err := aes.NewCipher(c.config.SessionTicketKey[:16])
   149  	if err != nil {
   150  		return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
   151  	}
   152  	cipher.NewCTR(block, iv).XORKeyStream(encrypted[aes.BlockSize:], serialized)
   153  
   154  	mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32])
   155  	mac.Write(encrypted[:len(encrypted)-sha256.Size])
   156  	mac.Sum(macBytes[:0])
   157  
   158  	return encrypted, nil
   159  }
   160  
   161  func (c *Conn) decryptTicket(encrypted []byte) (*sessionState, bool) {
   162  	if c.config.SessionTicketsDisabled ||
   163  		len(encrypted) < aes.BlockSize+sha256.Size {
   164  		return nil, false
   165  	}
   166  
   167  	iv := encrypted[:aes.BlockSize]
   168  	macBytes := encrypted[len(encrypted)-sha256.Size:]
   169  
   170  	mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32])
   171  	mac.Write(encrypted[:len(encrypted)-sha256.Size])
   172  	expected := mac.Sum(nil)
   173  
   174  	if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
   175  		return nil, false
   176  	}
   177  
   178  	block, err := aes.NewCipher(c.config.SessionTicketKey[:16])
   179  	if err != nil {
   180  		return nil, false
   181  	}
   182  	ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size]
   183  	plaintext := ciphertext
   184  	cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
   185  
   186  	state := new(sessionState)
   187  	ok := state.unmarshal(plaintext)
   188  	return state, ok
   189  }