vitess.io/vitess@v0.16.2/go/mysql/vault/auth_server_vault.go (about)

     1  /*
     2  Copyright 2020 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package vault
    18  
    19  import (
    20  	"crypto/subtle"
    21  	"fmt"
    22  	"net"
    23  	"os"
    24  	"os/signal"
    25  	"strings"
    26  	"sync"
    27  	"syscall"
    28  	"time"
    29  
    30  	vaultapi "github.com/aquarapid/vaultlib"
    31  	"github.com/spf13/pflag"
    32  
    33  	"vitess.io/vitess/go/mysql"
    34  	"vitess.io/vitess/go/vt/log"
    35  	"vitess.io/vitess/go/vt/servenv"
    36  )
    37  
    38  var (
    39  	vaultAddr             string
    40  	vaultTimeout          time.Duration
    41  	vaultCACert           string
    42  	vaultPath             string
    43  	vaultCacheTTL         time.Duration
    44  	vaultTokenFile        string
    45  	vaultRoleID           string
    46  	vaultRoleSecretIDFile string
    47  	vaultRoleMountPoint   string
    48  )
    49  
    50  func init() {
    51  	servenv.OnParseFor("vtgate", func(fs *pflag.FlagSet) {
    52  		fs.StringVar(&vaultAddr, "mysql_auth_vault_addr", "", "URL to Vault server")
    53  		fs.DurationVar(&vaultTimeout, "mysql_auth_vault_timeout", 10*time.Second, "Timeout for vault API operations")
    54  		fs.StringVar(&vaultCACert, "mysql_auth_vault_tls_ca", "", "Path to CA PEM for validating Vault server certificate")
    55  		fs.StringVar(&vaultPath, "mysql_auth_vault_path", "", "Vault path to vtgate credentials JSON blob, e.g.: secret/data/prod/vtgatecreds")
    56  		fs.DurationVar(&vaultCacheTTL, "mysql_auth_vault_ttl", 30*time.Minute, "How long to cache vtgate credentials from the Vault server")
    57  		fs.StringVar(&vaultTokenFile, "mysql_auth_vault_tokenfile", "", "Path to file containing Vault auth token; token can also be passed using VAULT_TOKEN environment variable")
    58  		fs.StringVar(&vaultRoleID, "mysql_auth_vault_roleid", "", "Vault AppRole id; can also be passed using VAULT_ROLEID environment variable")
    59  		fs.StringVar(&vaultRoleSecretIDFile, "mysql_auth_vault_role_secretidfile", "", "Path to file containing Vault AppRole secret_id; can also be passed using VAULT_SECRETID environment variable")
    60  		fs.StringVar(&vaultRoleMountPoint, "mysql_auth_vault_role_mountpoint", "approle", "Vault AppRole mountpoint; can also be passed using VAULT_MOUNTPOINT environment variable")
    61  	})
    62  }
    63  
    64  // AuthServerVault implements AuthServer with a config loaded from Vault.
    65  type AuthServerVault struct {
    66  	methods []mysql.AuthMethod
    67  	mu      sync.Mutex
    68  	// users, passwords and user data
    69  	// We use the same JSON format as for --mysql_auth_server_static
    70  	// Acts as a cache for the in-Vault data
    71  	entries                map[string][]*mysql.AuthServerStaticEntry
    72  	vaultCacheExpireTicker *time.Ticker
    73  	vaultClient            *vaultapi.Client
    74  	vaultPath              string
    75  	vaultTTL               time.Duration
    76  
    77  	sigChan chan os.Signal
    78  }
    79  
    80  // InitAuthServerVault - entrypoint for initialization of Vault AuthServer implementation
    81  func InitAuthServerVault() {
    82  	// Check critical parameters.
    83  	if vaultAddr == "" {
    84  		log.Infof("Not configuring AuthServerVault, as --mysql_auth_vault_addr is empty.")
    85  		return
    86  	}
    87  	if vaultPath == "" {
    88  		log.Exitf("If using Vault auth server, --mysql_auth_vault_path is required.")
    89  	}
    90  
    91  	registerAuthServerVault(vaultAddr, vaultTimeout, vaultCACert, vaultPath, vaultCacheTTL, vaultTokenFile, vaultRoleID, vaultRoleSecretIDFile, vaultRoleMountPoint)
    92  }
    93  
    94  func registerAuthServerVault(addr string, timeout time.Duration, caCertPath string, path string, ttl time.Duration, tokenFilePath string, roleID string, secretIDPath string, roleMountPoint string) {
    95  	authServerVault, err := newAuthServerVault(addr, timeout, caCertPath, path, ttl, tokenFilePath, roleID, secretIDPath, roleMountPoint)
    96  	if err != nil {
    97  		log.Exitf("%s", err)
    98  	}
    99  	mysql.RegisterAuthServer("vault", authServerVault)
   100  }
   101  
   102  func newAuthServerVault(addr string, timeout time.Duration, caCertPath string, path string, ttl time.Duration, tokenFilePath string, roleID string, secretIDPath string, roleMountPoint string) (*AuthServerVault, error) {
   103  	// Validate more parameters
   104  	token, err := readFromFile(tokenFilePath)
   105  	if err != nil {
   106  		return nil, fmt.Errorf("No Vault token in provided filename for --mysql_auth_vault_tokenfile")
   107  	}
   108  	secretID, err := readFromFile(secretIDPath)
   109  	if err != nil {
   110  		return nil, fmt.Errorf("No Vault secret_id in provided filename for --mysql_auth_vault_role_secretidfile")
   111  	}
   112  
   113  	config := vaultapi.NewConfig()
   114  
   115  	// All these can be overriden by environment
   116  	//   so we need to check if they have been set by NewConfig
   117  	if config.Address == "" {
   118  		config.Address = addr
   119  	}
   120  	if config.Timeout == (0 * time.Second) {
   121  		config.Timeout = timeout
   122  	}
   123  	if config.CACert == "" {
   124  		config.CACert = caCertPath
   125  	}
   126  	if config.Token == "" {
   127  		config.Token = token
   128  	}
   129  	if config.AppRoleCredentials.RoleID == "" {
   130  		config.AppRoleCredentials.RoleID = roleID
   131  	}
   132  	if config.AppRoleCredentials.SecretID == "" {
   133  		config.AppRoleCredentials.SecretID = secretID
   134  	}
   135  	if config.AppRoleCredentials.MountPoint == "" {
   136  		config.AppRoleCredentials.MountPoint = roleMountPoint
   137  	}
   138  
   139  	if config.CACert != "" {
   140  		// If we provide a CA, ensure we actually use it
   141  		config.InsecureSSL = false
   142  	}
   143  
   144  	client, err := vaultapi.NewClient(config)
   145  	if err != nil || client == nil {
   146  		log.Errorf("Error in vault client initialization, will retry: %v", err)
   147  	}
   148  
   149  	a := &AuthServerVault{
   150  		vaultClient: client,
   151  		vaultPath:   path,
   152  		vaultTTL:    ttl,
   153  		entries:     make(map[string][]*mysql.AuthServerStaticEntry),
   154  	}
   155  
   156  	authMethodNative := mysql.NewMysqlNativeAuthMethod(a, a)
   157  	a.methods = []mysql.AuthMethod{authMethodNative}
   158  
   159  	a.reloadVault()
   160  	a.installSignalHandlers()
   161  	return a, nil
   162  }
   163  
   164  // AuthMethods returns the list of registered auth methods
   165  // implemented by this auth server.
   166  func (a *AuthServerVault) AuthMethods() []mysql.AuthMethod {
   167  	return a.methods
   168  }
   169  
   170  // DefaultAuthMethodDescription returns MysqlNativePassword as the default
   171  // authentication method for the auth server implementation.
   172  func (a *AuthServerVault) DefaultAuthMethodDescription() mysql.AuthMethodDescription {
   173  	return mysql.MysqlNativePassword
   174  }
   175  
   176  // HandleUser is part of the Validator interface. We
   177  // handle any user here since we don't check up front.
   178  func (a *AuthServerVault) HandleUser(user string) bool {
   179  	return true
   180  }
   181  
   182  // UserEntryWithHash is called when mysql_native_password is used.
   183  func (a *AuthServerVault) UserEntryWithHash(conn *mysql.Conn, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (mysql.Getter, error) {
   184  	a.mu.Lock()
   185  	userEntries, ok := a.entries[user]
   186  	a.mu.Unlock()
   187  
   188  	if !ok {
   189  		return &mysql.StaticUserData{}, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
   190  	}
   191  
   192  	for _, entry := range userEntries {
   193  		if entry.MysqlNativePassword != "" {
   194  			hash, err := mysql.DecodeMysqlNativePasswordHex(entry.MysqlNativePassword)
   195  			if err != nil {
   196  				return &mysql.StaticUserData{Username: entry.UserData, Groups: entry.Groups}, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
   197  			}
   198  			isPass := mysql.VerifyHashedMysqlNativePassword(authResponse, salt, hash)
   199  			if mysql.MatchSourceHost(remoteAddr, entry.SourceHost) && isPass {
   200  				return &mysql.StaticUserData{Username: entry.UserData, Groups: entry.Groups}, nil
   201  			}
   202  		} else {
   203  			computedAuthResponse := mysql.ScrambleMysqlNativePassword(salt, []byte(entry.Password))
   204  			// Validate the password.
   205  			if mysql.MatchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 {
   206  				return &mysql.StaticUserData{Username: entry.UserData, Groups: entry.Groups}, nil
   207  			}
   208  		}
   209  	}
   210  	return &mysql.StaticUserData{}, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
   211  }
   212  
   213  func (a *AuthServerVault) setTTLTicker(ttl time.Duration) {
   214  	a.mu.Lock()
   215  	defer a.mu.Unlock()
   216  	if a.vaultCacheExpireTicker == nil {
   217  		a.vaultCacheExpireTicker = time.NewTicker(ttl)
   218  		go func() {
   219  			for range a.vaultCacheExpireTicker.C {
   220  				a.sigChan <- syscall.SIGHUP
   221  			}
   222  		}()
   223  	} else {
   224  		a.vaultCacheExpireTicker.Reset(ttl)
   225  	}
   226  }
   227  
   228  // Reload JSON auth key from Vault. Return true if successful, false if not
   229  func (a *AuthServerVault) reloadVault() error {
   230  	a.mu.Lock()
   231  	secret, err := a.vaultClient.GetSecret(a.vaultPath)
   232  	a.mu.Unlock()
   233  	a.setTTLTicker(10 * time.Second) // Reload frequently on error
   234  
   235  	if err != nil {
   236  		return fmt.Errorf("Error in vtgate Vault auth server params: %v", err)
   237  	}
   238  
   239  	if secret.JSONSecret == nil {
   240  		return fmt.Errorf("Empty vtgate credentials retrieved from Vault server")
   241  	}
   242  
   243  	entries := make(map[string][]*mysql.AuthServerStaticEntry)
   244  	if err := mysql.ParseConfig(secret.JSONSecret, &entries); err != nil {
   245  		return fmt.Errorf("Error parsing vtgate Vault auth server config: %v", err)
   246  	}
   247  	if len(entries) == 0 {
   248  		return fmt.Errorf("vtgate credentials from Vault empty! Not updating previously cached values")
   249  	}
   250  
   251  	log.Infof("reloadVault(): success. Client status: %s", a.vaultClient.GetStatus())
   252  	a.mu.Lock()
   253  	a.entries = entries
   254  	a.mu.Unlock()
   255  	a.setTTLTicker(a.vaultTTL)
   256  	return nil
   257  }
   258  
   259  func (a *AuthServerVault) installSignalHandlers() {
   260  	a.mu.Lock()
   261  	defer a.mu.Unlock()
   262  
   263  	a.sigChan = make(chan os.Signal, 1)
   264  	signal.Notify(a.sigChan, syscall.SIGHUP)
   265  	go func() {
   266  		for range a.sigChan {
   267  			err := a.reloadVault()
   268  			if err != nil {
   269  				log.Errorf("%s", err)
   270  			}
   271  
   272  		}
   273  	}()
   274  }
   275  
   276  func (a *AuthServerVault) close() {
   277  	log.Warningf("Closing AuthServerVault instance.")
   278  	a.mu.Lock()
   279  	defer a.mu.Unlock()
   280  	if a.vaultCacheExpireTicker != nil {
   281  		a.vaultCacheExpireTicker.Stop()
   282  	}
   283  	if a.sigChan != nil {
   284  		signal.Stop(a.sigChan)
   285  	}
   286  }
   287  
   288  // We ignore most errors here, to allow us to retry cleanly
   289  //
   290  //	or ignore the cases where the input is not passed by file, but via env
   291  func readFromFile(filePath string) (string, error) {
   292  	if filePath == "" {
   293  		return "", nil
   294  	}
   295  	fileBytes, err := os.ReadFile(filePath)
   296  	if err != nil {
   297  		log.Errorf("Could not read file: %s", filePath)
   298  		return "", err
   299  	}
   300  	return strings.TrimSpace(string(fileBytes)), nil
   301  }