github.com/Psiphon-Labs/tls-tris@v0.0.0-20230824155421-58bf6d336a9a/key_agreement.go (about) 1 // Copyright 2010 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "crypto" 9 "crypto/elliptic" 10 "crypto/md5" 11 "crypto/rsa" 12 "crypto/sha1" 13 "errors" 14 "io" 15 "math/big" 16 17 "golang.org/x/crypto/curve25519" 18 ) 19 20 var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message") 21 var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message") 22 23 // rsaKeyAgreement implements the standard TLS key agreement where the client 24 // encrypts the pre-master secret to the server's public key. 25 type rsaKeyAgreement struct{} 26 27 func (ka rsaKeyAgreement) generateServerKeyExchange(config *Config, sk crypto.PrivateKey, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { 28 return nil, nil 29 } 30 31 func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, sk crypto.PrivateKey, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { 32 if len(ckx.ciphertext) < 2 { 33 return nil, errClientKeyExchange 34 } 35 36 ciphertext := ckx.ciphertext 37 if version != VersionSSL30 { 38 ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1]) 39 if ciphertextLen != len(ckx.ciphertext)-2 { 40 return nil, errClientKeyExchange 41 } 42 ciphertext = ckx.ciphertext[2:] 43 } 44 priv, ok := sk.(crypto.Decrypter) 45 if !ok { 46 return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter") 47 } 48 // Perform constant time RSA PKCS#1 v1.5 decryption 49 preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48}) 50 if err != nil { 51 return nil, err 52 } 53 // We don't check the version number in the premaster secret. For one, 54 // by checking it, we would leak information about the validity of the 55 // encrypted pre-master secret. Secondly, it provides only a small 56 // benefit against a downgrade attack and some implementations send the 57 // wrong version anyway. See the discussion at the end of section 58 // 7.4.7.1 of RFC 4346. 59 return preMasterSecret, nil 60 } 61 62 func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, pk crypto.PublicKey, skx *serverKeyExchangeMsg) error { 63 return errors.New("tls: unexpected ServerKeyExchange") 64 } 65 66 func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, pk crypto.PublicKey) ([]byte, *clientKeyExchangeMsg, error) { 67 preMasterSecret := make([]byte, 48) 68 preMasterSecret[0] = byte(clientHello.vers >> 8) 69 preMasterSecret[1] = byte(clientHello.vers) 70 _, err := io.ReadFull(config.rand(), preMasterSecret[2:]) 71 if err != nil { 72 return nil, nil, err 73 } 74 75 // [Psiphon] 76 // Backport fix: https://github.com/golang/go/commit/58bc454a11d4b3dbc03f44dfcabb9068a9c076f4 77 rsaKey, ok := pk.(*rsa.PublicKey) 78 if !ok { 79 return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite") 80 } 81 encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret) 82 if err != nil { 83 return nil, nil, err 84 } 85 ckx := new(clientKeyExchangeMsg) 86 ckx.ciphertext = make([]byte, len(encrypted)+2) 87 ckx.ciphertext[0] = byte(len(encrypted) >> 8) 88 ckx.ciphertext[1] = byte(len(encrypted)) 89 copy(ckx.ciphertext[2:], encrypted) 90 return preMasterSecret, ckx, nil 91 } 92 93 // sha1Hash calculates a SHA1 hash over the given byte slices. 94 func sha1Hash(slices [][]byte) []byte { 95 hsha1 := sha1.New() 96 for _, slice := range slices { 97 hsha1.Write(slice) 98 } 99 return hsha1.Sum(nil) 100 } 101 102 // md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the 103 // concatenation of an MD5 and SHA1 hash. 104 func md5SHA1Hash(slices [][]byte) []byte { 105 md5sha1 := make([]byte, md5.Size+sha1.Size) 106 hmd5 := md5.New() 107 for _, slice := range slices { 108 hmd5.Write(slice) 109 } 110 copy(md5sha1, hmd5.Sum(nil)) 111 copy(md5sha1[md5.Size:], sha1Hash(slices)) 112 return md5sha1 113 } 114 115 // hashForServerKeyExchange hashes the given slices and returns their digest 116 // using the given hash function. 117 func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) ([]byte, error) { 118 if version >= VersionTLS12 { 119 h := hashFunc.New() 120 for _, slice := range slices { 121 h.Write(slice) 122 } 123 digest := h.Sum(nil) 124 return digest, nil 125 } 126 if sigType == signatureECDSA { 127 return sha1Hash(slices), nil 128 } 129 return md5SHA1Hash(slices), nil 130 } 131 132 func curveForCurveID(id CurveID) (elliptic.Curve, bool) { 133 switch id { 134 case CurveP256: 135 return elliptic.P256(), true 136 case CurveP384: 137 return elliptic.P384(), true 138 case CurveP521: 139 return elliptic.P521(), true 140 default: 141 return nil, false 142 } 143 144 } 145 146 // ecdheKeyAgreement implements a TLS key agreement where the server 147 // generates an ephemeral EC public/private key pair and signs it. The 148 // pre-master secret is then calculated using ECDH. The signature may 149 // either be ECDSA or RSA. 150 type ecdheKeyAgreement struct { 151 version uint16 152 isRSA bool 153 privateKey []byte 154 curveid CurveID 155 156 // publicKey is used to store the peer's public value when X25519 is 157 // being used. 158 publicKey []byte 159 // x and y are used to store the peer's public value when one of the 160 // NIST curves is being used. 161 x, y *big.Int 162 } 163 164 func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, sk crypto.PrivateKey, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { 165 preferredCurves := config.curvePreferences() 166 167 NextCandidate: 168 for _, candidate := range preferredCurves { 169 for _, c := range clientHello.supportedCurves { 170 if candidate == c { 171 ka.curveid = c 172 break NextCandidate 173 } 174 } 175 } 176 177 if ka.curveid == 0 { 178 return nil, errors.New("tls: no supported elliptic curves offered") 179 } 180 181 var ecdhePublic []byte 182 183 if ka.curveid == X25519 { 184 var scalar, public [32]byte 185 if _, err := io.ReadFull(config.rand(), scalar[:]); err != nil { 186 return nil, err 187 } 188 189 curve25519.ScalarBaseMult(&public, &scalar) 190 ka.privateKey = scalar[:] 191 ecdhePublic = public[:] 192 } else { 193 curve, ok := curveForCurveID(ka.curveid) 194 if !ok { 195 return nil, errors.New("tls: preferredCurves includes unsupported curve") 196 } 197 198 var x, y *big.Int 199 var err error 200 ka.privateKey, x, y, err = elliptic.GenerateKey(curve, config.rand()) 201 if err != nil { 202 return nil, err 203 } 204 ecdhePublic = elliptic.Marshal(curve, x, y) 205 } 206 207 // http://tools.ietf.org/html/rfc4492#section-5.4 208 serverECDHParams := make([]byte, 1+2+1+len(ecdhePublic)) 209 serverECDHParams[0] = 3 // named curve 210 serverECDHParams[1] = byte(ka.curveid >> 8) 211 serverECDHParams[2] = byte(ka.curveid) 212 serverECDHParams[3] = byte(len(ecdhePublic)) 213 copy(serverECDHParams[4:], ecdhePublic) 214 215 priv, ok := sk.(crypto.Signer) 216 if !ok { 217 return nil, errors.New("tls: certificate private key does not implement crypto.Signer") 218 } 219 220 signatureAlgorithm, sigType, hashFunc, err := pickSignatureAlgorithm(priv.Public(), clientHello.supportedSignatureAlgorithms, supportedSignatureAlgorithms, ka.version) 221 if err != nil { 222 return nil, err 223 } 224 if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { 225 return nil, errors.New("tls: certificate cannot be used with the selected cipher suite") 226 } 227 228 digest, err := hashForServerKeyExchange(sigType, hashFunc, ka.version, clientHello.random, hello.random, serverECDHParams) 229 if err != nil { 230 return nil, err 231 } 232 233 var sig []byte 234 signOpts := crypto.SignerOpts(hashFunc) 235 if sigType == signatureRSAPSS { 236 signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hashFunc} 237 } 238 sig, err = priv.Sign(config.rand(), digest, signOpts) 239 if err != nil { 240 return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error()) 241 } 242 243 skx := new(serverKeyExchangeMsg) 244 sigAndHashLen := 0 245 if ka.version >= VersionTLS12 { 246 sigAndHashLen = 2 247 } 248 skx.key = make([]byte, len(serverECDHParams)+sigAndHashLen+2+len(sig)) 249 copy(skx.key, serverECDHParams) 250 k := skx.key[len(serverECDHParams):] 251 if ka.version >= VersionTLS12 { 252 k[0] = byte(signatureAlgorithm >> 8) 253 k[1] = byte(signatureAlgorithm) 254 k = k[2:] 255 } 256 k[0] = byte(len(sig) >> 8) 257 k[1] = byte(len(sig)) 258 copy(k[2:], sig) 259 260 return skx, nil 261 } 262 263 func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, sk crypto.PrivateKey, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { 264 if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 { 265 return nil, errClientKeyExchange 266 } 267 268 if ka.curveid == X25519 { 269 if len(ckx.ciphertext) != 1+32 { 270 return nil, errClientKeyExchange 271 } 272 273 var theirPublic, sharedKey, scalar [32]byte 274 copy(theirPublic[:], ckx.ciphertext[1:]) 275 copy(scalar[:], ka.privateKey) 276 curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic) 277 return sharedKey[:], nil 278 } 279 280 curve, ok := curveForCurveID(ka.curveid) 281 if !ok { 282 panic("internal error") 283 } 284 x, y := elliptic.Unmarshal(curve, ckx.ciphertext[1:]) // Unmarshal also checks whether the given point is on the curve 285 if x == nil { 286 return nil, errClientKeyExchange 287 } 288 x, _ = curve.ScalarMult(x, y, ka.privateKey) 289 curveSize := (curve.Params().BitSize + 7) >> 3 290 xBytes := x.Bytes() 291 if len(xBytes) == curveSize { 292 return xBytes, nil 293 } 294 preMasterSecret := make([]byte, curveSize) 295 copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes) 296 return preMasterSecret, nil 297 } 298 299 func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, pk crypto.PublicKey, skx *serverKeyExchangeMsg) error { 300 if len(skx.key) < 4 { 301 return errServerKeyExchange 302 } 303 if skx.key[0] != 3 { // named curve 304 return errors.New("tls: server selected unsupported curve") 305 } 306 ka.curveid = CurveID(skx.key[1])<<8 | CurveID(skx.key[2]) 307 308 publicLen := int(skx.key[3]) 309 if publicLen+4 > len(skx.key) { 310 return errServerKeyExchange 311 } 312 serverECDHParams := skx.key[:4+publicLen] 313 publicKey := serverECDHParams[4:] 314 315 sig := skx.key[4+publicLen:] 316 if len(sig) < 2 { 317 return errServerKeyExchange 318 } 319 320 if ka.curveid == X25519 { 321 if len(publicKey) != 32 { 322 return errors.New("tls: bad X25519 public value") 323 } 324 ka.publicKey = publicKey 325 } else { 326 curve, ok := curveForCurveID(ka.curveid) 327 if !ok { 328 return errors.New("tls: server selected unsupported curve") 329 } 330 ka.x, ka.y = elliptic.Unmarshal(curve, publicKey) // Unmarshal also checks whether the given point is on the curve 331 if ka.x == nil { 332 return errServerKeyExchange 333 } 334 } 335 336 var signatureAlgorithm SignatureScheme 337 if ka.version >= VersionTLS12 { 338 // handle SignatureAndHashAlgorithm 339 signatureAlgorithm = SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1]) 340 sig = sig[2:] 341 if len(sig) < 2 { 342 return errServerKeyExchange 343 } 344 } 345 _, sigType, hashFunc, err := pickSignatureAlgorithm(pk, []SignatureScheme{signatureAlgorithm}, clientHello.supportedSignatureAlgorithms, ka.version) 346 if err != nil { 347 return err 348 } 349 if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { 350 return errServerKeyExchange 351 } 352 353 sigLen := int(sig[0])<<8 | int(sig[1]) 354 if sigLen+2 != len(sig) { 355 return errServerKeyExchange 356 } 357 sig = sig[2:] 358 359 digest, err := hashForServerKeyExchange(sigType, hashFunc, ka.version, clientHello.random, serverHello.random, serverECDHParams) 360 if err != nil { 361 return err 362 } 363 return verifyHandshakeSignature(sigType, pk, hashFunc, digest, sig) 364 } 365 366 func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, pk crypto.PublicKey) ([]byte, *clientKeyExchangeMsg, error) { 367 if ka.curveid == 0 { 368 return nil, nil, errors.New("tls: missing ServerKeyExchange message") 369 } 370 371 var serialized, preMasterSecret []byte 372 373 if ka.curveid == X25519 { 374 var ourPublic, theirPublic, sharedKey, scalar [32]byte 375 376 if _, err := io.ReadFull(config.rand(), scalar[:]); err != nil { 377 return nil, nil, err 378 } 379 380 copy(theirPublic[:], ka.publicKey) 381 curve25519.ScalarBaseMult(&ourPublic, &scalar) 382 curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic) 383 serialized = ourPublic[:] 384 preMasterSecret = sharedKey[:] 385 } else { 386 curve, ok := curveForCurveID(ka.curveid) 387 if !ok { 388 panic("internal error") 389 } 390 priv, mx, my, err := elliptic.GenerateKey(curve, config.rand()) 391 if err != nil { 392 return nil, nil, err 393 } 394 x, _ := curve.ScalarMult(ka.x, ka.y, priv) 395 preMasterSecret = make([]byte, (curve.Params().BitSize+7)>>3) 396 xBytes := x.Bytes() 397 copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes) 398 399 serialized = elliptic.Marshal(curve, mx, my) 400 } 401 402 ckx := new(clientKeyExchangeMsg) 403 ckx.ciphertext = make([]byte, 1+len(serialized)) 404 ckx.ciphertext[0] = byte(len(serialized)) 405 copy(ckx.ciphertext[1:], serialized) 406 407 return preMasterSecret, ckx, nil 408 }