github.com/crewjam/saml@v0.4.14/xmlenc/gcm.go (about)

     1  package xmlenc
     2  
     3  import (
     4  	"crypto/aes"
     5  	"crypto/cipher"
     6  	"crypto/rand"
     7  	"encoding/base64"
     8  	"fmt"
     9  	"io"
    10  
    11  	"github.com/beevik/etree"
    12  )
    13  
    14  // GCM implements Decrypter and Encrypter for block ciphers in struct mode
    15  type GCM struct {
    16  	keySize   int
    17  	algorithm string
    18  	cipher    func([]byte) (cipher.Block, error)
    19  }
    20  
    21  // KeySize returns the length of the key required.
    22  func (e GCM) KeySize() int {
    23  	return e.keySize
    24  }
    25  
    26  // Algorithm returns the name of the algorithm, as will be found
    27  // in an xenc:EncryptionMethod element.
    28  func (e GCM) Algorithm() string {
    29  	return e.algorithm
    30  }
    31  
    32  // Encrypt encrypts plaintext with key and nonce
    33  func (e GCM) Encrypt(key interface{}, plaintext []byte, nonce []byte) (*etree.Element, error) {
    34  	keyBuf, ok := key.([]byte)
    35  	if !ok {
    36  		return nil, ErrIncorrectKeyType("[]byte")
    37  	}
    38  	if len(keyBuf) != e.keySize {
    39  		return nil, ErrIncorrectKeyLength(e.keySize)
    40  	}
    41  
    42  	block, err := e.cipher(keyBuf)
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  
    47  	encryptedDataEl := etree.NewElement("xenc:EncryptedData")
    48  	encryptedDataEl.CreateAttr("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#")
    49  	{
    50  		randBuf := make([]byte, 16)
    51  		if _, err := RandReader.Read(randBuf); err != nil {
    52  			return nil, err
    53  		}
    54  		encryptedDataEl.CreateAttr("Id", fmt.Sprintf("_%x", randBuf))
    55  	}
    56  
    57  	em := encryptedDataEl.CreateElement("xenc:EncryptionMethod")
    58  	em.CreateAttr("Algorithm", e.algorithm)
    59  	em.CreateAttr("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#")
    60  
    61  	plaintext = appendPadding(plaintext, block.BlockSize())
    62  
    63  	aesgcm, err := cipher.NewGCM(block)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  
    68  	if nonce == nil {
    69  		// generate random nonce when it's nil
    70  		nonce := make([]byte, aesgcm.NonceSize())
    71  		if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
    72  			panic(err.Error())
    73  		}
    74  	}
    75  
    76  	ciphertext := make([]byte, len(plaintext))
    77  	text := aesgcm.Seal(nil, nonce, ciphertext, nil)
    78  
    79  	cd := encryptedDataEl.CreateElement("xenc:CipherData")
    80  	cd.CreateAttr("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#")
    81  	cd.CreateElement("xenc:CipherValue").SetText(base64.StdEncoding.EncodeToString(text))
    82  	return encryptedDataEl, nil
    83  }
    84  
    85  // Decrypt decrypts an encrypted element with key. If the ciphertext contains an
    86  // EncryptedKey element, then the type of `key` is determined by the registered
    87  // Decryptor for the EncryptedKey element. Otherwise, `key` must be a []byte of
    88  // length KeySize().
    89  func (e GCM) Decrypt(key interface{}, ciphertextEl *etree.Element) ([]byte, error) {
    90  	if encryptedKeyEl := ciphertextEl.FindElement("./KeyInfo/EncryptedKey"); encryptedKeyEl != nil {
    91  		var err error
    92  		key, err = Decrypt(key, encryptedKeyEl)
    93  		if err != nil {
    94  			return nil, err
    95  		}
    96  	}
    97  
    98  	keyBuf, ok := key.([]byte)
    99  
   100  	if !ok {
   101  		return nil, ErrIncorrectKeyType("[]byte")
   102  	}
   103  	if len(keyBuf) != e.KeySize() {
   104  		return nil, ErrIncorrectKeyLength(e.KeySize())
   105  	}
   106  
   107  	block, err := e.cipher(keyBuf)
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  
   112  	aesgcm, err := cipher.NewGCM(block)
   113  	if err != nil {
   114  		return nil, err
   115  	}
   116  
   117  	ciphertext, err := getCiphertext(ciphertextEl)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  
   122  	nonce := ciphertext[:aesgcm.NonceSize()]
   123  	text := ciphertext[aesgcm.NonceSize():]
   124  
   125  	plainText, err := aesgcm.Open(nil, nonce, text, nil)
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  	return plainText, nil
   130  }
   131  
   132  var (
   133  	// AES128GCM implements AES128-GCM mode for encryption and decryption
   134  	AES128GCM BlockCipher = GCM{
   135  		keySize:   16,
   136  		algorithm: "http://www.w3.org/2009/xmlenc11#aes128-gcm",
   137  		cipher:    aes.NewCipher,
   138  	}
   139  )
   140  
   141  func init() {
   142  	RegisterDecrypter(AES128GCM)
   143  }