github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/gmtls/ticket.go (about)

     1  // Copyright 2022 s1ren@github.com/hxx258456.
     2  
     3  /*
     4  gmtls是基于`golang/go`的`tls`包实现的国密改造版本。
     5  对应版权声明: thrid_licenses/github.com/golang/go/LICENSE
     6  */
     7  
     8  package gmtls
     9  
    10  import (
    11  	"bytes"
    12  	"crypto/cipher"
    13  	"crypto/hmac"
    14  	"crypto/subtle"
    15  	"errors"
    16  	"io"
    17  
    18  	"github.com/hxx258456/ccgo/sm3"
    19  	"github.com/hxx258456/ccgo/sm4"
    20  	"golang.org/x/crypto/cryptobyte"
    21  )
    22  
    23  // sessionState contains the information that is serialized into a session
    24  // ticket in order to later resume a connection.
    25  type sessionState struct {
    26  	vers         uint16
    27  	cipherSuite  uint16
    28  	createdAt    uint64
    29  	masterSecret []byte // opaque master_secret<1..2^16-1>;
    30  	// struct { opaque certificate<1..2^24-1> } Certificate;
    31  	certificates [][]byte // Certificate certificate_list<0..2^24-1>;
    32  
    33  	// usedOldKey is true if the ticket from which this session came from
    34  	// was encrypted with an older key and thus should be refreshed.
    35  	usedOldKey bool
    36  }
    37  
    38  func (m *sessionState) marshal() []byte {
    39  	var b cryptobyte.Builder
    40  	b.AddUint16(m.vers)
    41  	b.AddUint16(m.cipherSuite)
    42  	addUint64(&b, m.createdAt)
    43  	b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
    44  		b.AddBytes(m.masterSecret)
    45  	})
    46  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
    47  		for _, cert := range m.certificates {
    48  			b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
    49  				b.AddBytes(cert)
    50  			})
    51  		}
    52  	})
    53  	return b.BytesOrPanic()
    54  }
    55  
    56  func (m *sessionState) unmarshal(data []byte) bool {
    57  	*m = sessionState{usedOldKey: m.usedOldKey}
    58  	s := cryptobyte.String(data)
    59  	if ok := s.ReadUint16(&m.vers) &&
    60  		s.ReadUint16(&m.cipherSuite) &&
    61  		readUint64(&s, &m.createdAt) &&
    62  		readUint16LengthPrefixed(&s, &m.masterSecret) &&
    63  		len(m.masterSecret) != 0; !ok {
    64  		return false
    65  	}
    66  	var certList cryptobyte.String
    67  	if !s.ReadUint24LengthPrefixed(&certList) {
    68  		return false
    69  	}
    70  	for !certList.Empty() {
    71  		var cert []byte
    72  		if !readUint24LengthPrefixed(&certList, &cert) {
    73  			return false
    74  		}
    75  		m.certificates = append(m.certificates, cert)
    76  	}
    77  	return s.Empty()
    78  }
    79  
    80  // sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first
    81  // version (revision = 0) doesn't carry any of the information needed for 0-RTT
    82  // validation and the nonce is always empty.
    83  type sessionStateTLS13 struct {
    84  	// uint8 version  = 0x0304;
    85  	// uint8 revision = 0;
    86  	cipherSuite      uint16
    87  	createdAt        uint64
    88  	resumptionSecret []byte      // opaque resumption_master_secret<1..2^8-1>;
    89  	certificate      Certificate // CertificateEntry certificate_list<0..2^24-1>;
    90  }
    91  
    92  func (m *sessionStateTLS13) marshal() []byte {
    93  	var b cryptobyte.Builder
    94  	b.AddUint16(VersionTLS13)
    95  	b.AddUint8(0) // revision
    96  	b.AddUint16(m.cipherSuite)
    97  	addUint64(&b, m.createdAt)
    98  	b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
    99  		b.AddBytes(m.resumptionSecret)
   100  	})
   101  	marshalCertificate(&b, m.certificate)
   102  	return b.BytesOrPanic()
   103  }
   104  
   105  func (m *sessionStateTLS13) unmarshal(data []byte) bool {
   106  	*m = sessionStateTLS13{}
   107  	s := cryptobyte.String(data)
   108  	var version uint16
   109  	var revision uint8
   110  	return s.ReadUint16(&version) &&
   111  		version == VersionTLS13 &&
   112  		s.ReadUint8(&revision) &&
   113  		revision == 0 &&
   114  		s.ReadUint16(&m.cipherSuite) &&
   115  		readUint64(&s, &m.createdAt) &&
   116  		readUint8LengthPrefixed(&s, &m.resumptionSecret) &&
   117  		len(m.resumptionSecret) != 0 &&
   118  		unmarshalCertificate(&s, &m.certificate) &&
   119  		s.Empty()
   120  }
   121  
   122  // 会话票据加密
   123  //  golang原来的实现是用aes加密,sha256散列,国密改造改为 sm4 + sm3
   124  func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
   125  	if len(c.ticketKeys) == 0 {
   126  		return nil, errors.New("gmtls: internal error: session ticket keys unavailable")
   127  	}
   128  	// encrypted : ticketKeyName(16) + iv(16) + state对称加密结果 + 散列(32)
   129  	// encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size)
   130  	encrypted := make([]byte, ticketKeyNameLen+sm4.BlockSize+len(state)+sm3.Size)
   131  	// 前16个字节放ticketKeyName
   132  	keyName := encrypted[:ticketKeyNameLen]
   133  	// 16~32 放iv
   134  	iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+sm4.BlockSize]
   135  	// 最后32个字节放mac认证码
   136  	macBytes := encrypted[len(encrypted)-sm3.Size:]
   137  	// 生成随机字节数组填入iv
   138  	if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
   139  		return nil, err
   140  	}
   141  	// 当前连接的ticketKeys在前面读取ClientHello之后的处理中已经初始化。
   142  	// 这里拿到第一个ticketKey。
   143  	key := c.ticketKeys[0]
   144  	// 填入keyname
   145  	copy(keyName, key.keyName[:])
   146  	block, err := sm4.NewCipher(key.sm4Key[:])
   147  	if err != nil {
   148  		return nil, errors.New("gmtls: failed to create cipher while encrypting ticket: " + err.Error())
   149  	}
   150  	// encrypted的 32 ~ 倒数32 填入state对称加密结果
   151  	cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+sm4.BlockSize:], state)
   152  	// 使用sm3作为mac认证码函数
   153  	mac := hmac.New(sm3.New, key.hmacKey[:])
   154  	// 写入 encrypted 前三部分内容: ticketKeyName(16) + iv(16) + state对称加密结果
   155  	mac.Write(encrypted[:len(encrypted)-sm3.Size])
   156  	// 生成认证码填入macBytes
   157  	mac.Sum(macBytes[:0])
   158  
   159  	return encrypted, nil
   160  }
   161  
   162  // 会话票据解密
   163  //  golang原来的实现是用aes加密,sha256散列,国密改造改为 sm4 + sm3
   164  func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) {
   165  	if len(encrypted) < ticketKeyNameLen+sm4.BlockSize+sm3.Size {
   166  		return nil, false
   167  	}
   168  	// 获取keyname
   169  	keyName := encrypted[:ticketKeyNameLen]
   170  	// 获取iv
   171  	iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+sm4.BlockSize]
   172  	// 获取认证码
   173  	macBytes := encrypted[len(encrypted)-sm3.Size:]
   174  	// 获取秘文
   175  	ciphertext := encrypted[ticketKeyNameLen+sm4.BlockSize : len(encrypted)-sm3.Size]
   176  	// 根据keyname获取key
   177  	keyIndex := -1
   178  	for i, candidateKey := range c.ticketKeys {
   179  		if bytes.Equal(keyName, candidateKey.keyName[:]) {
   180  			keyIndex = i
   181  			break
   182  		}
   183  	}
   184  	if keyIndex == -1 {
   185  		return nil, false
   186  	}
   187  	key := &c.ticketKeys[keyIndex]
   188  	// 重新生成认证码
   189  	mac := hmac.New(sm3.New, key.hmacKey[:])
   190  	mac.Write(encrypted[:len(encrypted)-sm3.Size])
   191  	expected := mac.Sum(nil)
   192  	// 比较认证码
   193  	if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
   194  		return nil, false
   195  	}
   196  	// 对称解密
   197  	block, err := sm4.NewCipher(key.sm4Key[:])
   198  	if err != nil {
   199  		return nil, false
   200  	}
   201  	plaintext = make([]byte, len(ciphertext))
   202  	cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
   203  
   204  	return plaintext, keyIndex > 0
   205  }