github.com/cloudflare/circl@v1.5.0/blindsign/blindrsa/internal/common/common.go (about)

     1  package common
     2  
     3  import (
     4  	"crypto"
     5  	"crypto/rand"
     6  	"crypto/rsa"
     7  	"crypto/sha256"
     8  	"crypto/sha512"
     9  	"crypto/subtle"
    10  	"errors"
    11  	"hash"
    12  	"io"
    13  	"math/big"
    14  
    15  	"github.com/cloudflare/circl/blindsign/blindrsa/internal/keys"
    16  )
    17  
    18  // ConvertHashFunction converts a crypto.Hash function to an equivalent hash.Hash type.
    19  func ConvertHashFunction(hash crypto.Hash) hash.Hash {
    20  	switch hash {
    21  	case crypto.SHA256:
    22  		return sha256.New()
    23  	case crypto.SHA384:
    24  		return sha512.New384()
    25  	case crypto.SHA512:
    26  		return sha512.New()
    27  	default:
    28  		panic(ErrUnsupportedHashFunction)
    29  	}
    30  }
    31  
    32  // EncodeMessageEMSAPSS hashes the input message and then encodes it using PSS encoding.
    33  func EncodeMessageEMSAPSS(message []byte, N *big.Int, hash hash.Hash, salt []byte) ([]byte, error) {
    34  	hash.Reset() // Ensure the hash state is cleared
    35  	hash.Write(message)
    36  	digest := hash.Sum(nil)
    37  	hash.Reset()
    38  	emBits := N.BitLen() - 1
    39  	encodedMsg, err := emsaPSSEncode(digest[:], emBits, salt, hash)
    40  	return encodedMsg, err
    41  }
    42  
    43  // GenerateBlindingFactor generates a blinding factor and its multiplicative inverse
    44  // to use for RSA blinding.
    45  func GenerateBlindingFactor(random io.Reader, N *big.Int) (*big.Int, *big.Int, error) {
    46  	randReader := random
    47  	if randReader == nil {
    48  		randReader = rand.Reader
    49  	}
    50  	r, err := rand.Int(randReader, N)
    51  	if err != nil {
    52  		return nil, nil, err
    53  	}
    54  
    55  	if r.Sign() == 0 {
    56  		r.SetInt64(1)
    57  	}
    58  	rInv := new(big.Int).ModInverse(r, N)
    59  	if rInv == nil {
    60  		return nil, nil, ErrInvalidBlind
    61  	}
    62  
    63  	return r, rInv, nil
    64  }
    65  
    66  // VerifyMessageSignature verifies the input message signature against the expected public key
    67  func VerifyMessageSignature(message, signature []byte, saltLength int, pk *keys.BigPublicKey, hash crypto.Hash) error {
    68  	h := ConvertHashFunction(hash)
    69  	h.Write(message)
    70  	digest := h.Sum(nil)
    71  
    72  	err := verifyPSS(pk, hash, digest, signature, &rsa.PSSOptions{
    73  		Hash:       hash,
    74  		SaltLength: saltLength,
    75  	})
    76  	return err
    77  }
    78  
    79  // DecryptAndCheck checks that the private key operation is consistent (fault attack detection).
    80  func DecryptAndCheck(random io.Reader, priv *keys.BigPrivateKey, c *big.Int) (m *big.Int, err error) {
    81  	m, err = decrypt(random, priv, c)
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  
    86  	// In order to defend against errors in the CRT computation, m^e is
    87  	// calculated, which should match the original ciphertext.
    88  	check := encrypt(new(big.Int), priv.Pk.N, priv.Pk.E, m)
    89  	if c.Cmp(check) != 0 {
    90  		return nil, errors.New("rsa: internal error")
    91  	}
    92  	return m, nil
    93  }
    94  
    95  // VerifyBlindSignature verifies the signature of the hashed and encoded message against the input public key.
    96  func VerifyBlindSignature(pub *keys.BigPublicKey, hashed, sig []byte) error {
    97  	m := new(big.Int).SetBytes(hashed)
    98  	bigSig := new(big.Int).SetBytes(sig)
    99  
   100  	c := encrypt(new(big.Int), pub.N, pub.E, bigSig)
   101  	if subtle.ConstantTimeCompare(m.Bytes(), c.Bytes()) == 1 {
   102  		return nil
   103  	} else {
   104  		return rsa.ErrVerification
   105  	}
   106  }
   107  
   108  func saltLength(opts *rsa.PSSOptions) int {
   109  	if opts == nil {
   110  		return rsa.PSSSaltLengthAuto
   111  	}
   112  	return opts.SaltLength
   113  }
   114  
   115  func verifyPSS(pub *keys.BigPublicKey, hash crypto.Hash, digest []byte, sig []byte, opts *rsa.PSSOptions) error {
   116  	if len(sig) != pub.Size() {
   117  		return rsa.ErrVerification
   118  	}
   119  	s := new(big.Int).SetBytes(sig)
   120  	m := encrypt(new(big.Int), pub.N, pub.E, s)
   121  	emBits := pub.N.BitLen() - 1
   122  	emLen := (emBits + 7) / 8
   123  	if m.BitLen() > emLen*8 {
   124  		return rsa.ErrVerification
   125  	}
   126  	em := m.FillBytes(make([]byte, emLen))
   127  	return emsaPSSVerify(digest, em, emBits, saltLength(opts), hash.New())
   128  }
   129  
   130  var (
   131  	// ErrInvalidVariant is the error used if the variant request does not exist.
   132  	ErrInvalidVariant = errors.New("blindsign/blindrsa: invalid variant requested")
   133  
   134  	// ErrUnexpectedSize is the error used if the size of a parameter does not match its expected value.
   135  	ErrUnexpectedSize = errors.New("blindsign/blindrsa: unexpected input size")
   136  
   137  	// ErrInvalidMessageLength is the error used if the size of a protocol message does not match its expected value.
   138  	ErrInvalidMessageLength = errors.New("blindsign/blindrsa: invalid message length")
   139  
   140  	// ErrInvalidBlind is the error used if the blind generated by the Verifier fails.
   141  	ErrInvalidBlind = errors.New("blindsign/blindrsa: invalid blind")
   142  
   143  	// ErrInvalidRandomness is the error used if caller did not provide randomness to the Blind() function.
   144  	ErrInvalidRandomness = errors.New("blindsign/blindrsa: invalid random parameter")
   145  
   146  	// ErrUnsupportedHashFunction is the error used if the specified hash is not supported.
   147  	ErrUnsupportedHashFunction = errors.New("blindsign/blindrsa: unsupported hash function")
   148  )