github.com/crewjam/saml@v0.4.14/xmlenc/cbc.go (about)

     1  package xmlenc
     2  
     3  import (
     4  	"crypto/aes"
     5  	"crypto/cipher"
     6  	"crypto/des" // nolint: gas
     7  	"encoding/base64"
     8  	"errors"
     9  	"fmt"
    10  
    11  	"github.com/beevik/etree"
    12  )
    13  
    14  // CBC implements Decrypter and Encrypter for block ciphers in CBC mode
    15  type CBC struct {
    16  	keySize   int
    17  	algorithm string
    18  	cipher    func([]byte) (cipher.Block, error)
    19  }
    20  
    21  // KeySize returns the length of the key required.
    22  func (e CBC) KeySize() int {
    23  	return e.keySize
    24  }
    25  
    26  // Algorithm returns the name of the algorithm, as will be found
    27  // in an xenc:EncryptionMethod element.
    28  func (e CBC) Algorithm() string {
    29  	return e.algorithm
    30  }
    31  
    32  // Encrypt encrypts plaintext with key, which should be a []byte of length KeySize().
    33  // It returns an xenc:EncryptedData element.
    34  func (e CBC) Encrypt(key interface{}, plaintext []byte, _ []byte) (*etree.Element, error) {
    35  	keyBuf, ok := key.([]byte)
    36  	if !ok {
    37  		return nil, ErrIncorrectKeyType("[]byte")
    38  	}
    39  	if len(keyBuf) != e.keySize {
    40  		return nil, ErrIncorrectKeyLength(e.keySize)
    41  	}
    42  
    43  	block, err := e.cipher(keyBuf)
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  
    48  	encryptedDataEl := etree.NewElement("xenc:EncryptedData")
    49  	encryptedDataEl.CreateAttr("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#")
    50  	{
    51  		randBuf := make([]byte, 16)
    52  		if _, err := RandReader.Read(randBuf); err != nil {
    53  			return nil, err
    54  		}
    55  		encryptedDataEl.CreateAttr("Id", fmt.Sprintf("_%x", randBuf))
    56  	}
    57  
    58  	em := encryptedDataEl.CreateElement("xenc:EncryptionMethod")
    59  	em.CreateAttr("Algorithm", e.algorithm)
    60  	em.CreateAttr("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#")
    61  
    62  	plaintext = appendPadding(plaintext, block.BlockSize())
    63  
    64  	iv := make([]byte, block.BlockSize())
    65  	if _, err := RandReader.Read(iv); err != nil {
    66  		return nil, err
    67  	}
    68  
    69  	mode := cipher.NewCBCEncrypter(block, iv)
    70  	ciphertext := make([]byte, len(plaintext))
    71  	mode.CryptBlocks(ciphertext, plaintext)
    72  	ciphertext = append(iv, ciphertext...)
    73  
    74  	cd := encryptedDataEl.CreateElement("xenc:CipherData")
    75  	cd.CreateAttr("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#")
    76  	cd.CreateElement("xenc:CipherValue").SetText(base64.StdEncoding.EncodeToString(ciphertext))
    77  	return encryptedDataEl, nil
    78  }
    79  
    80  // Decrypt decrypts an encrypted element with key. If the ciphertext contains an
    81  // EncryptedKey element, then the type of `key` is determined by the registered
    82  // Decryptor for the EncryptedKey element. Otherwise, `key` must be a []byte of
    83  // length KeySize().
    84  func (e CBC) Decrypt(key interface{}, ciphertextEl *etree.Element) ([]byte, error) {
    85  	// If the key is encrypted, decrypt it.
    86  	if encryptedKeyEl := ciphertextEl.FindElement("./KeyInfo/EncryptedKey"); encryptedKeyEl != nil {
    87  		var err error
    88  		key, err = Decrypt(key, encryptedKeyEl)
    89  		if err != nil {
    90  			return nil, err
    91  		}
    92  	}
    93  
    94  	keyBuf, ok := key.([]byte)
    95  	if !ok {
    96  		return nil, ErrIncorrectKeyType("[]byte")
    97  	}
    98  	if len(keyBuf) != e.KeySize() {
    99  		return nil, ErrIncorrectKeyLength(e.KeySize())
   100  	}
   101  
   102  	block, err := e.cipher(keyBuf)
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  
   107  	ciphertext, err := getCiphertext(ciphertextEl)
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  
   112  	if len(ciphertext) < block.BlockSize() {
   113  		return nil, errors.New("ciphertext too short")
   114  	}
   115  
   116  	iv := ciphertext[:aes.BlockSize]
   117  	ciphertext = ciphertext[aes.BlockSize:]
   118  
   119  	mode := cipher.NewCBCDecrypter(block, iv)
   120  	plaintext := make([]byte, len(ciphertext))
   121  	mode.CryptBlocks(plaintext, ciphertext) // decrypt in place
   122  
   123  	plaintext, err = stripPadding(plaintext)
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  
   128  	return plaintext, nil
   129  }
   130  
   131  var (
   132  	// AES128CBC implements AES128-CBC symetric key mode for encryption and decryption
   133  	AES128CBC BlockCipher = CBC{
   134  		keySize:   16,
   135  		algorithm: "http://www.w3.org/2001/04/xmlenc#aes128-cbc",
   136  		cipher:    aes.NewCipher,
   137  	}
   138  
   139  	// AES192CBC implements AES192-CBC symetric key mode for encryption and decryption
   140  	AES192CBC BlockCipher = CBC{
   141  		keySize:   24,
   142  		algorithm: "http://www.w3.org/2001/04/xmlenc#aes192-cbc",
   143  		cipher:    aes.NewCipher,
   144  	}
   145  
   146  	// AES256CBC implements AES256-CBC symetric key mode for encryption and decryption
   147  	AES256CBC BlockCipher = CBC{
   148  		keySize:   32,
   149  		algorithm: "http://www.w3.org/2001/04/xmlenc#aes256-cbc",
   150  		cipher:    aes.NewCipher,
   151  	}
   152  
   153  	// TripleDES implements 3DES in CBC mode for encryption and decryption
   154  	TripleDES BlockCipher = CBC{
   155  		keySize:   8,
   156  		algorithm: "http://www.w3.org/2001/04/xmlenc#tripledes-cbc",
   157  		cipher:    des.NewCipher,
   158  	}
   159  )
   160  
   161  func init() {
   162  	RegisterDecrypter(AES128CBC)
   163  	RegisterDecrypter(AES192CBC)
   164  	RegisterDecrypter(AES256CBC)
   165  	RegisterDecrypter(TripleDES)
   166  }
   167  
   168  func appendPadding(buf []byte, blockSize int) []byte {
   169  	paddingBytes := blockSize - (len(buf) % blockSize)
   170  	padding := make([]byte, paddingBytes)
   171  	padding[len(padding)-1] = byte(paddingBytes)
   172  	return append(buf, padding...)
   173  }
   174  
   175  func stripPadding(buf []byte) ([]byte, error) {
   176  	if len(buf) < 1 {
   177  		return nil, errors.New("buffer is too short for padding")
   178  	}
   179  	paddingBytes := int(buf[len(buf)-1])
   180  	if paddingBytes > len(buf)-1 {
   181  		return nil, errors.New("buffer is too short for padding")
   182  	}
   183  	if paddingBytes < 1 {
   184  		return nil, errors.New("padding must be at least one byte")
   185  	}
   186  	return buf[:len(buf)-paddingBytes], nil
   187  }