vitess.io/vitess@v0.16.2/go/mysql/auth_server_static.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package mysql
    18  
    19  import (
    20  	"bytes"
    21  	"crypto/subtle"
    22  	"encoding/json"
    23  	"net"
    24  	"os"
    25  	"os/signal"
    26  	"sync"
    27  	"syscall"
    28  	"time"
    29  
    30  	"github.com/spf13/pflag"
    31  
    32  	"vitess.io/vitess/go/vt/log"
    33  	"vitess.io/vitess/go/vt/servenv"
    34  	"vitess.io/vitess/go/vt/vterrors"
    35  
    36  	querypb "vitess.io/vitess/go/vt/proto/query"
    37  	"vitess.io/vitess/go/vt/proto/vtrpc"
    38  )
    39  
    40  var (
    41  	mysqlAuthServerStaticFile           string
    42  	mysqlAuthServerStaticString         string
    43  	mysqlAuthServerStaticReloadInterval time.Duration
    44  	mysqlServerFlushDelay               = 100 * time.Millisecond
    45  )
    46  
    47  func init() {
    48  	servenv.OnParseFor("vtgate", func(fs *pflag.FlagSet) {
    49  		fs.StringVar(&mysqlAuthServerStaticFile, "mysql_auth_server_static_file", "", "JSON File to read the users/passwords from.")
    50  		fs.StringVar(&mysqlAuthServerStaticString, "mysql_auth_server_static_string", "", "JSON representation of the users/passwords config.")
    51  		fs.DurationVar(&mysqlAuthServerStaticReloadInterval, "mysql_auth_static_reload_interval", 0, "Ticker to reload credentials")
    52  		fs.DurationVar(&mysqlServerFlushDelay, "mysql_server_flush_delay", mysqlServerFlushDelay, "Delay after which buffered response will be flushed to the client.")
    53  	})
    54  }
    55  
    56  const (
    57  	localhostName = "localhost"
    58  )
    59  
    60  // AuthServerStatic implements AuthServer using a static configuration.
    61  type AuthServerStatic struct {
    62  	methods          []AuthMethod
    63  	file, jsonConfig string
    64  	reloadInterval   time.Duration
    65  	// This mutex helps us prevent data races between the multiple updates of entries.
    66  	mu sync.Mutex
    67  	// entries contains the users, passwords and user data.
    68  	entries map[string][]*AuthServerStaticEntry
    69  
    70  	sigChan chan os.Signal
    71  	ticker  *time.Ticker
    72  }
    73  
    74  // AuthServerStaticEntry stores the values for a given user.
    75  type AuthServerStaticEntry struct {
    76  	// MysqlNativePassword is generated by password hashing methods in MySQL.
    77  	// These changes are illustrated by changes in the result from the PASSWORD() function
    78  	// that computes password hash values and in the structure of the user table where passwords are stored.
    79  	// mysql> SELECT PASSWORD('mypass');
    80  	// +-------------------------------------------+
    81  	// | PASSWORD('mypass')                        |
    82  	// +-------------------------------------------+
    83  	// | *6C8989366EAF75BB670AD8EA7A7FC1176A95CEF4 |
    84  	// +-------------------------------------------+
    85  	// MysqlNativePassword's format looks like "*6C8989366EAF75BB670AD8EA7A7FC1176A95CEF4", it store a hashing value.
    86  	// Use MysqlNativePassword in auth config, maybe more secure. After all, it is cryptographic storage.
    87  	MysqlNativePassword string
    88  	Password            string
    89  	UserData            string
    90  	SourceHost          string
    91  	Groups              []string
    92  }
    93  
    94  // InitAuthServerStatic Handles initializing the AuthServerStatic if necessary.
    95  func InitAuthServerStatic() {
    96  	// Check parameters.
    97  	if mysqlAuthServerStaticFile == "" && mysqlAuthServerStaticString == "" {
    98  		// Not configured, nothing to do.
    99  		log.Infof("Not configuring AuthServerStatic, as mysql_auth_server_static_file and mysql_auth_server_static_string are empty")
   100  		return
   101  	}
   102  	if mysqlAuthServerStaticFile != "" && mysqlAuthServerStaticString != "" {
   103  		// Both parameters specified, can only use one.
   104  		log.Exitf("Both mysql_auth_server_static_file and mysql_auth_server_static_string specified, can only use one.")
   105  	}
   106  
   107  	// Create and register auth server.
   108  	RegisterAuthServerStaticFromParams(mysqlAuthServerStaticFile, mysqlAuthServerStaticString, mysqlAuthServerStaticReloadInterval)
   109  }
   110  
   111  // RegisterAuthServerStaticFromParams creates and registers a new
   112  // AuthServerStatic, loaded for a JSON file or string. If file is set,
   113  // it uses file. Otherwise, load the string. It log.Exits out in case
   114  // of error.
   115  func RegisterAuthServerStaticFromParams(file, jsonConfig string, reloadInterval time.Duration) {
   116  	authServerStatic := NewAuthServerStatic(file, jsonConfig, reloadInterval)
   117  	if len(authServerStatic.entries) <= 0 {
   118  		log.Exitf("Failed to populate entries from file: %v", file)
   119  	}
   120  	RegisterAuthServer("static", authServerStatic)
   121  }
   122  
   123  // NewAuthServerStatic returns a new empty AuthServerStatic.
   124  func NewAuthServerStatic(file, jsonConfig string, reloadInterval time.Duration) *AuthServerStatic {
   125  	a := &AuthServerStatic{
   126  		file:           file,
   127  		jsonConfig:     jsonConfig,
   128  		reloadInterval: reloadInterval,
   129  		entries:        make(map[string][]*AuthServerStaticEntry),
   130  	}
   131  
   132  	a.methods = []AuthMethod{NewMysqlNativeAuthMethod(a, a)}
   133  
   134  	a.reload()
   135  	a.installSignalHandlers()
   136  	return a
   137  }
   138  
   139  // NewAuthServerStaticWithAuthMethodDescription returns a new empty AuthServerStatic
   140  // but with support for a different auth method. Mostly used for testing purposes.
   141  func NewAuthServerStaticWithAuthMethodDescription(file, jsonConfig string, reloadInterval time.Duration, authMethodDescription AuthMethodDescription) *AuthServerStatic {
   142  	a := &AuthServerStatic{
   143  		file:           file,
   144  		jsonConfig:     jsonConfig,
   145  		reloadInterval: reloadInterval,
   146  		entries:        make(map[string][]*AuthServerStaticEntry),
   147  	}
   148  
   149  	var authMethod AuthMethod
   150  	switch authMethodDescription {
   151  	case CachingSha2Password:
   152  		authMethod = NewSha2CachingAuthMethod(a, a, a)
   153  	case MysqlNativePassword:
   154  		authMethod = NewMysqlNativeAuthMethod(a, a)
   155  	case MysqlClearPassword:
   156  		authMethod = NewMysqlClearAuthMethod(a, a)
   157  	case MysqlDialog:
   158  		authMethod = NewMysqlDialogAuthMethod(a, a, "")
   159  	}
   160  
   161  	a.methods = []AuthMethod{authMethod}
   162  
   163  	a.reload()
   164  	a.installSignalHandlers()
   165  	return a
   166  }
   167  
   168  // HandleUser is part of the Validator interface. We
   169  // handle any user here since we don't check up front.
   170  func (a *AuthServerStatic) HandleUser(user string) bool {
   171  	return true
   172  }
   173  
   174  // UserEntryWithPassword implements password lookup based on a plain
   175  // text password that is negotiated with the client.
   176  func (a *AuthServerStatic) UserEntryWithPassword(conn *Conn, user string, password string, remoteAddr net.Addr) (Getter, error) {
   177  	a.mu.Lock()
   178  	entries, ok := a.entries[user]
   179  	a.mu.Unlock()
   180  
   181  	if !ok {
   182  		return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
   183  	}
   184  
   185  	for _, entry := range entries {
   186  		// Validate the password.
   187  		if MatchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare([]byte(password), []byte(entry.Password)) == 1 {
   188  			return &StaticUserData{entry.UserData, entry.Groups}, nil
   189  		}
   190  	}
   191  	return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
   192  }
   193  
   194  // UserEntryWithHash implements password lookup based on a
   195  // mysql_native_password hash that is negotiated with the client.
   196  func (a *AuthServerStatic) UserEntryWithHash(conn *Conn, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error) {
   197  	a.mu.Lock()
   198  	entries, ok := a.entries[user]
   199  	a.mu.Unlock()
   200  
   201  	if !ok {
   202  		return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
   203  	}
   204  
   205  	for _, entry := range entries {
   206  		if entry.MysqlNativePassword != "" {
   207  			hash, err := DecodeMysqlNativePasswordHex(entry.MysqlNativePassword)
   208  			if err != nil {
   209  				return &StaticUserData{entry.UserData, entry.Groups}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
   210  			}
   211  
   212  			isPass := VerifyHashedMysqlNativePassword(authResponse, salt, hash)
   213  			if MatchSourceHost(remoteAddr, entry.SourceHost) && isPass {
   214  				return &StaticUserData{entry.UserData, entry.Groups}, nil
   215  			}
   216  		} else {
   217  			computedAuthResponse := ScrambleMysqlNativePassword(salt, []byte(entry.Password))
   218  			// Validate the password.
   219  			if MatchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 {
   220  				return &StaticUserData{entry.UserData, entry.Groups}, nil
   221  			}
   222  		}
   223  	}
   224  	return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
   225  }
   226  
   227  // UserEntryWithCacheHash implements password lookup based on a
   228  // caching_sha2_password hash that is negotiated with the client.
   229  func (a *AuthServerStatic) UserEntryWithCacheHash(conn *Conn, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, CacheState, error) {
   230  	a.mu.Lock()
   231  	entries, ok := a.entries[user]
   232  	a.mu.Unlock()
   233  
   234  	if !ok {
   235  		return &StaticUserData{}, AuthRejected, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
   236  	}
   237  
   238  	for _, entry := range entries {
   239  		computedAuthResponse := ScrambleCachingSha2Password(salt, []byte(entry.Password))
   240  
   241  		// Validate the password.
   242  		if MatchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 {
   243  			return &StaticUserData{entry.UserData, entry.Groups}, AuthAccepted, nil
   244  		}
   245  	}
   246  	return &StaticUserData{}, AuthRejected, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
   247  }
   248  
   249  // AuthMethods returns the AuthMethod instances this auth server can handle.
   250  func (a *AuthServerStatic) AuthMethods() []AuthMethod {
   251  	return a.methods
   252  }
   253  
   254  // DefaultAuthMethodDescription returns the default auth method in the handshake which
   255  // is MysqlNativePassword for this auth server.
   256  func (a *AuthServerStatic) DefaultAuthMethodDescription() AuthMethodDescription {
   257  	return MysqlNativePassword
   258  }
   259  
   260  func (a *AuthServerStatic) reload() {
   261  	jsonBytes := []byte(a.jsonConfig)
   262  	if a.file != "" {
   263  		data, err := os.ReadFile(a.file)
   264  		if err != nil {
   265  			log.Errorf("Failed to read mysql_auth_server_static_file file: %v", err)
   266  			return
   267  		}
   268  		jsonBytes = data
   269  	}
   270  
   271  	entries := make(map[string][]*AuthServerStaticEntry)
   272  	if err := ParseConfig(jsonBytes, &entries); err != nil {
   273  		log.Errorf("Error parsing auth server config: %v", err)
   274  		return
   275  	}
   276  
   277  	a.mu.Lock()
   278  	a.entries = entries
   279  	a.mu.Unlock()
   280  }
   281  
   282  func (a *AuthServerStatic) installSignalHandlers() {
   283  	if a.file == "" {
   284  		return
   285  	}
   286  
   287  	a.sigChan = make(chan os.Signal, 1)
   288  	signal.Notify(a.sigChan, syscall.SIGHUP)
   289  	go func() {
   290  		for range a.sigChan {
   291  			a.reload()
   292  		}
   293  	}()
   294  
   295  	// If duration is set, it will reload configuration every interval
   296  	if a.reloadInterval > 0 {
   297  		a.ticker = time.NewTicker(a.reloadInterval)
   298  		go func() {
   299  			for range a.ticker.C {
   300  				a.sigChan <- syscall.SIGHUP
   301  			}
   302  		}()
   303  	}
   304  }
   305  
   306  func (a *AuthServerStatic) close() {
   307  	if a.ticker != nil {
   308  		a.ticker.Stop()
   309  	}
   310  	if a.sigChan != nil {
   311  		signal.Stop(a.sigChan)
   312  	}
   313  }
   314  
   315  // ParseConfig takes a JSON MySQL static config and converts to a validated map
   316  func ParseConfig(jsonBytes []byte, config *map[string][]*AuthServerStaticEntry) error {
   317  	decoder := json.NewDecoder(bytes.NewReader(jsonBytes))
   318  	decoder.DisallowUnknownFields()
   319  	if err := decoder.Decode(config); err != nil {
   320  		// Couldn't parse, will try to parse with legacy config
   321  		return parseLegacyConfig(jsonBytes, config)
   322  	}
   323  	return validateConfig(*config)
   324  }
   325  
   326  func parseLegacyConfig(jsonBytes []byte, config *map[string][]*AuthServerStaticEntry) error {
   327  	// legacy config doesn't have an array
   328  	legacyConfig := make(map[string]*AuthServerStaticEntry)
   329  	decoder := json.NewDecoder(bytes.NewReader(jsonBytes))
   330  	decoder.DisallowUnknownFields()
   331  	if err := decoder.Decode(&legacyConfig); err != nil {
   332  		return err
   333  	}
   334  	log.Warningf("Config parsed using legacy configuration. Please update to the latest format: {\"user\":[{\"Password\": \"xxx\"}, ...]}")
   335  	for key, value := range legacyConfig {
   336  		(*config)[key] = append((*config)[key], value)
   337  	}
   338  	return nil
   339  }
   340  
   341  func validateConfig(config map[string][]*AuthServerStaticEntry) error {
   342  	for _, entries := range config {
   343  		for _, entry := range entries {
   344  			if entry.SourceHost != "" && entry.SourceHost != localhostName {
   345  				return vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "invalid SourceHost found (only localhost is supported): %v", entry.SourceHost)
   346  			}
   347  		}
   348  	}
   349  	return nil
   350  }
   351  
   352  // MatchSourceHost validates host entry in auth configuration
   353  func MatchSourceHost(remoteAddr net.Addr, targetSourceHost string) bool {
   354  	// Legacy support, there was not matcher defined default to true
   355  	if targetSourceHost == "" {
   356  		return true
   357  	}
   358  	switch remoteAddr.(type) {
   359  	case *net.UnixAddr:
   360  		if targetSourceHost == localhostName {
   361  			return true
   362  		}
   363  	}
   364  	return false
   365  }
   366  
   367  // StaticUserData holds the username and groups
   368  type StaticUserData struct {
   369  	Username string
   370  	Groups   []string
   371  }
   372  
   373  // Get returns the wrapped username and groups
   374  func (sud *StaticUserData) Get() *querypb.VTGateCallerID {
   375  	return &querypb.VTGateCallerID{Username: sud.Username, Groups: sud.Groups}
   376  }