github.com/criteo/command-launcher@v0.0.0-20230407142452-fb616f546e98/internal/gvault/file-vault.go (about) 1 package vault 2 3 import ( 4 "crypto/aes" 5 "crypto/cipher" 6 "crypto/rand" 7 "crypto/sha256" 8 "encoding/json" 9 "fmt" 10 "io" 11 "io/ioutil" 12 "os" 13 "path/filepath" 14 ) 15 16 type Dico map[string]string 17 18 type FileVault struct { 19 Name string 20 hash []byte 21 } 22 23 func (fv *FileVault) Write(key string, value string) error { 24 vaultDir, err := maybeCreateDir() 25 if err != nil { 26 return err 27 } 28 29 dico, err := fv.readFile() 30 if err != nil { 31 return err 32 } 33 34 dico[key] = value 35 data, err := json.Marshal(dico) 36 if err != nil { 37 return err 38 } 39 40 encrypted, err := fv.encrypt(data) 41 if err != nil { 42 return err 43 } 44 45 return ioutil.WriteFile(filepath.Join(vaultDir, fv.Name), encrypted, 0600) 46 } 47 48 func (fv *FileVault) Read(key string) (string, error) { 49 dico, err := fv.readFile() 50 if err != nil { 51 return "", err 52 } 53 54 if len(dico) == 0 { 55 return "", fmt.Errorf("vault %s is empty", fv.Name) 56 } 57 58 return dico[key], nil 59 } 60 61 func (fv *FileVault) readFile() (Dico, error) { 62 dico := make(Dico) 63 vaultDir, err := maybeCreateDir() 64 if err != nil { 65 return dico, err 66 } 67 68 encrypted, err := ioutil.ReadFile(filepath.Join(vaultDir, fv.Name)) 69 if err != nil { 70 return dico, err 71 } 72 73 if len(encrypted) == 0 { 74 return dico, err 75 } 76 77 data, err := fv.decrypt(encrypted) 78 if err != nil { 79 return dico, err 80 } 81 82 err = json.Unmarshal(data, &dico) 83 if err != nil { 84 return dico, err 85 } 86 87 return dico, nil 88 } 89 90 func (fv *FileVault) init() error { 91 hash, err := readSecret() 92 if err != nil { 93 return err 94 } 95 fv.hash = hash 96 97 dirVault, err := maybeCreateDir() 98 if err != nil { 99 return err 100 } 101 102 _, err = os.Stat(filepath.Join(dirVault, fv.Name)) 103 if os.IsNotExist(err) { 104 _, err = os.OpenFile(filepath.Join(dirVault, fv.Name), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) 105 if err != nil { 106 return err 107 } 108 } 109 110 return nil 111 } 112 113 func (fv *FileVault) decrypt(encrypted []byte) ([]byte, error) { 114 cphr, err := aes.NewCipher(fv.hash) 115 if err != nil { 116 return nil, err 117 } 118 119 gcm, err := cipher.NewGCM(cphr) 120 if err != nil { 121 return nil, err 122 } 123 124 nonceSize := gcm.NonceSize() 125 nonce, msg := encrypted[:nonceSize], encrypted[nonceSize:] 126 data, err := gcm.Open(nil, nonce, msg, nil) 127 if err != nil { 128 return nil, err 129 } 130 131 return data, nil 132 } 133 134 func (fv *FileVault) encrypt(data []byte) ([]byte, error) { 135 cphr, err := aes.NewCipher(fv.hash) 136 if err != nil { 137 return nil, err 138 } 139 140 gcm, err := cipher.NewGCM(cphr) 141 if err != nil { 142 return nil, err 143 } 144 145 nonce := make([]byte, gcm.NonceSize()) 146 if _, err = io.ReadFull(rand.Reader, nonce); err != nil { 147 return nil, err 148 } 149 150 encrypted := gcm.Seal(nonce, nonce, data, nil) 151 152 return encrypted, nil 153 } 154 155 func readSecret() ([]byte, error) { 156 // first get the secret from environment variable 157 secret := os.Getenv("CDT_VAULT_SECRET") 158 if secret != "" { 159 hash := sha256.Sum256([]byte(secret)) 160 return hash[:], nil 161 } 162 163 empty := []byte{} 164 homedir, err := os.UserHomeDir() 165 if err != nil { 166 return empty, err 167 } 168 169 sshDir := filepath.Join(homedir, ".ssh") 170 _, err = os.Stat(sshDir) 171 if err != nil { 172 return empty, err 173 } 174 175 // get the secret file from environment variable 176 secretFile := os.Getenv("CDT_VAULT_SECRET_FILE") 177 if secretFile == "" { 178 secretFile = filepath.Join(sshDir, "id_rsa") 179 } 180 181 // in case environment variable missing, fallback to default 182 data, err := ioutil.ReadFile(secretFile) 183 if err != nil { 184 return empty, err 185 } 186 187 hash := sha256.Sum256(data) 188 return hash[:], nil 189 } 190 191 func maybeCreateDir() (string, error) { 192 homedir, err := os.UserHomeDir() 193 if err != nil { 194 return "", err 195 } 196 197 dirVault := filepath.Join(homedir, ".file-vault") 198 _, err = os.Stat(dirVault) 199 if os.IsNotExist(err) { 200 err := os.MkdirAll(dirVault, 0700) 201 if err != nil { 202 return "", err 203 } 204 } 205 206 return dirVault, nil 207 }