vitess.io/vitess@v0.16.2/go/mysql/vault/auth_server_vault.go (about) 1 /* 2 Copyright 2020 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package vault 18 19 import ( 20 "crypto/subtle" 21 "fmt" 22 "net" 23 "os" 24 "os/signal" 25 "strings" 26 "sync" 27 "syscall" 28 "time" 29 30 vaultapi "github.com/aquarapid/vaultlib" 31 "github.com/spf13/pflag" 32 33 "vitess.io/vitess/go/mysql" 34 "vitess.io/vitess/go/vt/log" 35 "vitess.io/vitess/go/vt/servenv" 36 ) 37 38 var ( 39 vaultAddr string 40 vaultTimeout time.Duration 41 vaultCACert string 42 vaultPath string 43 vaultCacheTTL time.Duration 44 vaultTokenFile string 45 vaultRoleID string 46 vaultRoleSecretIDFile string 47 vaultRoleMountPoint string 48 ) 49 50 func init() { 51 servenv.OnParseFor("vtgate", func(fs *pflag.FlagSet) { 52 fs.StringVar(&vaultAddr, "mysql_auth_vault_addr", "", "URL to Vault server") 53 fs.DurationVar(&vaultTimeout, "mysql_auth_vault_timeout", 10*time.Second, "Timeout for vault API operations") 54 fs.StringVar(&vaultCACert, "mysql_auth_vault_tls_ca", "", "Path to CA PEM for validating Vault server certificate") 55 fs.StringVar(&vaultPath, "mysql_auth_vault_path", "", "Vault path to vtgate credentials JSON blob, e.g.: secret/data/prod/vtgatecreds") 56 fs.DurationVar(&vaultCacheTTL, "mysql_auth_vault_ttl", 30*time.Minute, "How long to cache vtgate credentials from the Vault server") 57 fs.StringVar(&vaultTokenFile, "mysql_auth_vault_tokenfile", "", "Path to file containing Vault auth token; token can also be passed using VAULT_TOKEN environment variable") 58 fs.StringVar(&vaultRoleID, "mysql_auth_vault_roleid", "", "Vault AppRole id; can also be passed using VAULT_ROLEID environment variable") 59 fs.StringVar(&vaultRoleSecretIDFile, "mysql_auth_vault_role_secretidfile", "", "Path to file containing Vault AppRole secret_id; can also be passed using VAULT_SECRETID environment variable") 60 fs.StringVar(&vaultRoleMountPoint, "mysql_auth_vault_role_mountpoint", "approle", "Vault AppRole mountpoint; can also be passed using VAULT_MOUNTPOINT environment variable") 61 }) 62 } 63 64 // AuthServerVault implements AuthServer with a config loaded from Vault. 65 type AuthServerVault struct { 66 methods []mysql.AuthMethod 67 mu sync.Mutex 68 // users, passwords and user data 69 // We use the same JSON format as for --mysql_auth_server_static 70 // Acts as a cache for the in-Vault data 71 entries map[string][]*mysql.AuthServerStaticEntry 72 vaultCacheExpireTicker *time.Ticker 73 vaultClient *vaultapi.Client 74 vaultPath string 75 vaultTTL time.Duration 76 77 sigChan chan os.Signal 78 } 79 80 // InitAuthServerVault - entrypoint for initialization of Vault AuthServer implementation 81 func InitAuthServerVault() { 82 // Check critical parameters. 83 if vaultAddr == "" { 84 log.Infof("Not configuring AuthServerVault, as --mysql_auth_vault_addr is empty.") 85 return 86 } 87 if vaultPath == "" { 88 log.Exitf("If using Vault auth server, --mysql_auth_vault_path is required.") 89 } 90 91 registerAuthServerVault(vaultAddr, vaultTimeout, vaultCACert, vaultPath, vaultCacheTTL, vaultTokenFile, vaultRoleID, vaultRoleSecretIDFile, vaultRoleMountPoint) 92 } 93 94 func registerAuthServerVault(addr string, timeout time.Duration, caCertPath string, path string, ttl time.Duration, tokenFilePath string, roleID string, secretIDPath string, roleMountPoint string) { 95 authServerVault, err := newAuthServerVault(addr, timeout, caCertPath, path, ttl, tokenFilePath, roleID, secretIDPath, roleMountPoint) 96 if err != nil { 97 log.Exitf("%s", err) 98 } 99 mysql.RegisterAuthServer("vault", authServerVault) 100 } 101 102 func newAuthServerVault(addr string, timeout time.Duration, caCertPath string, path string, ttl time.Duration, tokenFilePath string, roleID string, secretIDPath string, roleMountPoint string) (*AuthServerVault, error) { 103 // Validate more parameters 104 token, err := readFromFile(tokenFilePath) 105 if err != nil { 106 return nil, fmt.Errorf("No Vault token in provided filename for --mysql_auth_vault_tokenfile") 107 } 108 secretID, err := readFromFile(secretIDPath) 109 if err != nil { 110 return nil, fmt.Errorf("No Vault secret_id in provided filename for --mysql_auth_vault_role_secretidfile") 111 } 112 113 config := vaultapi.NewConfig() 114 115 // All these can be overriden by environment 116 // so we need to check if they have been set by NewConfig 117 if config.Address == "" { 118 config.Address = addr 119 } 120 if config.Timeout == (0 * time.Second) { 121 config.Timeout = timeout 122 } 123 if config.CACert == "" { 124 config.CACert = caCertPath 125 } 126 if config.Token == "" { 127 config.Token = token 128 } 129 if config.AppRoleCredentials.RoleID == "" { 130 config.AppRoleCredentials.RoleID = roleID 131 } 132 if config.AppRoleCredentials.SecretID == "" { 133 config.AppRoleCredentials.SecretID = secretID 134 } 135 if config.AppRoleCredentials.MountPoint == "" { 136 config.AppRoleCredentials.MountPoint = roleMountPoint 137 } 138 139 if config.CACert != "" { 140 // If we provide a CA, ensure we actually use it 141 config.InsecureSSL = false 142 } 143 144 client, err := vaultapi.NewClient(config) 145 if err != nil || client == nil { 146 log.Errorf("Error in vault client initialization, will retry: %v", err) 147 } 148 149 a := &AuthServerVault{ 150 vaultClient: client, 151 vaultPath: path, 152 vaultTTL: ttl, 153 entries: make(map[string][]*mysql.AuthServerStaticEntry), 154 } 155 156 authMethodNative := mysql.NewMysqlNativeAuthMethod(a, a) 157 a.methods = []mysql.AuthMethod{authMethodNative} 158 159 a.reloadVault() 160 a.installSignalHandlers() 161 return a, nil 162 } 163 164 // AuthMethods returns the list of registered auth methods 165 // implemented by this auth server. 166 func (a *AuthServerVault) AuthMethods() []mysql.AuthMethod { 167 return a.methods 168 } 169 170 // DefaultAuthMethodDescription returns MysqlNativePassword as the default 171 // authentication method for the auth server implementation. 172 func (a *AuthServerVault) DefaultAuthMethodDescription() mysql.AuthMethodDescription { 173 return mysql.MysqlNativePassword 174 } 175 176 // HandleUser is part of the Validator interface. We 177 // handle any user here since we don't check up front. 178 func (a *AuthServerVault) HandleUser(user string) bool { 179 return true 180 } 181 182 // UserEntryWithHash is called when mysql_native_password is used. 183 func (a *AuthServerVault) UserEntryWithHash(conn *mysql.Conn, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (mysql.Getter, error) { 184 a.mu.Lock() 185 userEntries, ok := a.entries[user] 186 a.mu.Unlock() 187 188 if !ok { 189 return &mysql.StaticUserData{}, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user) 190 } 191 192 for _, entry := range userEntries { 193 if entry.MysqlNativePassword != "" { 194 hash, err := mysql.DecodeMysqlNativePasswordHex(entry.MysqlNativePassword) 195 if err != nil { 196 return &mysql.StaticUserData{Username: entry.UserData, Groups: entry.Groups}, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user) 197 } 198 isPass := mysql.VerifyHashedMysqlNativePassword(authResponse, salt, hash) 199 if mysql.MatchSourceHost(remoteAddr, entry.SourceHost) && isPass { 200 return &mysql.StaticUserData{Username: entry.UserData, Groups: entry.Groups}, nil 201 } 202 } else { 203 computedAuthResponse := mysql.ScrambleMysqlNativePassword(salt, []byte(entry.Password)) 204 // Validate the password. 205 if mysql.MatchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 { 206 return &mysql.StaticUserData{Username: entry.UserData, Groups: entry.Groups}, nil 207 } 208 } 209 } 210 return &mysql.StaticUserData{}, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user) 211 } 212 213 func (a *AuthServerVault) setTTLTicker(ttl time.Duration) { 214 a.mu.Lock() 215 defer a.mu.Unlock() 216 if a.vaultCacheExpireTicker == nil { 217 a.vaultCacheExpireTicker = time.NewTicker(ttl) 218 go func() { 219 for range a.vaultCacheExpireTicker.C { 220 a.sigChan <- syscall.SIGHUP 221 } 222 }() 223 } else { 224 a.vaultCacheExpireTicker.Reset(ttl) 225 } 226 } 227 228 // Reload JSON auth key from Vault. Return true if successful, false if not 229 func (a *AuthServerVault) reloadVault() error { 230 a.mu.Lock() 231 secret, err := a.vaultClient.GetSecret(a.vaultPath) 232 a.mu.Unlock() 233 a.setTTLTicker(10 * time.Second) // Reload frequently on error 234 235 if err != nil { 236 return fmt.Errorf("Error in vtgate Vault auth server params: %v", err) 237 } 238 239 if secret.JSONSecret == nil { 240 return fmt.Errorf("Empty vtgate credentials retrieved from Vault server") 241 } 242 243 entries := make(map[string][]*mysql.AuthServerStaticEntry) 244 if err := mysql.ParseConfig(secret.JSONSecret, &entries); err != nil { 245 return fmt.Errorf("Error parsing vtgate Vault auth server config: %v", err) 246 } 247 if len(entries) == 0 { 248 return fmt.Errorf("vtgate credentials from Vault empty! Not updating previously cached values") 249 } 250 251 log.Infof("reloadVault(): success. Client status: %s", a.vaultClient.GetStatus()) 252 a.mu.Lock() 253 a.entries = entries 254 a.mu.Unlock() 255 a.setTTLTicker(a.vaultTTL) 256 return nil 257 } 258 259 func (a *AuthServerVault) installSignalHandlers() { 260 a.mu.Lock() 261 defer a.mu.Unlock() 262 263 a.sigChan = make(chan os.Signal, 1) 264 signal.Notify(a.sigChan, syscall.SIGHUP) 265 go func() { 266 for range a.sigChan { 267 err := a.reloadVault() 268 if err != nil { 269 log.Errorf("%s", err) 270 } 271 272 } 273 }() 274 } 275 276 func (a *AuthServerVault) close() { 277 log.Warningf("Closing AuthServerVault instance.") 278 a.mu.Lock() 279 defer a.mu.Unlock() 280 if a.vaultCacheExpireTicker != nil { 281 a.vaultCacheExpireTicker.Stop() 282 } 283 if a.sigChan != nil { 284 signal.Stop(a.sigChan) 285 } 286 } 287 288 // We ignore most errors here, to allow us to retry cleanly 289 // 290 // or ignore the cases where the input is not passed by file, but via env 291 func readFromFile(filePath string) (string, error) { 292 if filePath == "" { 293 return "", nil 294 } 295 fileBytes, err := os.ReadFile(filePath) 296 if err != nil { 297 log.Errorf("Could not read file: %s", filePath) 298 return "", err 299 } 300 return strings.TrimSpace(string(fileBytes)), nil 301 }