github.com/mattn/go@v0.0.0-20171011075504-07f7db3ea99f/src/crypto/rsa/pss.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package rsa 6 7 // This file implements the PSS signature scheme [1]. 8 // 9 // [1] https://www.emc.com/collateral/white-papers/h11300-pkcs-1v2-2-rsa-cryptography-standard-wp.pdf 10 11 import ( 12 "bytes" 13 "crypto" 14 "errors" 15 "hash" 16 "io" 17 "math/big" 18 ) 19 20 func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) { 21 // See [1], section 9.1.1 22 hLen := hash.Size() 23 sLen := len(salt) 24 emLen := (emBits + 7) / 8 25 26 // 1. If the length of M is greater than the input limitation for the 27 // hash function (2^61 - 1 octets for SHA-1), output "message too 28 // long" and stop. 29 // 30 // 2. Let mHash = Hash(M), an octet string of length hLen. 31 32 if len(mHash) != hLen { 33 return nil, errors.New("crypto/rsa: input must be hashed message") 34 } 35 36 // 3. If emLen < hLen + sLen + 2, output "encoding error" and stop. 37 38 if emLen < hLen+sLen+2 { 39 return nil, errors.New("crypto/rsa: encoding error") 40 } 41 42 em := make([]byte, emLen) 43 db := em[:emLen-sLen-hLen-2+1+sLen] 44 h := em[emLen-sLen-hLen-2+1+sLen : emLen-1] 45 46 // 4. Generate a random octet string salt of length sLen; if sLen = 0, 47 // then salt is the empty string. 48 // 49 // 5. Let 50 // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt; 51 // 52 // M' is an octet string of length 8 + hLen + sLen with eight 53 // initial zero octets. 54 // 55 // 6. Let H = Hash(M'), an octet string of length hLen. 56 57 var prefix [8]byte 58 59 hash.Write(prefix[:]) 60 hash.Write(mHash) 61 hash.Write(salt) 62 63 h = hash.Sum(h[:0]) 64 hash.Reset() 65 66 // 7. Generate an octet string PS consisting of emLen - sLen - hLen - 2 67 // zero octets. The length of PS may be 0. 68 // 69 // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length 70 // emLen - hLen - 1. 71 72 db[emLen-sLen-hLen-2] = 0x01 73 copy(db[emLen-sLen-hLen-1:], salt) 74 75 // 9. Let dbMask = MGF(H, emLen - hLen - 1). 76 // 77 // 10. Let maskedDB = DB \xor dbMask. 78 79 mgf1XOR(db, hash, h) 80 81 // 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in 82 // maskedDB to zero. 83 84 db[0] &= (0xFF >> uint(8*emLen-emBits)) 85 86 // 12. Let EM = maskedDB || H || 0xbc. 87 em[emLen-1] = 0xBC 88 89 // 13. Output EM. 90 return em, nil 91 } 92 93 func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error { 94 // 1. If the length of M is greater than the input limitation for the 95 // hash function (2^61 - 1 octets for SHA-1), output "inconsistent" 96 // and stop. 97 // 98 // 2. Let mHash = Hash(M), an octet string of length hLen. 99 hLen := hash.Size() 100 if hLen != len(mHash) { 101 return ErrVerification 102 } 103 104 // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. 105 emLen := (emBits + 7) / 8 106 if emLen < hLen+sLen+2 { 107 return ErrVerification 108 } 109 110 // 4. If the rightmost octet of EM does not have hexadecimal value 111 // 0xbc, output "inconsistent" and stop. 112 if em[len(em)-1] != 0xBC { 113 return ErrVerification 114 } 115 116 // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and 117 // let H be the next hLen octets. 118 db := em[:emLen-hLen-1] 119 h := em[emLen-hLen-1 : len(em)-1] 120 121 // 6. If the leftmost 8 * emLen - emBits bits of the leftmost octet in 122 // maskedDB are not all equal to zero, output "inconsistent" and 123 // stop. 124 if em[0]&(0xFF<<uint(8-(8*emLen-emBits))) != 0 { 125 return ErrVerification 126 } 127 128 // 7. Let dbMask = MGF(H, emLen - hLen - 1). 129 // 130 // 8. Let DB = maskedDB \xor dbMask. 131 mgf1XOR(db, hash, h) 132 133 // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB 134 // to zero. 135 db[0] &= (0xFF >> uint(8*emLen-emBits)) 136 137 if sLen == PSSSaltLengthAuto { 138 FindSaltLength: 139 for sLen = emLen - (hLen + 2); sLen >= 0; sLen-- { 140 switch db[emLen-hLen-sLen-2] { 141 case 1: 142 break FindSaltLength 143 case 0: 144 continue 145 default: 146 return ErrVerification 147 } 148 } 149 if sLen < 0 { 150 return ErrVerification 151 } 152 } else { 153 // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero 154 // or if the octet at position emLen - hLen - sLen - 1 (the leftmost 155 // position is "position 1") does not have hexadecimal value 0x01, 156 // output "inconsistent" and stop. 157 for _, e := range db[:emLen-hLen-sLen-2] { 158 if e != 0x00 { 159 return ErrVerification 160 } 161 } 162 if db[emLen-hLen-sLen-2] != 0x01 { 163 return ErrVerification 164 } 165 } 166 167 // 11. Let salt be the last sLen octets of DB. 168 salt := db[len(db)-sLen:] 169 170 // 12. Let 171 // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; 172 // M' is an octet string of length 8 + hLen + sLen with eight 173 // initial zero octets. 174 // 175 // 13. Let H' = Hash(M'), an octet string of length hLen. 176 var prefix [8]byte 177 hash.Write(prefix[:]) 178 hash.Write(mHash) 179 hash.Write(salt) 180 181 h0 := hash.Sum(nil) 182 183 // 14. If H = H', output "consistent." Otherwise, output "inconsistent." 184 if !bytes.Equal(h0, h) { 185 return ErrVerification 186 } 187 return nil 188 } 189 190 // signPSSWithSalt calculates the signature of hashed using PSS [1] with specified salt. 191 // Note that hashed must be the result of hashing the input message using the 192 // given hash function. salt is a random sequence of bytes whose length will be 193 // later used to verify the signature. 194 func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) (s []byte, err error) { 195 nBits := priv.N.BitLen() 196 em, err := emsaPSSEncode(hashed, nBits-1, salt, hash.New()) 197 if err != nil { 198 return 199 } 200 m := new(big.Int).SetBytes(em) 201 c, err := decryptAndCheck(rand, priv, m) 202 if err != nil { 203 return 204 } 205 s = make([]byte, (nBits+7)/8) 206 copyWithLeftPad(s, c.Bytes()) 207 return 208 } 209 210 const ( 211 // PSSSaltLengthAuto causes the salt in a PSS signature to be as large 212 // as possible when signing, and to be auto-detected when verifying. 213 PSSSaltLengthAuto = 0 214 // PSSSaltLengthEqualsHash causes the salt length to equal the length 215 // of the hash used in the signature. 216 PSSSaltLengthEqualsHash = -1 217 ) 218 219 // PSSOptions contains options for creating and verifying PSS signatures. 220 type PSSOptions struct { 221 // SaltLength controls the length of the salt used in the PSS 222 // signature. It can either be a number of bytes, or one of the special 223 // PSSSaltLength constants. 224 SaltLength int 225 226 // Hash, if not zero, overrides the hash function passed to SignPSS. 227 // This is the only way to specify the hash function when using the 228 // crypto.Signer interface. 229 Hash crypto.Hash 230 } 231 232 // HashFunc returns pssOpts.Hash so that PSSOptions implements 233 // crypto.SignerOpts. 234 func (pssOpts *PSSOptions) HashFunc() crypto.Hash { 235 return pssOpts.Hash 236 } 237 238 func (opts *PSSOptions) saltLength() int { 239 if opts == nil { 240 return PSSSaltLengthAuto 241 } 242 return opts.SaltLength 243 } 244 245 // SignPSS calculates the signature of hashed using RSASSA-PSS [1]. 246 // Note that hashed must be the result of hashing the input message using the 247 // given hash function. The opts argument may be nil, in which case sensible 248 // defaults are used. 249 func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []byte, opts *PSSOptions) ([]byte, error) { 250 saltLength := opts.saltLength() 251 switch saltLength { 252 case PSSSaltLengthAuto: 253 saltLength = (priv.N.BitLen()+7)/8 - 2 - hash.Size() 254 case PSSSaltLengthEqualsHash: 255 saltLength = hash.Size() 256 } 257 258 if opts != nil && opts.Hash != 0 { 259 hash = opts.Hash 260 } 261 262 salt := make([]byte, saltLength) 263 if _, err := io.ReadFull(rand, salt); err != nil { 264 return nil, err 265 } 266 return signPSSWithSalt(rand, priv, hash, hashed, salt) 267 } 268 269 // VerifyPSS verifies a PSS signature. 270 // hashed is the result of hashing the input message using the given hash 271 // function and sig is the signature. A valid signature is indicated by 272 // returning a nil error. The opts argument may be nil, in which case sensible 273 // defaults are used. 274 func VerifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, opts *PSSOptions) error { 275 return verifyPSS(pub, hash, hashed, sig, opts.saltLength()) 276 } 277 278 // verifyPSS verifies a PSS signature with the given salt length. 279 func verifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, saltLen int) error { 280 nBits := pub.N.BitLen() 281 if len(sig) != (nBits+7)/8 { 282 return ErrVerification 283 } 284 s := new(big.Int).SetBytes(sig) 285 m := encrypt(new(big.Int), pub, s) 286 emBits := nBits - 1 287 emLen := (emBits + 7) / 8 288 if emLen < len(m.Bytes()) { 289 return ErrVerification 290 } 291 em := make([]byte, emLen) 292 copyWithLeftPad(em, m.Bytes()) 293 if saltLen == PSSSaltLengthEqualsHash { 294 saltLen = hash.Size() 295 } 296 return emsaPSSVerify(hashed, em, emBits, saltLen, hash.New()) 297 }