github.com/hellobchain/newcryptosm@v0.0.0-20221019060107-edb949a317e9/tls/prf.go (about)

     1  /*
     2  Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7  	http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  package tls
    17  
    18  import (
    19  	"crypto/hmac"
    20  	"crypto/md5"
    21  	"crypto/sha1"
    22  	"crypto/sha256"
    23  	"errors"
    24  	"github.com/hellobchain/newcryptosm"
    25  	"github.com/hellobchain/newcryptosm/sm3"
    26  	"hash"
    27  )
    28  
    29  // Split a premaster secret in two as specified in RFC 4346, section 5.
    30  func splitPreMasterSecret(secret []byte) (s1, s2 []byte) {
    31  	s1 = secret[0 : (len(secret)+1)/2]
    32  	s2 = secret[len(secret)/2:]
    33  	return
    34  }
    35  
    36  // pHash implements the P_hash function, as defined in RFC 4346, section 5.
    37  func pHash(result, secret, seed []byte, hash func() hash.Hash) {
    38  	h := hmac.New(hash, secret)
    39  	h.Write(seed)
    40  	a := h.Sum(nil)
    41  
    42  	j := 0
    43  	for j < len(result) {
    44  		h.Reset()
    45  		h.Write(a)
    46  		h.Write(seed)
    47  		b := h.Sum(nil)
    48  		todo := len(b)
    49  		if j+todo > len(result) {
    50  			todo = len(result) - j
    51  		}
    52  		copy(result[j:j+todo], b)
    53  		j += todo
    54  
    55  		h.Reset()
    56  		h.Write(a)
    57  		a = h.Sum(nil)
    58  	}
    59  }
    60  
    61  // prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, section 5.
    62  func prf10(result, secret, label, seed []byte) {
    63  	hashSHA1 := sha1.New
    64  	hashMD5 := md5.New
    65  
    66  	labelAndSeed := make([]byte, len(label)+len(seed))
    67  	copy(labelAndSeed, label)
    68  	copy(labelAndSeed[len(label):], seed)
    69  
    70  	s1, s2 := splitPreMasterSecret(secret)
    71  	pHash(result, s1, labelAndSeed, hashMD5)
    72  	result2 := make([]byte, len(result))
    73  	pHash(result2, s2, labelAndSeed, hashSHA1)
    74  
    75  	for i, b := range result2 {
    76  		result[i] ^= b
    77  	}
    78  }
    79  
    80  // prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, section 5.
    81  func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) {
    82  	return func(result, secret, label, seed []byte) {
    83  		labelAndSeed := make([]byte, len(label)+len(seed))
    84  		copy(labelAndSeed, label)
    85  		copy(labelAndSeed[len(label):], seed)
    86  
    87  		pHash(result, secret, labelAndSeed, hashFunc)
    88  	}
    89  }
    90  
    91  // prf30 implements the SSL 3.0 pseudo-random function, as defined in
    92  // www.mozilla.org/projects/security/pki/nss/ssl/draft302.txt section 6.
    93  func prf30(result, secret, label, seed []byte) {
    94  	hashSHA1 := sha1.New()
    95  	hashMD5 := md5.New()
    96  
    97  	done := 0
    98  	i := 0
    99  	// RFC 5246 section 6.3 says that the largest PRF output needed is 128
   100  	// bytes. Since no more ciphersuites will be added to SSLv3, this will
   101  	// remain true. Each iteration gives us 16 bytes so 10 iterations will
   102  	// be sufficient.
   103  	var b [11]byte
   104  	for done < len(result) {
   105  		for j := 0; j <= i; j++ {
   106  			b[j] = 'A' + byte(i)
   107  		}
   108  
   109  		hashSHA1.Reset()
   110  		hashSHA1.Write(b[:i+1])
   111  		hashSHA1.Write(secret)
   112  		hashSHA1.Write(seed)
   113  		digest := hashSHA1.Sum(nil)
   114  
   115  		hashMD5.Reset()
   116  		hashMD5.Write(secret)
   117  		hashMD5.Write(digest)
   118  
   119  		done += copy(result[done:], hashMD5.Sum(nil))
   120  		i++
   121  	}
   122  }
   123  
   124  const (
   125  	tlsRandomLength      = 32 // Length of a random nonce in TLS 1.1.
   126  	masterSecretLength   = 48 // Length of a master secret in TLS 1.1.
   127  	finishedVerifyLength = 12 // Length of verify_data in a Finished message.
   128  )
   129  
   130  var masterSecretLabel = []byte("master secret")
   131  var keyExpansionLabel = []byte("key expansion")
   132  var clientFinishedLabel = []byte("client finished")
   133  var serverFinishedLabel = []byte("server finished")
   134  
   135  func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), newcryptosm.Hash) {
   136  	switch version {
   137  	case VersionSSL30:
   138  		if suite.flags&suiteSM2 != 0 {
   139  			return prf12(sm3.New), newcryptosm.SM3
   140  		}
   141  		return prf30, newcryptosm.Hash(0)
   142  	case VersionTLS10, VersionTLS11:
   143  		if suite.flags&suiteSM2 != 0 {
   144  			return prf12(sm3.New), newcryptosm.SM3
   145  		}
   146  		return prf10, newcryptosm.Hash(0)
   147  	case VersionTLS12:
   148  		if suite.flags&suiteSM2 != 0 {
   149  			return prf12(sm3.New), newcryptosm.SM3
   150  		}
   151  		return prf12(sha256.New), newcryptosm.SHA256
   152  	default:
   153  		panic("unknown version")
   154  	}
   155  }
   156  
   157  func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) {
   158  	prf, _ := prfAndHashForVersion(version, suite)
   159  	return prf
   160  }
   161  
   162  // masterFromPreMasterSecret generates the master secret from the pre-master
   163  // secret. See http://tools.ietf.org/html/rfc5246#section-8.1
   164  func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte {
   165  	seed := make([]byte, 0, len(clientRandom)+len(serverRandom))
   166  	seed = append(seed, clientRandom...)
   167  	seed = append(seed, serverRandom...)
   168  
   169  	masterSecret := make([]byte, masterSecretLength)
   170  	prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed)
   171  	return masterSecret
   172  }
   173  
   174  // keysFromMasterSecret generates the connection keys from the master
   175  // secret, given the lengths of the MAC key, cipher key and IV, as defined in
   176  // RFC 2246, section 6.3.
   177  func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) {
   178  	seed := make([]byte, 0, len(serverRandom)+len(clientRandom))
   179  	seed = append(seed, serverRandom...)
   180  	seed = append(seed, clientRandom...)
   181  
   182  	n := 2*macLen + 2*keyLen + 2*ivLen
   183  	keyMaterial := make([]byte, n)
   184  	prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed)
   185  	clientMAC = keyMaterial[:macLen]
   186  	keyMaterial = keyMaterial[macLen:]
   187  	serverMAC = keyMaterial[:macLen]
   188  	keyMaterial = keyMaterial[macLen:]
   189  	clientKey = keyMaterial[:keyLen]
   190  	keyMaterial = keyMaterial[keyLen:]
   191  	serverKey = keyMaterial[:keyLen]
   192  	keyMaterial = keyMaterial[keyLen:]
   193  	clientIV = keyMaterial[:ivLen]
   194  	keyMaterial = keyMaterial[ivLen:]
   195  	serverIV = keyMaterial[:ivLen]
   196  	return
   197  }
   198  
   199  // lookupTLSHash looks up the corresponding crypto.Hash for a given
   200  // TLS hash identifier.
   201  func lookupTLSHash(hash uint8) (newcryptosm.Hash, error) {
   202  	switch hash {
   203  	case hashSHA1:
   204  		return newcryptosm.SHA1, nil
   205  	case hashSHA256:
   206  		return newcryptosm.SHA256, nil
   207  	case hashSHA384:
   208  		return newcryptosm.SHA384, nil
   209  	case hashSM3:
   210  		return newcryptosm.SM3, nil
   211  	default:
   212  		return 0, errors.New("tls: unsupported hash algorithm")
   213  	}
   214  }
   215  
   216  func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash {
   217  	var buffer []byte
   218  	if version == VersionSSL30 || version >= VersionTLS12 {
   219  		buffer = []byte{}
   220  	}
   221  
   222  	prf, hash := prfAndHashForVersion(version, cipherSuite)
   223  	if hash != 0 {
   224  		return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf}
   225  	}
   226  
   227  	return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf}
   228  }
   229  
   230  // A finishedHash calculates the hash of a set of handshake messages suitable
   231  // for including in a Finished message.
   232  type finishedHash struct {
   233  	client hash.Hash
   234  	server hash.Hash
   235  
   236  	// Prior to TLS 1.2, an additional MD5 hash is required.
   237  	clientMD5 hash.Hash
   238  	serverMD5 hash.Hash
   239  
   240  	// In TLS 1.2, a full buffer is sadly required.
   241  	buffer []byte
   242  
   243  	version uint16
   244  	prf     func(result, secret, label, seed []byte)
   245  }
   246  
   247  func (h *finishedHash) Write(msg []byte) (n int, err error) {
   248  	h.client.Write(msg)
   249  	h.server.Write(msg)
   250  
   251  	if h.version < VersionTLS12 {
   252  		h.clientMD5.Write(msg)
   253  		h.serverMD5.Write(msg)
   254  	}
   255  
   256  	if h.buffer != nil {
   257  		h.buffer = append(h.buffer, msg...)
   258  	}
   259  
   260  	return len(msg), nil
   261  }
   262  
   263  func (h finishedHash) Sum() []byte {
   264  	if h.version >= VersionTLS12 {
   265  		return h.client.Sum(nil)
   266  	}
   267  
   268  	out := make([]byte, 0, md5.Size+sha1.Size)
   269  	out = h.clientMD5.Sum(out)
   270  	return h.client.Sum(out)
   271  }
   272  
   273  // finishedSum30 calculates the contents of the verify_data member of a SSLv3
   274  // Finished message given the MD5 and SHA1 hashes of a set of handshake
   275  // messages.
   276  func finishedSum30(md5, sha1 hash.Hash, masterSecret []byte, magic []byte) []byte {
   277  	md5.Write(magic)
   278  	md5.Write(masterSecret)
   279  	md5.Write(ssl30Pad1[:])
   280  	md5Digest := md5.Sum(nil)
   281  
   282  	md5.Reset()
   283  	md5.Write(masterSecret)
   284  	md5.Write(ssl30Pad2[:])
   285  	md5.Write(md5Digest)
   286  	md5Digest = md5.Sum(nil)
   287  
   288  	sha1.Write(magic)
   289  	sha1.Write(masterSecret)
   290  	sha1.Write(ssl30Pad1[:40])
   291  	sha1Digest := sha1.Sum(nil)
   292  
   293  	sha1.Reset()
   294  	sha1.Write(masterSecret)
   295  	sha1.Write(ssl30Pad2[:40])
   296  	sha1.Write(sha1Digest)
   297  	sha1Digest = sha1.Sum(nil)
   298  
   299  	ret := make([]byte, len(md5Digest)+len(sha1Digest))
   300  	copy(ret, md5Digest)
   301  	copy(ret[len(md5Digest):], sha1Digest)
   302  	return ret
   303  }
   304  
   305  var ssl3ClientFinishedMagic = [4]byte{0x43, 0x4c, 0x4e, 0x54}
   306  var ssl3ServerFinishedMagic = [4]byte{0x53, 0x52, 0x56, 0x52}
   307  
   308  // clientSum returns the contents of the verify_data member of a client's
   309  // Finished message.
   310  func (h finishedHash) clientSum(masterSecret []byte) []byte {
   311  	if h.version == VersionSSL30 {
   312  		return finishedSum30(h.clientMD5, h.client, masterSecret, ssl3ClientFinishedMagic[:])
   313  	}
   314  
   315  	out := make([]byte, finishedVerifyLength)
   316  	h.prf(out, masterSecret, clientFinishedLabel, h.Sum())
   317  	return out
   318  }
   319  
   320  // serverSum returns the contents of the verify_data member of a server's
   321  // Finished message.
   322  func (h finishedHash) serverSum(masterSecret []byte) []byte {
   323  	if h.version == VersionSSL30 {
   324  		return finishedSum30(h.serverMD5, h.server, masterSecret, ssl3ServerFinishedMagic[:])
   325  	}
   326  
   327  	out := make([]byte, finishedVerifyLength)
   328  	h.prf(out, masterSecret, serverFinishedLabel, h.Sum())
   329  	return out
   330  }
   331  
   332  // selectClientCertSignatureAlgorithm returns a signatureAndHash to sign a
   333  // client's CertificateVerify with, or an error if none can be found.
   334  func (h finishedHash) selectClientCertSignatureAlgorithm(serverList []signatureAndHash, sigType uint8) (signatureAndHash, error) {
   335  	if h.version < VersionTLS12 {
   336  		// Nothing to negotiate before TLS 1.2.
   337  		return signatureAndHash{signature: sigType}, nil
   338  	}
   339  
   340  	for _, v := range serverList {
   341  		if v.signature == sigType && isSupportedSignatureAndHash(v, supportedSignatureAlgorithms) {
   342  			return v, nil
   343  		}
   344  	}
   345  	return signatureAndHash{}, errors.New("tls: no supported signature algorithm found for signing client certificate")
   346  }
   347  
   348  // hashForClientCertificate returns a digest, hash function, and TLS 1.2 hash
   349  // id suitable for signing by a TLS client certificate.
   350  func (h finishedHash) hashForClientCertificate(signatureAndHash signatureAndHash, masterSecret []byte) ([]byte, newcryptosm.Hash, error) {
   351  	if (h.version == VersionSSL30 || h.version >= VersionTLS12) && h.buffer == nil {
   352  		panic("a handshake hash for a client-certificate was requested after discarding the handshake buffer")
   353  	}
   354  
   355  	if h.version == VersionSSL30 {
   356  		if signatureAndHash.signature != signatureRSA {
   357  			return nil, 0, errors.New("tls: unsupported signature type for client certificate")
   358  		}
   359  
   360  		md5Hash := md5.New()
   361  		md5Hash.Write(h.buffer)
   362  		sha1Hash := sha1.New()
   363  		sha1Hash.Write(h.buffer)
   364  		return finishedSum30(md5Hash, sha1Hash, masterSecret, nil), newcryptosm.MD5SHA1, nil
   365  	}
   366  	if h.version >= VersionTLS12 {
   367  		hashAlg, err := lookupTLSHash(signatureAndHash.hash)
   368  		if err != nil {
   369  			return nil, 0, err
   370  		}
   371  		hash := hashAlg.New()
   372  		hash.Write(h.buffer)
   373  		return hash.Sum(nil), hashAlg, nil
   374  	}
   375  	if signatureAndHash.signature == signatureSM2 {
   376  		return h.server.Sum(nil), newcryptosm.SM3, nil
   377  	}
   378  	if signatureAndHash.signature == signatureECDSA {
   379  		return h.server.Sum(nil), newcryptosm.SHA1, nil
   380  	}
   381  
   382  	return h.Sum(), newcryptosm.MD5SHA1, nil
   383  }
   384  
   385  // discardHandshakeBuffer is called when there is no more need to
   386  // buffer the entirety of the handshake messages.
   387  func (h *finishedHash) discardHandshakeBuffer() {
   388  	h.buffer = nil
   389  }