github.com/crewjam/saml@v0.4.14/xmlenc/decrypt.go (about)

     1  package xmlenc
     2  
     3  import (
     4  
     5  	// nolint: gas
     6  	"crypto/rsa"
     7  	"crypto/x509"
     8  	"encoding/base64"
     9  	"encoding/pem"
    10  	"errors"
    11  	"fmt"
    12  
    13  	"strings"
    14  
    15  	"github.com/beevik/etree"
    16  )
    17  
    18  // ErrAlgorithmNotImplemented is returned when encryption used is not
    19  // supported.
    20  type ErrAlgorithmNotImplemented string
    21  
    22  func (e ErrAlgorithmNotImplemented) Error() string {
    23  	return "algorithm is not implemented: " + string(e)
    24  }
    25  
    26  // ErrCannotFindRequiredElement is returned by Decrypt when a required
    27  // element cannot be found.
    28  type ErrCannotFindRequiredElement string
    29  
    30  func (e ErrCannotFindRequiredElement) Error() string {
    31  	return "cannot find required element: " + string(e)
    32  }
    33  
    34  // ErrIncorrectTag is returned when Decrypt is passed an element which
    35  // is neither an EncryptedType nor an EncryptedKey
    36  var ErrIncorrectTag = fmt.Errorf("tag must be an EncryptedType or EncryptedKey")
    37  
    38  // ErrIncorrectKeyLength is returned when the fixed length key is not
    39  // of the required length.
    40  type ErrIncorrectKeyLength int
    41  
    42  func (e ErrIncorrectKeyLength) Error() string {
    43  	return fmt.Sprintf("expected key to be %d bytes", int(e))
    44  }
    45  
    46  // ErrIncorrectKeyType is returned when the key is not the correct type
    47  type ErrIncorrectKeyType string
    48  
    49  func (e ErrIncorrectKeyType) Error() string {
    50  	return fmt.Sprintf("expected key to be %s", string(e))
    51  }
    52  
    53  // Decrypt decrypts the encrypted data using the provided key. If the
    54  // data are encrypted using AES or 3DEC, then the key should be a []byte.
    55  // If the data are encrypted with PKCS1v15 or RSA-OAEP-MGF1P then key should
    56  // be a *rsa.PrivateKey.
    57  func Decrypt(key interface{}, ciphertextEl *etree.Element) ([]byte, error) {
    58  	encryptionMethodEl := ciphertextEl.FindElement("./EncryptionMethod")
    59  	if encryptionMethodEl == nil {
    60  		return nil, ErrCannotFindRequiredElement("EncryptionMethod")
    61  	}
    62  	algorithm := encryptionMethodEl.SelectAttrValue("Algorithm", "")
    63  	decrypter, ok := decrypters[algorithm]
    64  	if !ok {
    65  		return nil, ErrAlgorithmNotImplemented(algorithm)
    66  	}
    67  	return decrypter.Decrypt(key, ciphertextEl)
    68  }
    69  
    70  func getCiphertext(encryptedKey *etree.Element) ([]byte, error) {
    71  	ciphertextEl := encryptedKey.FindElement("./CipherData/CipherValue")
    72  	if ciphertextEl == nil {
    73  		return nil, fmt.Errorf("cannot find CipherData element containing a CipherValue element")
    74  	}
    75  	ciphertext, err := base64.StdEncoding.DecodeString(strings.TrimSpace(ciphertextEl.Text()))
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  	return ciphertext, nil
    80  }
    81  
    82  func validateRSAKeyIfPresent(key interface{}, encryptedKey *etree.Element) (*rsa.PrivateKey, error) {
    83  	rsaKey, ok := key.(*rsa.PrivateKey)
    84  	if !ok {
    85  		return nil, errors.New("expected key to be a *rsa.PrivateKey")
    86  	}
    87  
    88  	// extract and verify that the public key matches the certificate
    89  	// this section is included to either let the service know up front
    90  	// if the key will work, or let the service provider know which key
    91  	// to use to decrypt the message. Either way, verification is not
    92  	// security-critical.
    93  	//nolint:revive,staticcheck // Keep the later empty branch so that we know to address this at a later date.
    94  	if el := encryptedKey.FindElement("./KeyInfo/X509Data/X509Certificate"); el != nil {
    95  		certPEMbuf := el.Text()
    96  		certPEMbuf = "-----BEGIN CERTIFICATE-----\n" + certPEMbuf + "\n-----END CERTIFICATE-----\n"
    97  		certPEM, _ := pem.Decode([]byte(certPEMbuf))
    98  		if certPEM == nil {
    99  			return nil, fmt.Errorf("invalid certificate")
   100  		}
   101  		cert, err := x509.ParseCertificate(certPEM.Bytes)
   102  		if err != nil {
   103  			return nil, err
   104  		}
   105  		pubKey, ok := cert.PublicKey.(*rsa.PublicKey)
   106  		if !ok {
   107  			return nil, fmt.Errorf("expected certificate to be an *rsa.PublicKey")
   108  		}
   109  		if rsaKey.N.Cmp(pubKey.N) != 0 || rsaKey.E != pubKey.E {
   110  			return nil, fmt.Errorf("certificate does not match provided key")
   111  		}
   112  	} else if el = encryptedKey.FindElement("./KeyInfo/X509Data/X509IssuerSerial"); el != nil {
   113  		// TODO: determine how to validate the issuer serial information
   114  	}
   115  	return rsaKey, nil
   116  }