github.com/mtsmfm/go/src@v0.0.0-20221020090648-44bdcb9f8fde/crypto/rsa/pss.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package rsa 6 7 // This file implements the RSASSA-PSS signature scheme according to RFC 8017. 8 9 import ( 10 "bytes" 11 "crypto" 12 "crypto/internal/boring" 13 "errors" 14 "hash" 15 "io" 16 "math/big" 17 ) 18 19 // Per RFC 8017, Section 9.1 20 // 21 // EM = MGF1 xor DB || H( 8*0x00 || mHash || salt ) || 0xbc 22 // 23 // where 24 // 25 // DB = PS || 0x01 || salt 26 // 27 // and PS can be empty so 28 // 29 // emLen = dbLen + hLen + 1 = psLen + sLen + hLen + 2 30 // 31 32 func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) { 33 // See RFC 8017, Section 9.1.1. 34 35 hLen := hash.Size() 36 sLen := len(salt) 37 emLen := (emBits + 7) / 8 38 39 // 1. If the length of M is greater than the input limitation for the 40 // hash function (2^61 - 1 octets for SHA-1), output "message too 41 // long" and stop. 42 // 43 // 2. Let mHash = Hash(M), an octet string of length hLen. 44 45 if len(mHash) != hLen { 46 return nil, errors.New("crypto/rsa: input must be hashed with given hash") 47 } 48 49 // 3. If emLen < hLen + sLen + 2, output "encoding error" and stop. 50 51 if emLen < hLen+sLen+2 { 52 return nil, errors.New("crypto/rsa: key size too small for PSS signature") 53 } 54 55 em := make([]byte, emLen) 56 psLen := emLen - sLen - hLen - 2 57 db := em[:psLen+1+sLen] 58 h := em[psLen+1+sLen : emLen-1] 59 60 // 4. Generate a random octet string salt of length sLen; if sLen = 0, 61 // then salt is the empty string. 62 // 63 // 5. Let 64 // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt; 65 // 66 // M' is an octet string of length 8 + hLen + sLen with eight 67 // initial zero octets. 68 // 69 // 6. Let H = Hash(M'), an octet string of length hLen. 70 71 var prefix [8]byte 72 73 hash.Write(prefix[:]) 74 hash.Write(mHash) 75 hash.Write(salt) 76 77 h = hash.Sum(h[:0]) 78 hash.Reset() 79 80 // 7. Generate an octet string PS consisting of emLen - sLen - hLen - 2 81 // zero octets. The length of PS may be 0. 82 // 83 // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length 84 // emLen - hLen - 1. 85 86 db[psLen] = 0x01 87 copy(db[psLen+1:], salt) 88 89 // 9. Let dbMask = MGF(H, emLen - hLen - 1). 90 // 91 // 10. Let maskedDB = DB \xor dbMask. 92 93 mgf1XOR(db, hash, h) 94 95 // 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in 96 // maskedDB to zero. 97 98 db[0] &= 0xff >> (8*emLen - emBits) 99 100 // 12. Let EM = maskedDB || H || 0xbc. 101 em[emLen-1] = 0xbc 102 103 // 13. Output EM. 104 return em, nil 105 } 106 107 func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error { 108 // See RFC 8017, Section 9.1.2. 109 110 hLen := hash.Size() 111 if sLen == PSSSaltLengthEqualsHash { 112 sLen = hLen 113 } 114 emLen := (emBits + 7) / 8 115 if emLen != len(em) { 116 return errors.New("rsa: internal error: inconsistent length") 117 } 118 119 // 1. If the length of M is greater than the input limitation for the 120 // hash function (2^61 - 1 octets for SHA-1), output "inconsistent" 121 // and stop. 122 // 123 // 2. Let mHash = Hash(M), an octet string of length hLen. 124 if hLen != len(mHash) { 125 return ErrVerification 126 } 127 128 // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. 129 if emLen < hLen+sLen+2 { 130 return ErrVerification 131 } 132 133 // 4. If the rightmost octet of EM does not have hexadecimal value 134 // 0xbc, output "inconsistent" and stop. 135 if em[emLen-1] != 0xbc { 136 return ErrVerification 137 } 138 139 // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and 140 // let H be the next hLen octets. 141 db := em[:emLen-hLen-1] 142 h := em[emLen-hLen-1 : emLen-1] 143 144 // 6. If the leftmost 8 * emLen - emBits bits of the leftmost octet in 145 // maskedDB are not all equal to zero, output "inconsistent" and 146 // stop. 147 var bitMask byte = 0xff >> (8*emLen - emBits) 148 if em[0] & ^bitMask != 0 { 149 return ErrVerification 150 } 151 152 // 7. Let dbMask = MGF(H, emLen - hLen - 1). 153 // 154 // 8. Let DB = maskedDB \xor dbMask. 155 mgf1XOR(db, hash, h) 156 157 // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB 158 // to zero. 159 db[0] &= bitMask 160 161 // If we don't know the salt length, look for the 0x01 delimiter. 162 if sLen == PSSSaltLengthAuto { 163 psLen := bytes.IndexByte(db, 0x01) 164 if psLen < 0 { 165 return ErrVerification 166 } 167 sLen = len(db) - psLen - 1 168 } 169 170 // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero 171 // or if the octet at position emLen - hLen - sLen - 1 (the leftmost 172 // position is "position 1") does not have hexadecimal value 0x01, 173 // output "inconsistent" and stop. 174 psLen := emLen - hLen - sLen - 2 175 for _, e := range db[:psLen] { 176 if e != 0x00 { 177 return ErrVerification 178 } 179 } 180 if db[psLen] != 0x01 { 181 return ErrVerification 182 } 183 184 // 11. Let salt be the last sLen octets of DB. 185 salt := db[len(db)-sLen:] 186 187 // 12. Let 188 // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; 189 // M' is an octet string of length 8 + hLen + sLen with eight 190 // initial zero octets. 191 // 192 // 13. Let H' = Hash(M'), an octet string of length hLen. 193 var prefix [8]byte 194 hash.Write(prefix[:]) 195 hash.Write(mHash) 196 hash.Write(salt) 197 198 h0 := hash.Sum(nil) 199 200 // 14. If H = H', output "consistent." Otherwise, output "inconsistent." 201 if !bytes.Equal(h0, h) { // TODO: constant time? 202 return ErrVerification 203 } 204 return nil 205 } 206 207 // signPSSWithSalt calculates the signature of hashed using PSS with specified salt. 208 // Note that hashed must be the result of hashing the input message using the 209 // given hash function. salt is a random sequence of bytes whose length will be 210 // later used to verify the signature. 211 func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) { 212 emBits := priv.N.BitLen() - 1 213 em, err := emsaPSSEncode(hashed, emBits, salt, hash.New()) 214 if err != nil { 215 return nil, err 216 } 217 218 if boring.Enabled { 219 bkey, err := boringPrivateKey(priv) 220 if err != nil { 221 return nil, err 222 } 223 // Note: BoringCrypto takes care of the "AndCheck" part of "decryptAndCheck". 224 // (It's not just decrypt.) 225 s, err := boring.DecryptRSANoPadding(bkey, em) 226 if err != nil { 227 return nil, err 228 } 229 return s, nil 230 } 231 232 m := new(big.Int).SetBytes(em) 233 c, err := decryptAndCheck(rand, priv, m) 234 if err != nil { 235 return nil, err 236 } 237 s := make([]byte, priv.Size()) 238 return c.FillBytes(s), nil 239 } 240 241 const ( 242 // PSSSaltLengthAuto causes the salt in a PSS signature to be as large 243 // as possible when signing, and to be auto-detected when verifying. 244 PSSSaltLengthAuto = 0 245 // PSSSaltLengthEqualsHash causes the salt length to equal the length 246 // of the hash used in the signature. 247 PSSSaltLengthEqualsHash = -1 248 ) 249 250 // PSSOptions contains options for creating and verifying PSS signatures. 251 type PSSOptions struct { 252 // SaltLength controls the length of the salt used in the PSS signature. It 253 // can either be a positive number of bytes, or one of the special 254 // PSSSaltLength constants. 255 SaltLength int 256 257 // Hash is the hash function used to generate the message digest. If not 258 // zero, it overrides the hash function passed to SignPSS. It's required 259 // when using PrivateKey.Sign. 260 Hash crypto.Hash 261 } 262 263 // HashFunc returns opts.Hash so that PSSOptions implements crypto.SignerOpts. 264 func (opts *PSSOptions) HashFunc() crypto.Hash { 265 return opts.Hash 266 } 267 268 func (opts *PSSOptions) saltLength() int { 269 if opts == nil { 270 return PSSSaltLengthAuto 271 } 272 return opts.SaltLength 273 } 274 275 var invalidSaltLenErr = errors.New("crypto/rsa: PSSOptions.SaltLength cannot be negative") 276 277 // SignPSS calculates the signature of digest using PSS. 278 // 279 // digest must be the result of hashing the input message using the given hash 280 // function. The opts argument may be nil, in which case sensible defaults are 281 // used. If opts.Hash is set, it overrides hash. 282 func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte, opts *PSSOptions) ([]byte, error) { 283 if boring.Enabled && rand == boring.RandReader { 284 bkey, err := boringPrivateKey(priv) 285 if err != nil { 286 return nil, err 287 } 288 return boring.SignRSAPSS(bkey, hash, digest, opts.saltLength()) 289 } 290 boring.UnreachableExceptTests() 291 292 if opts != nil && opts.Hash != 0 { 293 hash = opts.Hash 294 } 295 296 saltLength := opts.saltLength() 297 switch saltLength { 298 case PSSSaltLengthAuto: 299 saltLength = (priv.N.BitLen()-1+7)/8 - 2 - hash.Size() 300 case PSSSaltLengthEqualsHash: 301 saltLength = hash.Size() 302 default: 303 // If we get here saltLength is either > 0 or < -1, in the 304 // latter case we fail out. 305 if saltLength <= 0 { 306 return nil, invalidSaltLenErr 307 } 308 } 309 salt := make([]byte, saltLength) 310 if _, err := io.ReadFull(rand, salt); err != nil { 311 return nil, err 312 } 313 return signPSSWithSalt(rand, priv, hash, digest, salt) 314 } 315 316 // VerifyPSS verifies a PSS signature. 317 // 318 // A valid signature is indicated by returning a nil error. digest must be the 319 // result of hashing the input message using the given hash function. The opts 320 // argument may be nil, in which case sensible defaults are used. opts.Hash is 321 // ignored. 322 func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts *PSSOptions) error { 323 if boring.Enabled { 324 bkey, err := boringPublicKey(pub) 325 if err != nil { 326 return err 327 } 328 if err := boring.VerifyRSAPSS(bkey, hash, digest, sig, opts.saltLength()); err != nil { 329 return ErrVerification 330 } 331 return nil 332 } 333 if len(sig) != pub.Size() { 334 return ErrVerification 335 } 336 // Salt length must be either one of the special constants (-1 or 0) 337 // or otherwise positive. If it is < PSSSaltLengthEqualsHash (-1) 338 // we return an error. 339 if opts.saltLength() < PSSSaltLengthEqualsHash { 340 return invalidSaltLenErr 341 } 342 s := new(big.Int).SetBytes(sig) 343 m := encrypt(new(big.Int), pub, s) 344 emBits := pub.N.BitLen() - 1 345 emLen := (emBits + 7) / 8 346 if m.BitLen() > emLen*8 { 347 return ErrVerification 348 } 349 em := m.FillBytes(make([]byte, emLen)) 350 return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New()) 351 }