github.com/JimmyHuang454/JLS-go@v0.0.0-20230831150107-90d536585ba0/boring/rsa.go (about)

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build boringcrypto && linux && (amd64 || arm64) && !android && !cmd_go_bootstrap && !msan
     6  
     7  package boring
     8  
     9  // #include "goboringcrypto.h"
    10  import "C"
    11  import (
    12  	"crypto"
    13  	"crypto/subtle"
    14  	"errors"
    15  	"hash"
    16  	"runtime"
    17  	"strconv"
    18  	"unsafe"
    19  )
    20  
    21  func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv BigInt, err error) {
    22  	bad := func(e error) (N, E, D, P, Q, Dp, Dq, Qinv BigInt, err error) {
    23  		return nil, nil, nil, nil, nil, nil, nil, nil, e
    24  	}
    25  
    26  	key := C._goboringcrypto_RSA_new()
    27  	if key == nil {
    28  		return bad(fail("RSA_new"))
    29  	}
    30  	defer C._goboringcrypto_RSA_free(key)
    31  
    32  	if C._goboringcrypto_RSA_generate_key_fips(key, C.int(bits), nil) == 0 {
    33  		return bad(fail("RSA_generate_key_fips"))
    34  	}
    35  
    36  	var n, e, d, p, q, dp, dq, qinv *C.GO_BIGNUM
    37  	C._goboringcrypto_RSA_get0_key(key, &n, &e, &d)
    38  	C._goboringcrypto_RSA_get0_factors(key, &p, &q)
    39  	C._goboringcrypto_RSA_get0_crt_params(key, &dp, &dq, &qinv)
    40  	return bnToBig(n), bnToBig(e), bnToBig(d), bnToBig(p), bnToBig(q), bnToBig(dp), bnToBig(dq), bnToBig(qinv), nil
    41  }
    42  
    43  type PublicKeyRSA struct {
    44  	// _key MUST NOT be accessed directly. Instead, use the withKey method.
    45  	_key *C.GO_RSA
    46  }
    47  
    48  func NewPublicKeyRSA(N, E BigInt) (*PublicKeyRSA, error) {
    49  	key := C._goboringcrypto_RSA_new()
    50  	if key == nil {
    51  		return nil, fail("RSA_new")
    52  	}
    53  	if !bigToBn(&key.n, N) ||
    54  		!bigToBn(&key.e, E) {
    55  		return nil, fail("BN_bin2bn")
    56  	}
    57  	k := &PublicKeyRSA{_key: key}
    58  	runtime.SetFinalizer(k, (*PublicKeyRSA).finalize)
    59  	return k, nil
    60  }
    61  
    62  func (k *PublicKeyRSA) finalize() {
    63  	C._goboringcrypto_RSA_free(k._key)
    64  }
    65  
    66  func (k *PublicKeyRSA) withKey(f func(*C.GO_RSA) C.int) C.int {
    67  	// Because of the finalizer, any time _key is passed to cgo, that call must
    68  	// be followed by a call to runtime.KeepAlive, to make sure k is not
    69  	// collected (and finalized) before the cgo call returns.
    70  	defer runtime.KeepAlive(k)
    71  	return f(k._key)
    72  }
    73  
    74  type PrivateKeyRSA struct {
    75  	// _key MUST NOT be accessed directly. Instead, use the withKey method.
    76  	_key *C.GO_RSA
    77  }
    78  
    79  func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv BigInt) (*PrivateKeyRSA, error) {
    80  	key := C._goboringcrypto_RSA_new()
    81  	if key == nil {
    82  		return nil, fail("RSA_new")
    83  	}
    84  	if !bigToBn(&key.n, N) ||
    85  		!bigToBn(&key.e, E) ||
    86  		!bigToBn(&key.d, D) ||
    87  		!bigToBn(&key.p, P) ||
    88  		!bigToBn(&key.q, Q) ||
    89  		!bigToBn(&key.dmp1, Dp) ||
    90  		!bigToBn(&key.dmq1, Dq) ||
    91  		!bigToBn(&key.iqmp, Qinv) {
    92  		return nil, fail("BN_bin2bn")
    93  	}
    94  	k := &PrivateKeyRSA{_key: key}
    95  	runtime.SetFinalizer(k, (*PrivateKeyRSA).finalize)
    96  	return k, nil
    97  }
    98  
    99  func (k *PrivateKeyRSA) finalize() {
   100  	C._goboringcrypto_RSA_free(k._key)
   101  }
   102  
   103  func (k *PrivateKeyRSA) withKey(f func(*C.GO_RSA) C.int) C.int {
   104  	// Because of the finalizer, any time _key is passed to cgo, that call must
   105  	// be followed by a call to runtime.KeepAlive, to make sure k is not
   106  	// collected (and finalized) before the cgo call returns.
   107  	defer runtime.KeepAlive(k)
   108  	return f(k._key)
   109  }
   110  
   111  func setupRSA(withKey func(func(*C.GO_RSA) C.int) C.int,
   112  	padding C.int, h, mgfHash hash.Hash, label []byte, saltLen int, ch crypto.Hash,
   113  	init func(*C.GO_EVP_PKEY_CTX) C.int) (pkey *C.GO_EVP_PKEY, ctx *C.GO_EVP_PKEY_CTX, err error) {
   114  	defer func() {
   115  		if err != nil {
   116  			if pkey != nil {
   117  				C._goboringcrypto_EVP_PKEY_free(pkey)
   118  				pkey = nil
   119  			}
   120  			if ctx != nil {
   121  				C._goboringcrypto_EVP_PKEY_CTX_free(ctx)
   122  				ctx = nil
   123  			}
   124  		}
   125  	}()
   126  
   127  	pkey = C._goboringcrypto_EVP_PKEY_new()
   128  	if pkey == nil {
   129  		return nil, nil, fail("EVP_PKEY_new")
   130  	}
   131  	if withKey(func(key *C.GO_RSA) C.int {
   132  		return C._goboringcrypto_EVP_PKEY_set1_RSA(pkey, key)
   133  	}) == 0 {
   134  		return nil, nil, fail("EVP_PKEY_set1_RSA")
   135  	}
   136  	ctx = C._goboringcrypto_EVP_PKEY_CTX_new(pkey, nil)
   137  	if ctx == nil {
   138  		return nil, nil, fail("EVP_PKEY_CTX_new")
   139  	}
   140  	if init(ctx) == 0 {
   141  		return nil, nil, fail("EVP_PKEY_operation_init")
   142  	}
   143  	if C._goboringcrypto_EVP_PKEY_CTX_set_rsa_padding(ctx, padding) == 0 {
   144  		return nil, nil, fail("EVP_PKEY_CTX_set_rsa_padding")
   145  	}
   146  	if padding == C.GO_RSA_PKCS1_OAEP_PADDING {
   147  		md := hashToMD(h)
   148  		if md == nil {
   149  			return nil, nil, errors.New("crypto/rsa: unsupported hash function")
   150  		}
   151  		mgfMD := hashToMD(mgfHash)
   152  		if mgfMD == nil {
   153  			return nil, nil, errors.New("crypto/rsa: unsupported hash function")
   154  		}
   155  		if C._goboringcrypto_EVP_PKEY_CTX_set_rsa_oaep_md(ctx, md) == 0 {
   156  			return nil, nil, fail("EVP_PKEY_set_rsa_oaep_md")
   157  		}
   158  		if C._goboringcrypto_EVP_PKEY_CTX_set_rsa_mgf1_md(ctx, mgfMD) == 0 {
   159  			return nil, nil, fail("EVP_PKEY_set_rsa_mgf1_md")
   160  		}
   161  		// ctx takes ownership of label, so malloc a copy for BoringCrypto to free.
   162  		clabel := (*C.uint8_t)(C._goboringcrypto_OPENSSL_malloc(C.size_t(len(label))))
   163  		if clabel == nil {
   164  			return nil, nil, fail("OPENSSL_malloc")
   165  		}
   166  		copy((*[1 << 30]byte)(unsafe.Pointer(clabel))[:len(label)], label)
   167  		if C._goboringcrypto_EVP_PKEY_CTX_set0_rsa_oaep_label(ctx, clabel, C.size_t(len(label))) == 0 {
   168  			return nil, nil, fail("EVP_PKEY_CTX_set0_rsa_oaep_label")
   169  		}
   170  	}
   171  	if padding == C.GO_RSA_PKCS1_PSS_PADDING {
   172  		if saltLen != 0 {
   173  			if C._goboringcrypto_EVP_PKEY_CTX_set_rsa_pss_saltlen(ctx, C.int(saltLen)) == 0 {
   174  				return nil, nil, fail("EVP_PKEY_set_rsa_pss_saltlen")
   175  			}
   176  		}
   177  		md := cryptoHashToMD(ch)
   178  		if md == nil {
   179  			return nil, nil, errors.New("crypto/rsa: unsupported hash function")
   180  		}
   181  		if C._goboringcrypto_EVP_PKEY_CTX_set_rsa_mgf1_md(ctx, md) == 0 {
   182  			return nil, nil, fail("EVP_PKEY_set_rsa_mgf1_md")
   183  		}
   184  	}
   185  
   186  	return pkey, ctx, nil
   187  }
   188  
   189  func cryptRSA(withKey func(func(*C.GO_RSA) C.int) C.int,
   190  	padding C.int, h, mgfHash hash.Hash, label []byte, saltLen int, ch crypto.Hash,
   191  	init func(*C.GO_EVP_PKEY_CTX) C.int,
   192  	crypt func(*C.GO_EVP_PKEY_CTX, *C.uint8_t, *C.size_t, *C.uint8_t, C.size_t) C.int,
   193  	in []byte) ([]byte, error) {
   194  
   195  	pkey, ctx, err := setupRSA(withKey, padding, h, mgfHash, label, saltLen, ch, init)
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  	defer C._goboringcrypto_EVP_PKEY_free(pkey)
   200  	defer C._goboringcrypto_EVP_PKEY_CTX_free(ctx)
   201  
   202  	var outLen C.size_t
   203  	if crypt(ctx, nil, &outLen, base(in), C.size_t(len(in))) == 0 {
   204  		return nil, fail("EVP_PKEY_decrypt/encrypt")
   205  	}
   206  	out := make([]byte, outLen)
   207  	if crypt(ctx, base(out), &outLen, base(in), C.size_t(len(in))) == 0 {
   208  		return nil, fail("EVP_PKEY_decrypt/encrypt")
   209  	}
   210  	return out[:outLen], nil
   211  }
   212  
   213  func DecryptRSAOAEP(h, mgfHash hash.Hash, priv *PrivateKeyRSA, ciphertext, label []byte) ([]byte, error) {
   214  	return cryptRSA(priv.withKey, C.GO_RSA_PKCS1_OAEP_PADDING, h, mgfHash, label, 0, 0, decryptInit, decrypt, ciphertext)
   215  }
   216  
   217  func EncryptRSAOAEP(h, mgfHash hash.Hash, pub *PublicKeyRSA, msg, label []byte) ([]byte, error) {
   218  	return cryptRSA(pub.withKey, C.GO_RSA_PKCS1_OAEP_PADDING, h, mgfHash, label, 0, 0, encryptInit, encrypt, msg)
   219  }
   220  
   221  func DecryptRSAPKCS1(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error) {
   222  	return cryptRSA(priv.withKey, C.GO_RSA_PKCS1_PADDING, nil, nil, nil, 0, 0, decryptInit, decrypt, ciphertext)
   223  }
   224  
   225  func EncryptRSAPKCS1(pub *PublicKeyRSA, msg []byte) ([]byte, error) {
   226  	return cryptRSA(pub.withKey, C.GO_RSA_PKCS1_PADDING, nil, nil, nil, 0, 0, encryptInit, encrypt, msg)
   227  }
   228  
   229  func DecryptRSANoPadding(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error) {
   230  	return cryptRSA(priv.withKey, C.GO_RSA_NO_PADDING, nil, nil, nil, 0, 0, decryptInit, decrypt, ciphertext)
   231  }
   232  
   233  func EncryptRSANoPadding(pub *PublicKeyRSA, msg []byte) ([]byte, error) {
   234  	return cryptRSA(pub.withKey, C.GO_RSA_NO_PADDING, nil, nil, nil, 0, 0, encryptInit, encrypt, msg)
   235  }
   236  
   237  // These dumb wrappers work around the fact that cgo functions cannot be used as values directly.
   238  
   239  func decryptInit(ctx *C.GO_EVP_PKEY_CTX) C.int {
   240  	return C._goboringcrypto_EVP_PKEY_decrypt_init(ctx)
   241  }
   242  
   243  func decrypt(ctx *C.GO_EVP_PKEY_CTX, out *C.uint8_t, outLen *C.size_t, in *C.uint8_t, inLen C.size_t) C.int {
   244  	return C._goboringcrypto_EVP_PKEY_decrypt(ctx, out, outLen, in, inLen)
   245  }
   246  
   247  func encryptInit(ctx *C.GO_EVP_PKEY_CTX) C.int {
   248  	return C._goboringcrypto_EVP_PKEY_encrypt_init(ctx)
   249  }
   250  
   251  func encrypt(ctx *C.GO_EVP_PKEY_CTX, out *C.uint8_t, outLen *C.size_t, in *C.uint8_t, inLen C.size_t) C.int {
   252  	return C._goboringcrypto_EVP_PKEY_encrypt(ctx, out, outLen, in, inLen)
   253  }
   254  
   255  var invalidSaltLenErr = errors.New("crypto/rsa: PSSOptions.SaltLength cannot be negative")
   256  
   257  func SignRSAPSS(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte, saltLen int) ([]byte, error) {
   258  	md := cryptoHashToMD(h)
   259  	if md == nil {
   260  		return nil, errors.New("crypto/rsa: unsupported hash function")
   261  	}
   262  
   263  	// A salt length of -2 is valid in BoringSSL, but not in crypto/rsa, so reject
   264  	// it, and lengths < -2, before we convert to the BoringSSL sentinel values.
   265  	if saltLen <= -2 {
   266  		return nil, invalidSaltLenErr
   267  	}
   268  
   269  	// BoringSSL uses sentinel salt length values like we do, but the values don't
   270  	// fully match what we use. We both use -1 for salt length equal to hash length,
   271  	// but BoringSSL uses -2 to mean maximal size where we use 0. In the latter
   272  	// case convert to the BoringSSL version.
   273  	if saltLen == 0 {
   274  		saltLen = -2
   275  	}
   276  
   277  	var out []byte
   278  	var outLen C.size_t
   279  	if priv.withKey(func(key *C.GO_RSA) C.int {
   280  		out = make([]byte, C._goboringcrypto_RSA_size(key))
   281  		return C._goboringcrypto_RSA_sign_pss_mgf1(key, &outLen, base(out), C.size_t(len(out)),
   282  			base(hashed), C.size_t(len(hashed)), md, nil, C.int(saltLen))
   283  	}) == 0 {
   284  		return nil, fail("RSA_sign_pss_mgf1")
   285  	}
   286  
   287  	return out[:outLen], nil
   288  }
   289  
   290  func VerifyRSAPSS(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte, saltLen int) error {
   291  	md := cryptoHashToMD(h)
   292  	if md == nil {
   293  		return errors.New("crypto/rsa: unsupported hash function")
   294  	}
   295  
   296  	// A salt length of -2 is valid in BoringSSL, but not in crypto/rsa, so reject
   297  	// it, and lengths < -2, before we convert to the BoringSSL sentinel values.
   298  	if saltLen <= -2 {
   299  		return invalidSaltLenErr
   300  	}
   301  
   302  	// BoringSSL uses sentinel salt length values like we do, but the values don't
   303  	// fully match what we use. We both use -1 for salt length equal to hash length,
   304  	// but BoringSSL uses -2 to mean maximal size where we use 0. In the latter
   305  	// case convert to the BoringSSL version.
   306  	if saltLen == 0 {
   307  		saltLen = -2
   308  	}
   309  
   310  	if pub.withKey(func(key *C.GO_RSA) C.int {
   311  		return C._goboringcrypto_RSA_verify_pss_mgf1(key, base(hashed), C.size_t(len(hashed)),
   312  			md, nil, C.int(saltLen), base(sig), C.size_t(len(sig)))
   313  	}) == 0 {
   314  		return fail("RSA_verify_pss_mgf1")
   315  	}
   316  	return nil
   317  }
   318  
   319  func SignRSAPKCS1v15(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte) ([]byte, error) {
   320  	if h == 0 {
   321  		// No hashing.
   322  		var out []byte
   323  		var outLen C.size_t
   324  		if priv.withKey(func(key *C.GO_RSA) C.int {
   325  			out = make([]byte, C._goboringcrypto_RSA_size(key))
   326  			return C._goboringcrypto_RSA_sign_raw(key, &outLen, base(out), C.size_t(len(out)),
   327  				base(hashed), C.size_t(len(hashed)), C.GO_RSA_PKCS1_PADDING)
   328  		}) == 0 {
   329  			return nil, fail("RSA_sign_raw")
   330  		}
   331  		return out[:outLen], nil
   332  	}
   333  
   334  	md := cryptoHashToMD(h)
   335  	if md == nil {
   336  		return nil, errors.New("crypto/rsa: unsupported hash function: " + strconv.Itoa(int(h)))
   337  	}
   338  	nid := C._goboringcrypto_EVP_MD_type(md)
   339  	var out []byte
   340  	var outLen C.uint
   341  	if priv.withKey(func(key *C.GO_RSA) C.int {
   342  		out = make([]byte, C._goboringcrypto_RSA_size(key))
   343  		return C._goboringcrypto_RSA_sign(nid, base(hashed), C.uint(len(hashed)),
   344  			base(out), &outLen, key)
   345  	}) == 0 {
   346  		return nil, fail("RSA_sign")
   347  	}
   348  	return out[:outLen], nil
   349  }
   350  
   351  func VerifyRSAPKCS1v15(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte) error {
   352  	if h == 0 {
   353  		var out []byte
   354  		var outLen C.size_t
   355  		if pub.withKey(func(key *C.GO_RSA) C.int {
   356  			out = make([]byte, C._goboringcrypto_RSA_size(key))
   357  			return C._goboringcrypto_RSA_verify_raw(key, &outLen, base(out),
   358  				C.size_t(len(out)), base(sig), C.size_t(len(sig)), C.GO_RSA_PKCS1_PADDING)
   359  		}) == 0 {
   360  			return fail("RSA_verify")
   361  		}
   362  		if subtle.ConstantTimeCompare(hashed, out[:outLen]) != 1 {
   363  			return fail("RSA_verify")
   364  		}
   365  		return nil
   366  	}
   367  	md := cryptoHashToMD(h)
   368  	if md == nil {
   369  		return errors.New("crypto/rsa: unsupported hash function")
   370  	}
   371  	nid := C._goboringcrypto_EVP_MD_type(md)
   372  	if pub.withKey(func(key *C.GO_RSA) C.int {
   373  		return C._goboringcrypto_RSA_verify(nid, base(hashed), C.size_t(len(hashed)),
   374  			base(sig), C.size_t(len(sig)), key)
   375  	}) == 0 {
   376  		return fail("RSA_verify")
   377  	}
   378  	return nil
   379  }