github.com/unionj-cloud/go-doudou/v2@v2.3.5/toolkit/memberlist/security.go (about)

     1  package memberlist
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/aes"
     6  	"crypto/cipher"
     7  	"crypto/rand"
     8  	"fmt"
     9  	"io"
    10  )
    11  
    12  /*
    13  
    14  Encrypted messages are prefixed with an encryptionVersion byte
    15  that is used for us to be able to properly encode/decode. We
    16  currently support the following versions:
    17  
    18   0 - AES-GCM 128, using PKCS7 padding
    19   1 - AES-GCM 128, no padding. Padding not needed, caused bloat.
    20  
    21  */
    22  type encryptionVersion uint8
    23  
    24  const (
    25  	minEncryptionVersion encryptionVersion = 0
    26  	maxEncryptionVersion encryptionVersion = 1
    27  )
    28  
    29  const (
    30  	versionSize    = 1
    31  	nonceSize      = 12
    32  	tagSize        = 16
    33  	maxPadOverhead = 16
    34  	blockSize      = aes.BlockSize
    35  )
    36  
    37  // pkcs7encode is used to pad a byte buffer to a specific block size using
    38  // the PKCS7 algorithm. "Ignores" some bytes to compensate for IV
    39  func pkcs7encode(buf *bytes.Buffer, ignore, blockSize int) {
    40  	n := buf.Len() - ignore
    41  	more := blockSize - (n % blockSize)
    42  	for i := 0; i < more; i++ {
    43  		buf.WriteByte(byte(more))
    44  	}
    45  }
    46  
    47  // pkcs7decode is used to decode a buffer that has been padded
    48  func pkcs7decode(buf []byte, blockSize int) []byte {
    49  	if len(buf) == 0 {
    50  		panic("Cannot decode a PKCS7 buffer of zero length")
    51  	}
    52  	n := len(buf)
    53  	last := buf[n-1]
    54  	n -= int(last)
    55  	return buf[:n]
    56  }
    57  
    58  // encryptOverhead returns the maximum possible overhead of encryption by version
    59  func encryptOverhead(vsn encryptionVersion) int {
    60  	switch vsn {
    61  	case 0:
    62  		return 45 // Version: 1, IV: 12, Padding: 16, Tag: 16
    63  	case 1:
    64  		return 29 // Version: 1, IV: 12, Tag: 16
    65  	default:
    66  		panic("unsupported version")
    67  	}
    68  }
    69  
    70  // encryptedLength is used to compute the buffer size needed
    71  // for a message of given length
    72  func encryptedLength(vsn encryptionVersion, inp int) int {
    73  	// If we are on version 1, there is no padding
    74  	if vsn >= 1 {
    75  		return versionSize + nonceSize + inp + tagSize
    76  	}
    77  
    78  	// Determine the padding size
    79  	padding := blockSize - (inp % blockSize)
    80  
    81  	// Sum the extra parts to get total size
    82  	return versionSize + nonceSize + inp + padding + tagSize
    83  }
    84  
    85  // encryptPayload is used to encrypt a message with a given key.
    86  // We make use of AES-128 in GCM mode. New byte buffer is the version,
    87  // nonce, ciphertext and tag
    88  func encryptPayload(vsn encryptionVersion, key []byte, msg []byte, data []byte, dst *bytes.Buffer) error {
    89  	// Get the AES block cipher
    90  	aesBlock, err := aes.NewCipher(key)
    91  	if err != nil {
    92  		return err
    93  	}
    94  
    95  	// Get the GCM cipher mode
    96  	gcm, err := cipher.NewGCM(aesBlock)
    97  	if err != nil {
    98  		return err
    99  	}
   100  
   101  	// Grow the buffer to make room for everything
   102  	offset := dst.Len()
   103  	dst.Grow(encryptedLength(vsn, len(msg)))
   104  
   105  	// Write the encryption version
   106  	dst.WriteByte(byte(vsn))
   107  
   108  	// Add a random nonce
   109  	_, err = io.CopyN(dst, rand.Reader, nonceSize)
   110  	if err != nil {
   111  		return err
   112  	}
   113  	afterNonce := dst.Len()
   114  
   115  	// Ensure we are correctly padded (only version 0)
   116  	if vsn == 0 {
   117  		io.Copy(dst, bytes.NewReader(msg))
   118  		pkcs7encode(dst, offset+versionSize+nonceSize, aes.BlockSize)
   119  	}
   120  
   121  	// Encrypt message using GCM
   122  	slice := dst.Bytes()[offset:]
   123  	nonce := slice[versionSize : versionSize+nonceSize]
   124  
   125  	// Message source depends on the encryption version.
   126  	// Version 0 uses padding, version 1 does not
   127  	var src []byte
   128  	if vsn == 0 {
   129  		src = slice[versionSize+nonceSize:]
   130  	} else {
   131  		src = msg
   132  	}
   133  	out := gcm.Seal(nil, nonce, src, data)
   134  
   135  	// Truncate the plaintext, and write the cipher text
   136  	dst.Truncate(afterNonce)
   137  	dst.Write(out)
   138  	return nil
   139  }
   140  
   141  // decryptMessage performs the actual decryption of ciphertext. This is in its
   142  // own function to allow it to be called on all keys easily.
   143  func decryptMessage(key, msg []byte, data []byte) ([]byte, error) {
   144  	// Get the AES block cipher
   145  	aesBlock, err := aes.NewCipher(key)
   146  	if err != nil {
   147  		return nil, err
   148  	}
   149  
   150  	// Get the GCM cipher mode
   151  	gcm, err := cipher.NewGCM(aesBlock)
   152  	if err != nil {
   153  		return nil, err
   154  	}
   155  
   156  	// Decrypt the message
   157  	nonce := msg[versionSize : versionSize+nonceSize]
   158  	ciphertext := msg[versionSize+nonceSize:]
   159  	plain, err := gcm.Open(nil, nonce, ciphertext, data)
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  
   164  	// Success!
   165  	return plain, nil
   166  }
   167  
   168  // decryptPayload is used to decrypt a message with a given key,
   169  // and verify it's contents. Any padding will be removed, and a
   170  // slice to the plaintext is returned. Decryption is done IN PLACE!
   171  func decryptPayload(keys [][]byte, msg []byte, data []byte) ([]byte, error) {
   172  	// Ensure we have at least one byte
   173  	if len(msg) == 0 {
   174  		return nil, fmt.Errorf("Cannot decrypt empty payload")
   175  	}
   176  
   177  	// Verify the version
   178  	vsn := encryptionVersion(msg[0])
   179  	if vsn > maxEncryptionVersion {
   180  		return nil, fmt.Errorf("Unsupported encryption version %d", msg[0])
   181  	}
   182  
   183  	// Ensure the length is sane
   184  	if len(msg) < encryptedLength(vsn, 0) {
   185  		return nil, fmt.Errorf("Payload is too small to decrypt: %d", len(msg))
   186  	}
   187  
   188  	for _, key := range keys {
   189  		plain, err := decryptMessage(key, msg, data)
   190  		if err == nil {
   191  			// Remove the PKCS7 padding for vsn 0
   192  			if vsn == 0 {
   193  				return pkcs7decode(plain, aes.BlockSize), nil
   194  			} else {
   195  				return plain, nil
   196  			}
   197  		}
   198  	}
   199  
   200  	return nil, fmt.Errorf("No installed keys could decrypt the message")
   201  }