github.com/Hyperledger-TWGC/tjfoc-gm@v1.4.0/gmtls/ticket.go (about) 1 /* 2 Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 */ 15 16 package gmtls 17 18 import ( 19 "bytes" 20 "crypto/aes" 21 "crypto/cipher" 22 "crypto/hmac" 23 "crypto/sha256" 24 "crypto/subtle" 25 "errors" 26 "io" 27 ) 28 29 // sessionState contains the information that is serialized into a session 30 // ticket in order to later resume a connection. 31 type sessionState struct { 32 vers uint16 33 cipherSuite uint16 34 masterSecret []byte 35 certificates [][]byte 36 // usedOldKey is true if the ticket from which this session came from 37 // was encrypted with an older key and thus should be refreshed. 38 usedOldKey bool 39 } 40 41 func (s *sessionState) equal(i interface{}) bool { 42 s1, ok := i.(*sessionState) 43 if !ok { 44 return false 45 } 46 47 if s.vers != s1.vers || 48 s.cipherSuite != s1.cipherSuite || 49 !bytes.Equal(s.masterSecret, s1.masterSecret) { 50 return false 51 } 52 53 if len(s.certificates) != len(s1.certificates) { 54 return false 55 } 56 57 for i := range s.certificates { 58 if !bytes.Equal(s.certificates[i], s1.certificates[i]) { 59 return false 60 } 61 } 62 63 return true 64 } 65 66 func (s *sessionState) marshal() []byte { 67 length := 2 + 2 + 2 + len(s.masterSecret) + 2 68 for _, cert := range s.certificates { 69 length += 4 + len(cert) 70 } 71 72 ret := make([]byte, length) 73 x := ret 74 x[0] = byte(s.vers >> 8) 75 x[1] = byte(s.vers) 76 x[2] = byte(s.cipherSuite >> 8) 77 x[3] = byte(s.cipherSuite) 78 x[4] = byte(len(s.masterSecret) >> 8) 79 x[5] = byte(len(s.masterSecret)) 80 x = x[6:] 81 copy(x, s.masterSecret) 82 x = x[len(s.masterSecret):] 83 84 x[0] = byte(len(s.certificates) >> 8) 85 x[1] = byte(len(s.certificates)) 86 x = x[2:] 87 88 for _, cert := range s.certificates { 89 x[0] = byte(len(cert) >> 24) 90 x[1] = byte(len(cert) >> 16) 91 x[2] = byte(len(cert) >> 8) 92 x[3] = byte(len(cert)) 93 copy(x[4:], cert) 94 x = x[4+len(cert):] 95 } 96 97 return ret 98 } 99 100 func (s *sessionState) unmarshal(data []byte) bool { 101 if len(data) < 8 { 102 return false 103 } 104 105 s.vers = uint16(data[0])<<8 | uint16(data[1]) 106 s.cipherSuite = uint16(data[2])<<8 | uint16(data[3]) 107 masterSecretLen := int(data[4])<<8 | int(data[5]) 108 data = data[6:] 109 if len(data) < masterSecretLen { 110 return false 111 } 112 113 s.masterSecret = data[:masterSecretLen] 114 data = data[masterSecretLen:] 115 116 if len(data) < 2 { 117 return false 118 } 119 120 numCerts := int(data[0])<<8 | int(data[1]) 121 data = data[2:] 122 123 s.certificates = make([][]byte, numCerts) 124 for i := range s.certificates { 125 if len(data) < 4 { 126 return false 127 } 128 certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3]) 129 data = data[4:] 130 if certLen < 0 { 131 return false 132 } 133 if len(data) < certLen { 134 return false 135 } 136 s.certificates[i] = data[:certLen] 137 data = data[certLen:] 138 } 139 140 return len(data) == 0 141 } 142 143 func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) { 144 serialized := state.marshal() 145 encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(serialized)+sha256.Size) 146 keyName := encrypted[:ticketKeyNameLen] 147 iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] 148 macBytes := encrypted[len(encrypted)-sha256.Size:] 149 150 if _, err := io.ReadFull(c.config.rand(), iv); err != nil { 151 return nil, err 152 } 153 key := c.config.ticketKeys()[0] 154 copy(keyName, key.keyName[:]) 155 block, err := aes.NewCipher(key.aesKey[:]) 156 if err != nil { 157 return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) 158 } 159 cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], serialized) 160 161 mac := hmac.New(sha256.New, key.hmacKey[:]) 162 mac.Write(encrypted[:len(encrypted)-sha256.Size]) 163 mac.Sum(macBytes[:0]) 164 165 return encrypted, nil 166 } 167 168 func (c *Conn) decryptTicket(encrypted []byte) (*sessionState, bool) { 169 if c.config.SessionTicketsDisabled || 170 len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size { 171 return nil, false 172 } 173 174 keyName := encrypted[:ticketKeyNameLen] 175 iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] 176 macBytes := encrypted[len(encrypted)-sha256.Size:] 177 178 keys := c.config.ticketKeys() 179 keyIndex := -1 180 for i, candidateKey := range keys { 181 if bytes.Equal(keyName, candidateKey.keyName[:]) { 182 keyIndex = i 183 break 184 } 185 } 186 187 if keyIndex == -1 { 188 return nil, false 189 } 190 key := &keys[keyIndex] 191 192 mac := hmac.New(sha256.New, key.hmacKey[:]) 193 mac.Write(encrypted[:len(encrypted)-sha256.Size]) 194 expected := mac.Sum(nil) 195 196 if subtle.ConstantTimeCompare(macBytes, expected) != 1 { 197 return nil, false 198 } 199 200 block, err := aes.NewCipher(key.aesKey[:]) 201 if err != nil { 202 return nil, false 203 } 204 ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size] 205 plaintext := ciphertext 206 cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext) 207 208 state := &sessionState{usedOldKey: keyIndex > 0} 209 ok := state.unmarshal(plaintext) 210 return state, ok 211 }