github.com/gitbundle/modules@v0.0.0-20231025071548-85b91c5c3b01/encrypt/ssh.go (about)

     1  // Copyright 2023 The GitBundle Inc. All rights reserved.
     2  // Copyright 2017 The Gitea Authors. All rights reserved.
     3  // Use of this source code is governed by a MIT-style
     4  // license that can be found in the LICENSE file.
     5  
     6  package encrypt
     7  
     8  import (
     9  	"bytes"
    10  	"crypto/aes"
    11  	"crypto/cipher"
    12  	"crypto/rand"
    13  	"crypto/rsa"
    14  	"crypto/sha256"
    15  	"crypto/x509"
    16  	"encoding/base64"
    17  	"encoding/binary"
    18  	"encoding/pem"
    19  	"errors"
    20  	"fmt"
    21  
    22  	"golang.org/x/crypto/ssh"
    23  )
    24  
    25  func GenerateSshKeyPairs() (string, string, error) {
    26  	reader := rand.Reader
    27  	bitSize := 2048
    28  
    29  	key, err := rsa.GenerateKey(reader, bitSize)
    30  	if err != nil {
    31  		return "", "", fmt.Errorf("encrypt: failed to generate random key (%v)", err)
    32  	}
    33  
    34  	pub, err := ssh.NewPublicKey(key.Public())
    35  	if err != nil {
    36  		return "", "", fmt.Errorf("encrypt: failed to create public key (%v)", err)
    37  	}
    38  	pubKeyStr := string(ssh.MarshalAuthorizedKey(pub))
    39  	privKeyStr := marshalRSAPrivate(key)
    40  
    41  	return pubKeyStr, privKeyStr, nil
    42  }
    43  
    44  func EncryptWithSshKey(plainText, publicKey []byte) (string, error) {
    45  	parsed, _, _, _, err := ssh.ParseAuthorizedKey(publicKey)
    46  	if err != nil {
    47  		return "", fmt.Errorf("encrypt: failed to parse authorized key (%v)", err)
    48  	}
    49  	// To get back to an *rsa.PublicKey, we need to first upgrade to the
    50  	// ssh.CryptoPublicKey interface
    51  	parsedCryptoKey := parsed.(ssh.CryptoPublicKey)
    52  
    53  	// Then, we can call CryptoPublicKey() to get the actual crypto.PublicKey
    54  	pubCrypto := parsedCryptoKey.CryptoPublicKey()
    55  
    56  	// Finally, we can convert back to an *rsa.PublicKey
    57  	pub := pubCrypto.(*rsa.PublicKey)
    58  
    59  	if len(plainText) <= 256 {
    60  		// plainText is small enough to only use OAEP encryption; this will result in less bytes to transfer.
    61  		encryptedBytes, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, pub, plainText, nil)
    62  		if err != nil {
    63  			return "", fmt.Errorf("encrypt: failed to encrypt with OAEP method (%v)", err)
    64  		}
    65  		if len(encryptedBytes) != 256 {
    66  			return "", errors.New("encrypt: invalid encrypted data length with OAEP")
    67  		}
    68  		return base64.StdEncoding.EncodeToString(encryptedBytes), nil
    69  	}
    70  
    71  	// otherwise, encrypt using AES256
    72  	key, cipherText, err := encryptAES256(plainText)
    73  	if err != nil {
    74  		return "", err
    75  	}
    76  
    77  	encryptedBytes, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, pub, key, nil)
    78  	if err != nil {
    79  		return "", err
    80  	}
    81  	if len(encryptedBytes) != 256 {
    82  		return "", errors.New("encrypt: invalid encrypted data length")
    83  	}
    84  	return base64.StdEncoding.EncodeToString(append(encryptedBytes, cipherText...)), nil
    85  }
    86  
    87  func DecryptWithSshKey(cipherText, privateKey string) ([]byte, error) {
    88  	data, err := base64.StdEncoding.DecodeString(cipherText)
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	if len(data) < 256 {
    94  		return nil, errors.New("encrypt: not enough data to decrypt")
    95  	}
    96  
    97  	block, _ := pem.Decode([]byte(privateKey))
    98  	key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  
   103  	aesData := data[256:]
   104  	payload, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, key, data[:256], nil)
   105  	if err != nil {
   106  		return nil, fmt.Errorf("encrypt: failed to decrypt with OAEP (%v)", err)
   107  	}
   108  
   109  	if len(aesData) == 0 {
   110  		return payload, nil
   111  	}
   112  
   113  	decryptedAESKey := payload
   114  	decrypted, err := decryptAES(decryptedAESKey, aesData)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  
   119  	return decrypted, nil
   120  }
   121  
   122  func marshalRSAPrivate(priv *rsa.PrivateKey) string {
   123  	return string(pem.EncodeToMemory(&pem.Block{
   124  		Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv),
   125  	}))
   126  }
   127  
   128  // encryptAES256 returns a random passphrase and corresponding bytes encrypted with it
   129  func encryptAES256(data []byte) ([]byte, []byte, error) {
   130  	key := make([]byte, 32)
   131  	if _, err := rand.Read(key); err != nil {
   132  		return nil, nil, fmt.Errorf("encrypt: failed to generate random key (%v)", err)
   133  	}
   134  
   135  	n := len(data)
   136  	buf := new(bytes.Buffer)
   137  	if err := binary.Write(buf, binary.LittleEndian, uint64(n)); err != nil {
   138  		return nil, nil, fmt.Errorf("encrypt: failed to write binary data (%v)", err)
   139  	}
   140  	if _, err := buf.Write(data); err != nil {
   141  		return nil, nil, fmt.Errorf("encrypt: failed to write buffer data (%v)", err)
   142  	}
   143  
   144  	paddingN := aes.BlockSize - (buf.Len() % aes.BlockSize)
   145  	if paddingN > 0 {
   146  		padding := make([]byte, paddingN)
   147  		if _, err := rand.Read(padding); err != nil {
   148  			return nil, nil, fmt.Errorf("encrypt: failed to generate random key with padding (%v)", err)
   149  		}
   150  		if _, err := buf.Write(padding); err != nil {
   151  			return nil, nil, fmt.Errorf("encrypt: failed to write padding buffer (%v)", err)
   152  		}
   153  	}
   154  	plaintext := buf.Bytes()
   155  
   156  	sum := sha256.Sum256(plaintext)
   157  	plaintext = append(sum[:], plaintext...)
   158  
   159  	block, err := aes.NewCipher(key)
   160  	if err != nil {
   161  		return nil, nil, fmt.Errorf("encrypt: failed to create cipher (%v)", err)
   162  	}
   163  
   164  	cipherText := make([]byte, aes.BlockSize+len(plaintext))
   165  	iv := cipherText[:aes.BlockSize]
   166  	if _, err := rand.Read(iv); err != nil {
   167  		return nil, nil, fmt.Errorf("encrypt: failed to generate random key with iv (%v)", err)
   168  	}
   169  
   170  	mode := cipher.NewCBCEncrypter(block, iv)
   171  	mode.CryptBlocks(cipherText[aes.BlockSize:], plaintext)
   172  	return key, cipherText, nil
   173  }
   174  
   175  func decryptAES(key, cipherText []byte) ([]byte, error) {
   176  	block, err := aes.NewCipher(key)
   177  	if err != nil {
   178  		return nil, fmt.Errorf("encrypt: failed to create cipher with key (%v)", err)
   179  	}
   180  
   181  	if len(cipherText) < aes.BlockSize {
   182  		return nil, errors.New("encrypt: cipherText too short to decrypt")
   183  	}
   184  	iv := cipherText[:aes.BlockSize]
   185  	cipherText = cipherText[aes.BlockSize:]
   186  
   187  	if len(cipherText)%aes.BlockSize != 0 {
   188  		return nil, errors.New("encrypt: cipherText is not a multiple of the block size")
   189  	}
   190  
   191  	mode := cipher.NewCBCDecrypter(block, iv)
   192  
   193  	// TODO: works inplace when both args are the same
   194  	mode.CryptBlocks(cipherText, cipherText)
   195  
   196  	expectedSum := cipherText[:32]
   197  	actualSum := sha256.Sum256(cipherText[32:])
   198  	if !bytes.Equal(expectedSum, actualSum[:]) {
   199  		return nil, fmt.Errorf("encrypt: sha256 mismatch %v vs %v", expectedSum, actualSum)
   200  	}
   201  
   202  	buf := bytes.NewReader(cipherText[32:])
   203  	var n uint64
   204  	if err = binary.Read(buf, binary.LittleEndian, &n); err != nil {
   205  		return nil, fmt.Errorf("encrypt: failed to read binary data (%v)", err)
   206  	}
   207  	payload := make([]byte, n)
   208  	if _, err = buf.Read(payload); err != nil {
   209  		return nil, fmt.Errorf("encrypt: failed to read payload data (%v)", err)
   210  	}
   211  
   212  	return payload, nil
   213  }