github.com/Hyperledger-TWGC/tjfoc-gm@v1.4.0/gmtls/ticket.go (about)

     1  /*
     2  Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7  	http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  package gmtls
    17  
    18  import (
    19  	"bytes"
    20  	"crypto/aes"
    21  	"crypto/cipher"
    22  	"crypto/hmac"
    23  	"crypto/sha256"
    24  	"crypto/subtle"
    25  	"errors"
    26  	"io"
    27  )
    28  
    29  // sessionState contains the information that is serialized into a session
    30  // ticket in order to later resume a connection.
    31  type sessionState struct {
    32  	vers         uint16
    33  	cipherSuite  uint16
    34  	masterSecret []byte
    35  	certificates [][]byte
    36  	// usedOldKey is true if the ticket from which this session came from
    37  	// was encrypted with an older key and thus should be refreshed.
    38  	usedOldKey bool
    39  }
    40  
    41  func (s *sessionState) equal(i interface{}) bool {
    42  	s1, ok := i.(*sessionState)
    43  	if !ok {
    44  		return false
    45  	}
    46  
    47  	if s.vers != s1.vers ||
    48  		s.cipherSuite != s1.cipherSuite ||
    49  		!bytes.Equal(s.masterSecret, s1.masterSecret) {
    50  		return false
    51  	}
    52  
    53  	if len(s.certificates) != len(s1.certificates) {
    54  		return false
    55  	}
    56  
    57  	for i := range s.certificates {
    58  		if !bytes.Equal(s.certificates[i], s1.certificates[i]) {
    59  			return false
    60  		}
    61  	}
    62  
    63  	return true
    64  }
    65  
    66  func (s *sessionState) marshal() []byte {
    67  	length := 2 + 2 + 2 + len(s.masterSecret) + 2
    68  	for _, cert := range s.certificates {
    69  		length += 4 + len(cert)
    70  	}
    71  
    72  	ret := make([]byte, length)
    73  	x := ret
    74  	x[0] = byte(s.vers >> 8)
    75  	x[1] = byte(s.vers)
    76  	x[2] = byte(s.cipherSuite >> 8)
    77  	x[3] = byte(s.cipherSuite)
    78  	x[4] = byte(len(s.masterSecret) >> 8)
    79  	x[5] = byte(len(s.masterSecret))
    80  	x = x[6:]
    81  	copy(x, s.masterSecret)
    82  	x = x[len(s.masterSecret):]
    83  
    84  	x[0] = byte(len(s.certificates) >> 8)
    85  	x[1] = byte(len(s.certificates))
    86  	x = x[2:]
    87  
    88  	for _, cert := range s.certificates {
    89  		x[0] = byte(len(cert) >> 24)
    90  		x[1] = byte(len(cert) >> 16)
    91  		x[2] = byte(len(cert) >> 8)
    92  		x[3] = byte(len(cert))
    93  		copy(x[4:], cert)
    94  		x = x[4+len(cert):]
    95  	}
    96  
    97  	return ret
    98  }
    99  
   100  func (s *sessionState) unmarshal(data []byte) bool {
   101  	if len(data) < 8 {
   102  		return false
   103  	}
   104  
   105  	s.vers = uint16(data[0])<<8 | uint16(data[1])
   106  	s.cipherSuite = uint16(data[2])<<8 | uint16(data[3])
   107  	masterSecretLen := int(data[4])<<8 | int(data[5])
   108  	data = data[6:]
   109  	if len(data) < masterSecretLen {
   110  		return false
   111  	}
   112  
   113  	s.masterSecret = data[:masterSecretLen]
   114  	data = data[masterSecretLen:]
   115  
   116  	if len(data) < 2 {
   117  		return false
   118  	}
   119  
   120  	numCerts := int(data[0])<<8 | int(data[1])
   121  	data = data[2:]
   122  
   123  	s.certificates = make([][]byte, numCerts)
   124  	for i := range s.certificates {
   125  		if len(data) < 4 {
   126  			return false
   127  		}
   128  		certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
   129  		data = data[4:]
   130  		if certLen < 0 {
   131  			return false
   132  		}
   133  		if len(data) < certLen {
   134  			return false
   135  		}
   136  		s.certificates[i] = data[:certLen]
   137  		data = data[certLen:]
   138  	}
   139  
   140  	return len(data) == 0
   141  }
   142  
   143  func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) {
   144  	serialized := state.marshal()
   145  	encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(serialized)+sha256.Size)
   146  	keyName := encrypted[:ticketKeyNameLen]
   147  	iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
   148  	macBytes := encrypted[len(encrypted)-sha256.Size:]
   149  
   150  	if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
   151  		return nil, err
   152  	}
   153  	key := c.config.ticketKeys()[0]
   154  	copy(keyName, key.keyName[:])
   155  	block, err := aes.NewCipher(key.aesKey[:])
   156  	if err != nil {
   157  		return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
   158  	}
   159  	cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], serialized)
   160  
   161  	mac := hmac.New(sha256.New, key.hmacKey[:])
   162  	mac.Write(encrypted[:len(encrypted)-sha256.Size])
   163  	mac.Sum(macBytes[:0])
   164  
   165  	return encrypted, nil
   166  }
   167  
   168  func (c *Conn) decryptTicket(encrypted []byte) (*sessionState, bool) {
   169  	if c.config.SessionTicketsDisabled ||
   170  		len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
   171  		return nil, false
   172  	}
   173  
   174  	keyName := encrypted[:ticketKeyNameLen]
   175  	iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
   176  	macBytes := encrypted[len(encrypted)-sha256.Size:]
   177  
   178  	keys := c.config.ticketKeys()
   179  	keyIndex := -1
   180  	for i, candidateKey := range keys {
   181  		if bytes.Equal(keyName, candidateKey.keyName[:]) {
   182  			keyIndex = i
   183  			break
   184  		}
   185  	}
   186  
   187  	if keyIndex == -1 {
   188  		return nil, false
   189  	}
   190  	key := &keys[keyIndex]
   191  
   192  	mac := hmac.New(sha256.New, key.hmacKey[:])
   193  	mac.Write(encrypted[:len(encrypted)-sha256.Size])
   194  	expected := mac.Sum(nil)
   195  
   196  	if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
   197  		return nil, false
   198  	}
   199  
   200  	block, err := aes.NewCipher(key.aesKey[:])
   201  	if err != nil {
   202  		return nil, false
   203  	}
   204  	ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
   205  	plaintext := ciphertext
   206  	cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
   207  
   208  	state := &sessionState{usedOldKey: keyIndex > 0}
   209  	ok := state.unmarshal(plaintext)
   210  	return state, ok
   211  }