github.com/criteo/command-launcher@v0.0.0-20230407142452-fb616f546e98/internal/gvault/file-vault.go (about)

     1  package vault
     2  
     3  import (
     4  	"crypto/aes"
     5  	"crypto/cipher"
     6  	"crypto/rand"
     7  	"crypto/sha256"
     8  	"encoding/json"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"os"
    13  	"path/filepath"
    14  )
    15  
    16  type Dico map[string]string
    17  
    18  type FileVault struct {
    19  	Name string
    20  	hash []byte
    21  }
    22  
    23  func (fv *FileVault) Write(key string, value string) error {
    24  	vaultDir, err := maybeCreateDir()
    25  	if err != nil {
    26  		return err
    27  	}
    28  
    29  	dico, err := fv.readFile()
    30  	if err != nil {
    31  		return err
    32  	}
    33  
    34  	dico[key] = value
    35  	data, err := json.Marshal(dico)
    36  	if err != nil {
    37  		return err
    38  	}
    39  
    40  	encrypted, err := fv.encrypt(data)
    41  	if err != nil {
    42  		return err
    43  	}
    44  
    45  	return ioutil.WriteFile(filepath.Join(vaultDir, fv.Name), encrypted, 0600)
    46  }
    47  
    48  func (fv *FileVault) Read(key string) (string, error) {
    49  	dico, err := fv.readFile()
    50  	if err != nil {
    51  		return "", err
    52  	}
    53  
    54  	if len(dico) == 0 {
    55  		return "", fmt.Errorf("vault %s is empty", fv.Name)
    56  	}
    57  
    58  	return dico[key], nil
    59  }
    60  
    61  func (fv *FileVault) readFile() (Dico, error) {
    62  	dico := make(Dico)
    63  	vaultDir, err := maybeCreateDir()
    64  	if err != nil {
    65  		return dico, err
    66  	}
    67  
    68  	encrypted, err := ioutil.ReadFile(filepath.Join(vaultDir, fv.Name))
    69  	if err != nil {
    70  		return dico, err
    71  	}
    72  
    73  	if len(encrypted) == 0 {
    74  		return dico, err
    75  	}
    76  
    77  	data, err := fv.decrypt(encrypted)
    78  	if err != nil {
    79  		return dico, err
    80  	}
    81  
    82  	err = json.Unmarshal(data, &dico)
    83  	if err != nil {
    84  		return dico, err
    85  	}
    86  
    87  	return dico, nil
    88  }
    89  
    90  func (fv *FileVault) init() error {
    91  	hash, err := readSecret()
    92  	if err != nil {
    93  		return err
    94  	}
    95  	fv.hash = hash
    96  
    97  	dirVault, err := maybeCreateDir()
    98  	if err != nil {
    99  		return err
   100  	}
   101  
   102  	_, err = os.Stat(filepath.Join(dirVault, fv.Name))
   103  	if os.IsNotExist(err) {
   104  		_, err = os.OpenFile(filepath.Join(dirVault, fv.Name), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
   105  		if err != nil {
   106  			return err
   107  		}
   108  	}
   109  
   110  	return nil
   111  }
   112  
   113  func (fv *FileVault) decrypt(encrypted []byte) ([]byte, error) {
   114  	cphr, err := aes.NewCipher(fv.hash)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  
   119  	gcm, err := cipher.NewGCM(cphr)
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  
   124  	nonceSize := gcm.NonceSize()
   125  	nonce, msg := encrypted[:nonceSize], encrypted[nonceSize:]
   126  	data, err := gcm.Open(nil, nonce, msg, nil)
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  
   131  	return data, nil
   132  }
   133  
   134  func (fv *FileVault) encrypt(data []byte) ([]byte, error) {
   135  	cphr, err := aes.NewCipher(fv.hash)
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  
   140  	gcm, err := cipher.NewGCM(cphr)
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  
   145  	nonce := make([]byte, gcm.NonceSize())
   146  	if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
   147  		return nil, err
   148  	}
   149  
   150  	encrypted := gcm.Seal(nonce, nonce, data, nil)
   151  
   152  	return encrypted, nil
   153  }
   154  
   155  func readSecret() ([]byte, error) {
   156  	// first get the secret from environment variable
   157  	secret := os.Getenv("CDT_VAULT_SECRET")
   158  	if secret != "" {
   159  		hash := sha256.Sum256([]byte(secret))
   160  		return hash[:], nil
   161  	}
   162  
   163  	empty := []byte{}
   164  	homedir, err := os.UserHomeDir()
   165  	if err != nil {
   166  		return empty, err
   167  	}
   168  
   169  	sshDir := filepath.Join(homedir, ".ssh")
   170  	_, err = os.Stat(sshDir)
   171  	if err != nil {
   172  		return empty, err
   173  	}
   174  
   175  	// get the secret file from environment variable
   176  	secretFile := os.Getenv("CDT_VAULT_SECRET_FILE")
   177  	if secretFile == "" {
   178  		secretFile = filepath.Join(sshDir, "id_rsa")
   179  	}
   180  
   181  	// in case environment variable missing, fallback to default
   182  	data, err := ioutil.ReadFile(secretFile)
   183  	if err != nil {
   184  		return empty, err
   185  	}
   186  
   187  	hash := sha256.Sum256(data)
   188  	return hash[:], nil
   189  }
   190  
   191  func maybeCreateDir() (string, error) {
   192  	homedir, err := os.UserHomeDir()
   193  	if err != nil {
   194  		return "", err
   195  	}
   196  
   197  	dirVault := filepath.Join(homedir, ".file-vault")
   198  	_, err = os.Stat(dirVault)
   199  	if os.IsNotExist(err) {
   200  		err := os.MkdirAll(dirVault, 0700)
   201  		if err != nil {
   202  			return "", err
   203  		}
   204  	}
   205  
   206  	return dirVault, nil
   207  }