github.com/geraldss/go/src@v0.0.0-20210511222824-ac7d0ebfc235/crypto/rsa/pss.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package rsa 6 7 // This file implements the RSASSA-PSS signature scheme according to RFC 8017. 8 9 import ( 10 "bytes" 11 "crypto" 12 "errors" 13 "hash" 14 "io" 15 "math/big" 16 ) 17 18 // Per RFC 8017, Section 9.1 19 // 20 // EM = MGF1 xor DB || H( 8*0x00 || mHash || salt ) || 0xbc 21 // 22 // where 23 // 24 // DB = PS || 0x01 || salt 25 // 26 // and PS can be empty so 27 // 28 // emLen = dbLen + hLen + 1 = psLen + sLen + hLen + 2 29 // 30 31 func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) { 32 // See RFC 8017, Section 9.1.1. 33 34 hLen := hash.Size() 35 sLen := len(salt) 36 emLen := (emBits + 7) / 8 37 38 // 1. If the length of M is greater than the input limitation for the 39 // hash function (2^61 - 1 octets for SHA-1), output "message too 40 // long" and stop. 41 // 42 // 2. Let mHash = Hash(M), an octet string of length hLen. 43 44 if len(mHash) != hLen { 45 return nil, errors.New("crypto/rsa: input must be hashed with given hash") 46 } 47 48 // 3. If emLen < hLen + sLen + 2, output "encoding error" and stop. 49 50 if emLen < hLen+sLen+2 { 51 return nil, errors.New("crypto/rsa: key size too small for PSS signature") 52 } 53 54 em := make([]byte, emLen) 55 psLen := emLen - sLen - hLen - 2 56 db := em[:psLen+1+sLen] 57 h := em[psLen+1+sLen : emLen-1] 58 59 // 4. Generate a random octet string salt of length sLen; if sLen = 0, 60 // then salt is the empty string. 61 // 62 // 5. Let 63 // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt; 64 // 65 // M' is an octet string of length 8 + hLen + sLen with eight 66 // initial zero octets. 67 // 68 // 6. Let H = Hash(M'), an octet string of length hLen. 69 70 var prefix [8]byte 71 72 hash.Write(prefix[:]) 73 hash.Write(mHash) 74 hash.Write(salt) 75 76 h = hash.Sum(h[:0]) 77 hash.Reset() 78 79 // 7. Generate an octet string PS consisting of emLen - sLen - hLen - 2 80 // zero octets. The length of PS may be 0. 81 // 82 // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length 83 // emLen - hLen - 1. 84 85 db[psLen] = 0x01 86 copy(db[psLen+1:], salt) 87 88 // 9. Let dbMask = MGF(H, emLen - hLen - 1). 89 // 90 // 10. Let maskedDB = DB \xor dbMask. 91 92 mgf1XOR(db, hash, h) 93 94 // 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in 95 // maskedDB to zero. 96 97 db[0] &= 0xff >> (8*emLen - emBits) 98 99 // 12. Let EM = maskedDB || H || 0xbc. 100 em[emLen-1] = 0xbc 101 102 // 13. Output EM. 103 return em, nil 104 } 105 106 func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error { 107 // See RFC 8017, Section 9.1.2. 108 109 hLen := hash.Size() 110 if sLen == PSSSaltLengthEqualsHash { 111 sLen = hLen 112 } 113 emLen := (emBits + 7) / 8 114 if emLen != len(em) { 115 return errors.New("rsa: internal error: inconsistent length") 116 } 117 118 // 1. If the length of M is greater than the input limitation for the 119 // hash function (2^61 - 1 octets for SHA-1), output "inconsistent" 120 // and stop. 121 // 122 // 2. Let mHash = Hash(M), an octet string of length hLen. 123 if hLen != len(mHash) { 124 return ErrVerification 125 } 126 127 // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. 128 if emLen < hLen+sLen+2 { 129 return ErrVerification 130 } 131 132 // 4. If the rightmost octet of EM does not have hexadecimal value 133 // 0xbc, output "inconsistent" and stop. 134 if em[emLen-1] != 0xbc { 135 return ErrVerification 136 } 137 138 // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and 139 // let H be the next hLen octets. 140 db := em[:emLen-hLen-1] 141 h := em[emLen-hLen-1 : emLen-1] 142 143 // 6. If the leftmost 8 * emLen - emBits bits of the leftmost octet in 144 // maskedDB are not all equal to zero, output "inconsistent" and 145 // stop. 146 var bitMask byte = 0xff >> (8*emLen - emBits) 147 if em[0] & ^bitMask != 0 { 148 return ErrVerification 149 } 150 151 // 7. Let dbMask = MGF(H, emLen - hLen - 1). 152 // 153 // 8. Let DB = maskedDB \xor dbMask. 154 mgf1XOR(db, hash, h) 155 156 // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB 157 // to zero. 158 db[0] &= bitMask 159 160 // If we don't know the salt length, look for the 0x01 delimiter. 161 if sLen == PSSSaltLengthAuto { 162 psLen := bytes.IndexByte(db, 0x01) 163 if psLen < 0 { 164 return ErrVerification 165 } 166 sLen = len(db) - psLen - 1 167 } 168 169 // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero 170 // or if the octet at position emLen - hLen - sLen - 1 (the leftmost 171 // position is "position 1") does not have hexadecimal value 0x01, 172 // output "inconsistent" and stop. 173 psLen := emLen - hLen - sLen - 2 174 for _, e := range db[:psLen] { 175 if e != 0x00 { 176 return ErrVerification 177 } 178 } 179 if db[psLen] != 0x01 { 180 return ErrVerification 181 } 182 183 // 11. Let salt be the last sLen octets of DB. 184 salt := db[len(db)-sLen:] 185 186 // 12. Let 187 // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; 188 // M' is an octet string of length 8 + hLen + sLen with eight 189 // initial zero octets. 190 // 191 // 13. Let H' = Hash(M'), an octet string of length hLen. 192 var prefix [8]byte 193 hash.Write(prefix[:]) 194 hash.Write(mHash) 195 hash.Write(salt) 196 197 h0 := hash.Sum(nil) 198 199 // 14. If H = H', output "consistent." Otherwise, output "inconsistent." 200 if !bytes.Equal(h0, h) { // TODO: constant time? 201 return ErrVerification 202 } 203 return nil 204 } 205 206 // signPSSWithSalt calculates the signature of hashed using PSS with specified salt. 207 // Note that hashed must be the result of hashing the input message using the 208 // given hash function. salt is a random sequence of bytes whose length will be 209 // later used to verify the signature. 210 func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) { 211 emBits := priv.N.BitLen() - 1 212 em, err := emsaPSSEncode(hashed, emBits, salt, hash.New()) 213 if err != nil { 214 return nil, err 215 } 216 m := new(big.Int).SetBytes(em) 217 c, err := decryptAndCheck(rand, priv, m) 218 if err != nil { 219 return nil, err 220 } 221 s := make([]byte, priv.Size()) 222 return c.FillBytes(s), nil 223 } 224 225 const ( 226 // PSSSaltLengthAuto causes the salt in a PSS signature to be as large 227 // as possible when signing, and to be auto-detected when verifying. 228 PSSSaltLengthAuto = 0 229 // PSSSaltLengthEqualsHash causes the salt length to equal the length 230 // of the hash used in the signature. 231 PSSSaltLengthEqualsHash = -1 232 ) 233 234 // PSSOptions contains options for creating and verifying PSS signatures. 235 type PSSOptions struct { 236 // SaltLength controls the length of the salt used in the PSS 237 // signature. It can either be a number of bytes, or one of the special 238 // PSSSaltLength constants. 239 SaltLength int 240 241 // Hash is the hash function used to generate the message digest. If not 242 // zero, it overrides the hash function passed to SignPSS. It's required 243 // when using PrivateKey.Sign. 244 Hash crypto.Hash 245 } 246 247 // HashFunc returns opts.Hash so that PSSOptions implements crypto.SignerOpts. 248 func (opts *PSSOptions) HashFunc() crypto.Hash { 249 return opts.Hash 250 } 251 252 func (opts *PSSOptions) saltLength() int { 253 if opts == nil { 254 return PSSSaltLengthAuto 255 } 256 return opts.SaltLength 257 } 258 259 // SignPSS calculates the signature of digest using PSS. 260 // 261 // digest must be the result of hashing the input message using the given hash 262 // function. The opts argument may be nil, in which case sensible defaults are 263 // used. If opts.Hash is set, it overrides hash. 264 func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte, opts *PSSOptions) ([]byte, error) { 265 if opts != nil && opts.Hash != 0 { 266 hash = opts.Hash 267 } 268 269 saltLength := opts.saltLength() 270 switch saltLength { 271 case PSSSaltLengthAuto: 272 saltLength = priv.Size() - 2 - hash.Size() 273 case PSSSaltLengthEqualsHash: 274 saltLength = hash.Size() 275 } 276 277 salt := make([]byte, saltLength) 278 if _, err := io.ReadFull(rand, salt); err != nil { 279 return nil, err 280 } 281 return signPSSWithSalt(rand, priv, hash, digest, salt) 282 } 283 284 // VerifyPSS verifies a PSS signature. 285 // 286 // A valid signature is indicated by returning a nil error. digest must be the 287 // result of hashing the input message using the given hash function. The opts 288 // argument may be nil, in which case sensible defaults are used. opts.Hash is 289 // ignored. 290 func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts *PSSOptions) error { 291 if len(sig) != pub.Size() { 292 return ErrVerification 293 } 294 s := new(big.Int).SetBytes(sig) 295 m := encrypt(new(big.Int), pub, s) 296 emBits := pub.N.BitLen() - 1 297 emLen := (emBits + 7) / 8 298 if m.BitLen() > emLen*8 { 299 return ErrVerification 300 } 301 em := m.FillBytes(make([]byte, emLen)) 302 return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New()) 303 }