vitess.io/vitess@v0.16.2/go/mysql/ldapauthserver/auth_server_ldap.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package ldapauthserver
    18  
    19  import (
    20  	"encoding/json"
    21  	"fmt"
    22  	"net"
    23  	"os"
    24  	"sync"
    25  	"time"
    26  
    27  	"github.com/spf13/pflag"
    28  	ldap "gopkg.in/ldap.v2"
    29  
    30  	"vitess.io/vitess/go/mysql"
    31  	"vitess.io/vitess/go/netutil"
    32  	"vitess.io/vitess/go/vt/log"
    33  	"vitess.io/vitess/go/vt/servenv"
    34  	"vitess.io/vitess/go/vt/vttls"
    35  
    36  	querypb "vitess.io/vitess/go/vt/proto/query"
    37  )
    38  
    39  var (
    40  	ldapAuthConfigFile   string
    41  	ldapAuthConfigString string
    42  	ldapAuthMethod       string
    43  )
    44  
    45  func init() {
    46  	servenv.OnParseFor("vtgate", func(fs *pflag.FlagSet) {
    47  		fs.StringVar(&ldapAuthConfigFile, "mysql_ldap_auth_config_file", "", "JSON File from which to read LDAP server config.")
    48  		fs.StringVar(&ldapAuthConfigString, "mysql_ldap_auth_config_string", "", "JSON representation of LDAP server config.")
    49  		fs.StringVar(&ldapAuthMethod, "mysql_ldap_auth_method", string(mysql.MysqlClearPassword), "client-side authentication method to use. Supported values: mysql_clear_password, dialog.")
    50  	})
    51  }
    52  
    53  // AuthServerLdap implements AuthServer with an LDAP backend
    54  type AuthServerLdap struct {
    55  	Client
    56  	ServerConfig
    57  	User           string
    58  	Password       string
    59  	GroupQuery     string
    60  	UserDnPattern  string
    61  	RefreshSeconds int64
    62  	methods        []mysql.AuthMethod
    63  }
    64  
    65  // Init is public so it can be called from plugin_auth_ldap.go (go/cmd/vtgate)
    66  func Init() {
    67  	if ldapAuthConfigFile == "" && ldapAuthConfigString == "" {
    68  		log.Infof("Not configuring AuthServerLdap because mysql_ldap_auth_config_file and mysql_ldap_auth_config_string are empty")
    69  		return
    70  	}
    71  	if ldapAuthConfigFile != "" && ldapAuthConfigString != "" {
    72  		log.Infof("Both mysql_ldap_auth_config_file and mysql_ldap_auth_config_string are non-empty, can only use one.")
    73  		return
    74  	}
    75  
    76  	if ldapAuthMethod != string(mysql.MysqlClearPassword) && ldapAuthMethod != string(mysql.MysqlDialog) {
    77  		log.Exitf("Invalid mysql_ldap_auth_method value: only support mysql_clear_password or dialog")
    78  	}
    79  	ldapAuthServer := &AuthServerLdap{
    80  		Client:       &ClientImpl{},
    81  		ServerConfig: ServerConfig{},
    82  	}
    83  
    84  	data := []byte(ldapAuthConfigString)
    85  	if ldapAuthConfigFile != "" {
    86  		var err error
    87  		data, err = os.ReadFile(ldapAuthConfigFile)
    88  		if err != nil {
    89  			log.Exitf("Failed to read mysql_ldap_auth_config_file: %v", err)
    90  		}
    91  	}
    92  	if err := json.Unmarshal(data, ldapAuthServer); err != nil {
    93  		log.Exitf("Error parsing AuthServerLdap config: %v", err)
    94  	}
    95  
    96  	var authMethod mysql.AuthMethod
    97  	switch mysql.AuthMethodDescription(ldapAuthMethod) {
    98  	case mysql.MysqlClearPassword:
    99  		authMethod = mysql.NewMysqlClearAuthMethod(ldapAuthServer, ldapAuthServer)
   100  	case mysql.MysqlDialog:
   101  		authMethod = mysql.NewMysqlDialogAuthMethod(ldapAuthServer, ldapAuthServer, "")
   102  	default:
   103  		log.Exitf("Invalid mysql_ldap_auth_method value: only support mysql_clear_password or dialog")
   104  	}
   105  
   106  	ldapAuthServer.methods = []mysql.AuthMethod{authMethod}
   107  	mysql.RegisterAuthServer("ldap", ldapAuthServer)
   108  }
   109  
   110  // AuthMethods returns the list of registered auth methods
   111  // implemented by this auth server.
   112  func (asl *AuthServerLdap) AuthMethods() []mysql.AuthMethod {
   113  	return asl.methods
   114  }
   115  
   116  // DefaultAuthMethodDescription returns MysqlNativePassword as the default
   117  // authentication method for the auth server implementation.
   118  func (asl *AuthServerLdap) DefaultAuthMethodDescription() mysql.AuthMethodDescription {
   119  	return mysql.MysqlNativePassword
   120  }
   121  
   122  // HandleUser is part of the Validator interface. We
   123  // handle any user here since we don't check up front.
   124  func (asl *AuthServerLdap) HandleUser(user string) bool {
   125  	return true
   126  }
   127  
   128  // UserEntryWithPassword is part of the PlaintextStorage interface
   129  // and called after the password is sent by the client.
   130  func (asl *AuthServerLdap) UserEntryWithPassword(conn *mysql.Conn, user string, password string, remoteAddr net.Addr) (mysql.Getter, error) {
   131  	return asl.validate(user, password)
   132  }
   133  
   134  func (asl *AuthServerLdap) validate(username, password string) (mysql.Getter, error) {
   135  	if err := asl.Client.Connect("tcp", &asl.ServerConfig); err != nil {
   136  		return nil, err
   137  	}
   138  	defer asl.Client.Close()
   139  	if err := asl.Client.Bind(fmt.Sprintf(asl.UserDnPattern, username), password); err != nil {
   140  		return nil, err
   141  	}
   142  	groups, err := asl.getGroups(username)
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  	return &LdapUserData{asl: asl, groups: groups, username: username, lastUpdated: time.Now(), updating: false}, nil
   147  }
   148  
   149  // this needs to be passed an already connected client...should check for this
   150  func (asl *AuthServerLdap) getGroups(username string) ([]string, error) {
   151  	err := asl.Client.Bind(asl.User, asl.Password)
   152  	if err != nil {
   153  		return nil, err
   154  	}
   155  	req := ldap.NewSearchRequest(
   156  		asl.GroupQuery,
   157  		ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
   158  		fmt.Sprintf("(memberUid=%s)", username),
   159  		[]string{"cn"},
   160  		nil,
   161  	)
   162  	res, err := asl.Client.Search(req)
   163  	if err != nil {
   164  		return nil, err
   165  	}
   166  	var groups []string
   167  	for _, entry := range res.Entries {
   168  		for _, attr := range entry.Attributes {
   169  			groups = append(groups, attr.Values[0])
   170  		}
   171  	}
   172  	return groups, nil
   173  }
   174  
   175  // LdapUserData holds username and LDAP groups as well as enough data to
   176  // intelligently update itself.
   177  type LdapUserData struct {
   178  	asl         *AuthServerLdap
   179  	groups      []string
   180  	username    string
   181  	lastUpdated time.Time
   182  	updating    bool
   183  	sync.Mutex
   184  }
   185  
   186  func (lud *LdapUserData) update() {
   187  	lud.Lock()
   188  	if lud.updating {
   189  		lud.Unlock()
   190  		return
   191  	}
   192  	lud.updating = true
   193  	lud.Unlock()
   194  	err := lud.asl.Client.Connect("tcp", &lud.asl.ServerConfig)
   195  	if err != nil {
   196  		log.Errorf("Error updating LDAP user data: %v", err)
   197  		return
   198  	}
   199  	defer lud.asl.Client.Close() //after the error check
   200  	groups, err := lud.asl.getGroups(lud.username)
   201  	if err != nil {
   202  		log.Errorf("Error updating LDAP user data: %v", err)
   203  		return
   204  	}
   205  	lud.Lock()
   206  	lud.groups = groups
   207  	lud.lastUpdated = time.Now()
   208  	lud.updating = false
   209  	lud.Unlock()
   210  }
   211  
   212  // Get returns wrapped username and LDAP groups and possibly updates the cache
   213  func (lud *LdapUserData) Get() *querypb.VTGateCallerID {
   214  	if int64(time.Since(lud.lastUpdated).Seconds()) > lud.asl.RefreshSeconds {
   215  		go lud.update()
   216  	}
   217  	return &querypb.VTGateCallerID{Username: lud.username, Groups: lud.groups}
   218  }
   219  
   220  // ServerConfig holds the config for and LDAP server
   221  // * include port in ldapServer, "ldap.example.com:386"
   222  type ServerConfig struct {
   223  	LdapServer        string
   224  	LdapCert          string
   225  	LdapKey           string
   226  	LdapCA            string
   227  	LdapCRL           string
   228  	LdapTLSMinVersion string
   229  }
   230  
   231  // Client provides an interface we can mock
   232  type Client interface {
   233  	Connect(network string, config *ServerConfig) error
   234  	Close()
   235  	Bind(string, string) error
   236  	Search(*ldap.SearchRequest) (*ldap.SearchResult, error)
   237  }
   238  
   239  // ClientImpl is the real implementation of LdapClient
   240  type ClientImpl struct {
   241  	*ldap.Conn
   242  }
   243  
   244  // Connect calls ldap.Dial and then upgrades the connection to TLS
   245  // This must be called before any other methods
   246  func (lci *ClientImpl) Connect(network string, config *ServerConfig) error {
   247  	conn, err := ldap.Dial(network, config.LdapServer)
   248  	if err != nil {
   249  		return err
   250  	}
   251  	lci.Conn = conn
   252  	// Reconnect with TLS ... why don't we simply DialTLS directly?
   253  	serverName, _, err := netutil.SplitHostPort(config.LdapServer)
   254  	if err != nil {
   255  		return err
   256  	}
   257  
   258  	tlsVersion, err := vttls.TLSVersionToNumber(config.LdapTLSMinVersion)
   259  	if err != nil {
   260  		return err
   261  	}
   262  
   263  	tlsConfig, err := vttls.ClientConfig(vttls.VerifyIdentity, config.LdapCert, config.LdapKey, config.LdapCA, config.LdapCRL, serverName, tlsVersion)
   264  	if err != nil {
   265  		return err
   266  	}
   267  	err = conn.StartTLS(tlsConfig)
   268  	if err != nil {
   269  		return err
   270  	}
   271  	return nil
   272  }