github.com/icodeface/tls@v0.0.0-20230910023335-34df9250cd12/ticket.go (about) 1 // Copyright 2012 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bytes" 9 "crypto/aes" 10 "crypto/cipher" 11 "crypto/hmac" 12 "crypto/sha256" 13 "crypto/subtle" 14 "errors" 15 "github.com/icodeface/tls/internal/x/crypto/cryptobyte" 16 "io" 17 ) 18 19 // sessionState contains the information that is serialized into a session 20 // ticket in order to later resume a connection. 21 type sessionState struct { 22 vers uint16 23 cipherSuite uint16 24 masterSecret []byte 25 certificates [][]byte 26 // usedOldKey is true if the ticket from which this session came from 27 // was encrypted with an older key and thus should be refreshed. 28 usedOldKey bool 29 } 30 31 func (s *sessionState) marshal() []byte { 32 length := 2 + 2 + 2 + len(s.masterSecret) + 2 33 for _, cert := range s.certificates { 34 length += 4 + len(cert) 35 } 36 37 ret := make([]byte, length) 38 x := ret 39 x[0] = byte(s.vers >> 8) 40 x[1] = byte(s.vers) 41 x[2] = byte(s.cipherSuite >> 8) 42 x[3] = byte(s.cipherSuite) 43 x[4] = byte(len(s.masterSecret) >> 8) 44 x[5] = byte(len(s.masterSecret)) 45 x = x[6:] 46 copy(x, s.masterSecret) 47 x = x[len(s.masterSecret):] 48 49 x[0] = byte(len(s.certificates) >> 8) 50 x[1] = byte(len(s.certificates)) 51 x = x[2:] 52 53 for _, cert := range s.certificates { 54 x[0] = byte(len(cert) >> 24) 55 x[1] = byte(len(cert) >> 16) 56 x[2] = byte(len(cert) >> 8) 57 x[3] = byte(len(cert)) 58 copy(x[4:], cert) 59 x = x[4+len(cert):] 60 } 61 62 return ret 63 } 64 65 func (s *sessionState) unmarshal(data []byte) bool { 66 if len(data) < 8 { 67 return false 68 } 69 70 s.vers = uint16(data[0])<<8 | uint16(data[1]) 71 s.cipherSuite = uint16(data[2])<<8 | uint16(data[3]) 72 masterSecretLen := int(data[4])<<8 | int(data[5]) 73 data = data[6:] 74 if len(data) < masterSecretLen { 75 return false 76 } 77 78 s.masterSecret = data[:masterSecretLen] 79 data = data[masterSecretLen:] 80 81 if len(data) < 2 { 82 return false 83 } 84 85 numCerts := int(data[0])<<8 | int(data[1]) 86 data = data[2:] 87 88 s.certificates = make([][]byte, numCerts) 89 for i := range s.certificates { 90 if len(data) < 4 { 91 return false 92 } 93 certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3]) 94 data = data[4:] 95 if certLen < 0 { 96 return false 97 } 98 if len(data) < certLen { 99 return false 100 } 101 s.certificates[i] = data[:certLen] 102 data = data[certLen:] 103 } 104 105 return len(data) == 0 106 } 107 108 // sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first 109 // version (revision = 0) doesn't carry any of the information needed for 0-RTT 110 // validation and the nonce is always empty. 111 type sessionStateTLS13 struct { 112 // uint8 version = 0x0304; 113 // uint8 revision = 0; 114 cipherSuite uint16 115 createdAt uint64 116 resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>; 117 certificate Certificate // CertificateEntry certificate_list<0..2^24-1>; 118 } 119 120 func (m *sessionStateTLS13) marshal() []byte { 121 var b cryptobyte.Builder 122 b.AddUint16(VersionTLS13) 123 b.AddUint8(0) // revision 124 b.AddUint16(m.cipherSuite) 125 addUint64(&b, m.createdAt) 126 b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { 127 b.AddBytes(m.resumptionSecret) 128 }) 129 marshalCertificate(&b, m.certificate) 130 return b.BytesOrPanic() 131 } 132 133 func (m *sessionStateTLS13) unmarshal(data []byte) bool { 134 *m = sessionStateTLS13{} 135 s := cryptobyte.String(data) 136 var version uint16 137 var revision uint8 138 return s.ReadUint16(&version) && 139 version == VersionTLS13 && 140 s.ReadUint8(&revision) && 141 revision == 0 && 142 s.ReadUint16(&m.cipherSuite) && 143 readUint64(&s, &m.createdAt) && 144 readUint8LengthPrefixed(&s, &m.resumptionSecret) && 145 len(m.resumptionSecret) != 0 && 146 unmarshalCertificate(&s, &m.certificate) && 147 s.Empty() 148 } 149 150 func (c *Conn) encryptTicket(state []byte) ([]byte, error) { 151 encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size) 152 keyName := encrypted[:ticketKeyNameLen] 153 iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] 154 macBytes := encrypted[len(encrypted)-sha256.Size:] 155 156 if _, err := io.ReadFull(c.config.rand(), iv); err != nil { 157 return nil, err 158 } 159 key := c.config.ticketKeys()[0] 160 copy(keyName, key.keyName[:]) 161 block, err := aes.NewCipher(key.aesKey[:]) 162 if err != nil { 163 return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) 164 } 165 cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state) 166 167 mac := hmac.New(sha256.New, key.hmacKey[:]) 168 mac.Write(encrypted[:len(encrypted)-sha256.Size]) 169 mac.Sum(macBytes[:0]) 170 171 return encrypted, nil 172 } 173 174 func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) { 175 if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size { 176 return nil, false 177 } 178 179 keyName := encrypted[:ticketKeyNameLen] 180 iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] 181 macBytes := encrypted[len(encrypted)-sha256.Size:] 182 ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size] 183 184 keys := c.config.ticketKeys() 185 keyIndex := -1 186 for i, candidateKey := range keys { 187 if bytes.Equal(keyName, candidateKey.keyName[:]) { 188 keyIndex = i 189 break 190 } 191 } 192 193 if keyIndex == -1 { 194 return nil, false 195 } 196 key := &keys[keyIndex] 197 198 mac := hmac.New(sha256.New, key.hmacKey[:]) 199 mac.Write(encrypted[:len(encrypted)-sha256.Size]) 200 expected := mac.Sum(nil) 201 202 if subtle.ConstantTimeCompare(macBytes, expected) != 1 { 203 return nil, false 204 } 205 206 block, err := aes.NewCipher(key.aesKey[:]) 207 if err != nil { 208 return nil, false 209 } 210 plaintext = make([]byte, len(ciphertext)) 211 cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext) 212 213 return plaintext, keyIndex > 0 214 }