github.com/crewjam/saml@v0.4.14/xmlenc/pubkey.go (about)

     1  package xmlenc
     2  
     3  import (
     4  	"crypto/rsa"
     5  	"crypto/x509"
     6  	"encoding/base64"
     7  	"fmt"
     8  
     9  	"github.com/beevik/etree"
    10  )
    11  
    12  // RSA implements Encrypter and Decrypter using RSA public key encryption.
    13  //
    14  // Use function like OAEP(), or PKCS1v15() to get an instance of this type ready
    15  // to use.
    16  type RSA struct {
    17  	BlockCipher  BlockCipher
    18  	DigestMethod DigestMethod // only for OAEP
    19  
    20  	algorithm    string
    21  	keyEncrypter func(e RSA, pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error)
    22  	keyDecrypter func(e RSA, privKey *rsa.PrivateKey, ciphertext []byte) ([]byte, error)
    23  }
    24  
    25  // Algorithm returns the name of the algorithm
    26  func (e RSA) Algorithm() string {
    27  	return e.algorithm
    28  }
    29  
    30  // Encrypt implements encrypter. certificate must be a []byte containing the ASN.1 bytes
    31  // of certificate containing an RSA public key.
    32  func (e RSA) Encrypt(certificate interface{}, plaintext []byte, nonce []byte) (*etree.Element, error) {
    33  	cert, ok := certificate.(*x509.Certificate)
    34  	if !ok {
    35  		return nil, ErrIncorrectKeyType("*x.509 certificate")
    36  	}
    37  
    38  	pubKey, ok := cert.PublicKey.(*rsa.PublicKey)
    39  	if !ok {
    40  		return nil, ErrIncorrectKeyType("x.509 certificate with an RSA public key")
    41  	}
    42  
    43  	// generate a key
    44  	key := make([]byte, e.BlockCipher.KeySize())
    45  	if _, err := RandReader.Read(key); err != nil {
    46  		return nil, err
    47  	}
    48  
    49  	keyInfoEl := etree.NewElement("ds:KeyInfo")
    50  	keyInfoEl.CreateAttr("xmlns:ds", "http://www.w3.org/2000/09/xmldsig#")
    51  
    52  	encryptedKey := keyInfoEl.CreateElement("xenc:EncryptedKey")
    53  	{
    54  		randBuf := make([]byte, 16)
    55  		if _, err := RandReader.Read(randBuf); err != nil {
    56  			return nil, err
    57  		}
    58  		encryptedKey.CreateAttr("Id", fmt.Sprintf("_%x", randBuf))
    59  	}
    60  	encryptedKey.CreateAttr("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#")
    61  
    62  	encryptionMethodEl := encryptedKey.CreateElement("xenc:EncryptionMethod")
    63  	encryptionMethodEl.CreateAttr("Algorithm", e.algorithm)
    64  	encryptionMethodEl.CreateAttr("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#")
    65  	if e.DigestMethod != nil {
    66  		dm := encryptionMethodEl.CreateElement("ds:DigestMethod")
    67  		dm.CreateAttr("Algorithm", e.DigestMethod.Algorithm())
    68  		dm.CreateAttr("xmlns:ds", "http://www.w3.org/2000/09/xmldsig#")
    69  	}
    70  	{
    71  		innerKeyInfoEl := encryptedKey.CreateElement("ds:KeyInfo")
    72  		x509data := innerKeyInfoEl.CreateElement("ds:X509Data")
    73  		x509data.CreateElement("ds:X509Certificate").SetText(
    74  			base64.StdEncoding.EncodeToString(cert.Raw),
    75  		)
    76  	}
    77  
    78  	buf, err := e.keyEncrypter(e, pubKey, key)
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  
    83  	cd := encryptedKey.CreateElement("xenc:CipherData")
    84  	cd.CreateAttr("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#")
    85  	cd.CreateElement("xenc:CipherValue").SetText(base64.StdEncoding.EncodeToString(buf))
    86  	encryptedDataEl, err := e.BlockCipher.Encrypt(key, plaintext, nonce)
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  	encryptedDataEl.InsertChildAt(encryptedDataEl.FindElement("./CipherData").Index(), keyInfoEl)
    91  
    92  	return encryptedDataEl, nil
    93  }
    94  
    95  // Decrypt implements Decryptor. `key` must be an *rsa.PrivateKey.
    96  func (e RSA) Decrypt(key interface{}, ciphertextEl *etree.Element) ([]byte, error) {
    97  	rsaKey, err := validateRSAKeyIfPresent(key, ciphertextEl)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  
   102  	ciphertext, err := getCiphertext(ciphertextEl)
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  
   107  	{
   108  		digestMethodEl := ciphertextEl.FindElement("./EncryptionMethod/DigestMethod")
   109  		if digestMethodEl == nil {
   110  			e.DigestMethod = SHA1
   111  		} else {
   112  			hashAlgorithmStr := digestMethodEl.SelectAttrValue("Algorithm", "")
   113  			digestMethod, ok := digestMethods[hashAlgorithmStr]
   114  			if !ok {
   115  				return nil, ErrAlgorithmNotImplemented(hashAlgorithmStr)
   116  			}
   117  			e.DigestMethod = digestMethod
   118  		}
   119  	}
   120  
   121  	return e.keyDecrypter(e, rsaKey, ciphertext)
   122  }
   123  
   124  // OAEP returns a version of RSA that implements RSA in OAEP-MGF1P mode. By default
   125  // the block cipher used is AES-256 CBC and the digest method is SHA-256. You can
   126  // specify other ciphers and digest methods by assigning to BlockCipher or
   127  // DigestMethod.
   128  func OAEP() RSA {
   129  	return RSA{
   130  		BlockCipher:  AES256CBC,
   131  		DigestMethod: SHA256,
   132  		algorithm:    "http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p",
   133  		keyEncrypter: func(e RSA, pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) {
   134  			return rsa.EncryptOAEP(e.DigestMethod.Hash(), RandReader, pubKey, plaintext, nil)
   135  		},
   136  		keyDecrypter: func(e RSA, privKey *rsa.PrivateKey, ciphertext []byte) ([]byte, error) {
   137  			return rsa.DecryptOAEP(e.DigestMethod.Hash(), RandReader, privKey, ciphertext, nil)
   138  		},
   139  	}
   140  }
   141  
   142  // PKCS1v15 returns a version of RSA that implements RSA in PKCS1v15 mode. By default
   143  // the block cipher used is AES-256 CBC. The DigestMethod field is ignored because PKCS1v15
   144  // does not use a digest function.
   145  func PKCS1v15() RSA {
   146  	return RSA{
   147  		BlockCipher:  AES256CBC,
   148  		DigestMethod: nil,
   149  		algorithm:    "http://www.w3.org/2001/04/xmlenc#rsa-1_5",
   150  		keyEncrypter: func(e RSA, pubKey *rsa.PublicKey, plaintext []byte) ([]byte, error) {
   151  			return rsa.EncryptPKCS1v15(RandReader, pubKey, plaintext)
   152  		},
   153  		keyDecrypter: func(e RSA, privKey *rsa.PrivateKey, ciphertext []byte) ([]byte, error) {
   154  			return rsa.DecryptPKCS1v15(RandReader, privKey, ciphertext)
   155  		},
   156  	}
   157  }
   158  
   159  func init() {
   160  	RegisterDecrypter(OAEP())
   161  	RegisterDecrypter(PKCS1v15())
   162  }