vitess.io/vitess@v0.16.2/go/mysql/ldapauthserver/auth_server_ldap.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package ldapauthserver 18 19 import ( 20 "encoding/json" 21 "fmt" 22 "net" 23 "os" 24 "sync" 25 "time" 26 27 "github.com/spf13/pflag" 28 ldap "gopkg.in/ldap.v2" 29 30 "vitess.io/vitess/go/mysql" 31 "vitess.io/vitess/go/netutil" 32 "vitess.io/vitess/go/vt/log" 33 "vitess.io/vitess/go/vt/servenv" 34 "vitess.io/vitess/go/vt/vttls" 35 36 querypb "vitess.io/vitess/go/vt/proto/query" 37 ) 38 39 var ( 40 ldapAuthConfigFile string 41 ldapAuthConfigString string 42 ldapAuthMethod string 43 ) 44 45 func init() { 46 servenv.OnParseFor("vtgate", func(fs *pflag.FlagSet) { 47 fs.StringVar(&ldapAuthConfigFile, "mysql_ldap_auth_config_file", "", "JSON File from which to read LDAP server config.") 48 fs.StringVar(&ldapAuthConfigString, "mysql_ldap_auth_config_string", "", "JSON representation of LDAP server config.") 49 fs.StringVar(&ldapAuthMethod, "mysql_ldap_auth_method", string(mysql.MysqlClearPassword), "client-side authentication method to use. Supported values: mysql_clear_password, dialog.") 50 }) 51 } 52 53 // AuthServerLdap implements AuthServer with an LDAP backend 54 type AuthServerLdap struct { 55 Client 56 ServerConfig 57 User string 58 Password string 59 GroupQuery string 60 UserDnPattern string 61 RefreshSeconds int64 62 methods []mysql.AuthMethod 63 } 64 65 // Init is public so it can be called from plugin_auth_ldap.go (go/cmd/vtgate) 66 func Init() { 67 if ldapAuthConfigFile == "" && ldapAuthConfigString == "" { 68 log.Infof("Not configuring AuthServerLdap because mysql_ldap_auth_config_file and mysql_ldap_auth_config_string are empty") 69 return 70 } 71 if ldapAuthConfigFile != "" && ldapAuthConfigString != "" { 72 log.Infof("Both mysql_ldap_auth_config_file and mysql_ldap_auth_config_string are non-empty, can only use one.") 73 return 74 } 75 76 if ldapAuthMethod != string(mysql.MysqlClearPassword) && ldapAuthMethod != string(mysql.MysqlDialog) { 77 log.Exitf("Invalid mysql_ldap_auth_method value: only support mysql_clear_password or dialog") 78 } 79 ldapAuthServer := &AuthServerLdap{ 80 Client: &ClientImpl{}, 81 ServerConfig: ServerConfig{}, 82 } 83 84 data := []byte(ldapAuthConfigString) 85 if ldapAuthConfigFile != "" { 86 var err error 87 data, err = os.ReadFile(ldapAuthConfigFile) 88 if err != nil { 89 log.Exitf("Failed to read mysql_ldap_auth_config_file: %v", err) 90 } 91 } 92 if err := json.Unmarshal(data, ldapAuthServer); err != nil { 93 log.Exitf("Error parsing AuthServerLdap config: %v", err) 94 } 95 96 var authMethod mysql.AuthMethod 97 switch mysql.AuthMethodDescription(ldapAuthMethod) { 98 case mysql.MysqlClearPassword: 99 authMethod = mysql.NewMysqlClearAuthMethod(ldapAuthServer, ldapAuthServer) 100 case mysql.MysqlDialog: 101 authMethod = mysql.NewMysqlDialogAuthMethod(ldapAuthServer, ldapAuthServer, "") 102 default: 103 log.Exitf("Invalid mysql_ldap_auth_method value: only support mysql_clear_password or dialog") 104 } 105 106 ldapAuthServer.methods = []mysql.AuthMethod{authMethod} 107 mysql.RegisterAuthServer("ldap", ldapAuthServer) 108 } 109 110 // AuthMethods returns the list of registered auth methods 111 // implemented by this auth server. 112 func (asl *AuthServerLdap) AuthMethods() []mysql.AuthMethod { 113 return asl.methods 114 } 115 116 // DefaultAuthMethodDescription returns MysqlNativePassword as the default 117 // authentication method for the auth server implementation. 118 func (asl *AuthServerLdap) DefaultAuthMethodDescription() mysql.AuthMethodDescription { 119 return mysql.MysqlNativePassword 120 } 121 122 // HandleUser is part of the Validator interface. We 123 // handle any user here since we don't check up front. 124 func (asl *AuthServerLdap) HandleUser(user string) bool { 125 return true 126 } 127 128 // UserEntryWithPassword is part of the PlaintextStorage interface 129 // and called after the password is sent by the client. 130 func (asl *AuthServerLdap) UserEntryWithPassword(conn *mysql.Conn, user string, password string, remoteAddr net.Addr) (mysql.Getter, error) { 131 return asl.validate(user, password) 132 } 133 134 func (asl *AuthServerLdap) validate(username, password string) (mysql.Getter, error) { 135 if err := asl.Client.Connect("tcp", &asl.ServerConfig); err != nil { 136 return nil, err 137 } 138 defer asl.Client.Close() 139 if err := asl.Client.Bind(fmt.Sprintf(asl.UserDnPattern, username), password); err != nil { 140 return nil, err 141 } 142 groups, err := asl.getGroups(username) 143 if err != nil { 144 return nil, err 145 } 146 return &LdapUserData{asl: asl, groups: groups, username: username, lastUpdated: time.Now(), updating: false}, nil 147 } 148 149 // this needs to be passed an already connected client...should check for this 150 func (asl *AuthServerLdap) getGroups(username string) ([]string, error) { 151 err := asl.Client.Bind(asl.User, asl.Password) 152 if err != nil { 153 return nil, err 154 } 155 req := ldap.NewSearchRequest( 156 asl.GroupQuery, 157 ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, 158 fmt.Sprintf("(memberUid=%s)", username), 159 []string{"cn"}, 160 nil, 161 ) 162 res, err := asl.Client.Search(req) 163 if err != nil { 164 return nil, err 165 } 166 var groups []string 167 for _, entry := range res.Entries { 168 for _, attr := range entry.Attributes { 169 groups = append(groups, attr.Values[0]) 170 } 171 } 172 return groups, nil 173 } 174 175 // LdapUserData holds username and LDAP groups as well as enough data to 176 // intelligently update itself. 177 type LdapUserData struct { 178 asl *AuthServerLdap 179 groups []string 180 username string 181 lastUpdated time.Time 182 updating bool 183 sync.Mutex 184 } 185 186 func (lud *LdapUserData) update() { 187 lud.Lock() 188 if lud.updating { 189 lud.Unlock() 190 return 191 } 192 lud.updating = true 193 lud.Unlock() 194 err := lud.asl.Client.Connect("tcp", &lud.asl.ServerConfig) 195 if err != nil { 196 log.Errorf("Error updating LDAP user data: %v", err) 197 return 198 } 199 defer lud.asl.Client.Close() //after the error check 200 groups, err := lud.asl.getGroups(lud.username) 201 if err != nil { 202 log.Errorf("Error updating LDAP user data: %v", err) 203 return 204 } 205 lud.Lock() 206 lud.groups = groups 207 lud.lastUpdated = time.Now() 208 lud.updating = false 209 lud.Unlock() 210 } 211 212 // Get returns wrapped username and LDAP groups and possibly updates the cache 213 func (lud *LdapUserData) Get() *querypb.VTGateCallerID { 214 if int64(time.Since(lud.lastUpdated).Seconds()) > lud.asl.RefreshSeconds { 215 go lud.update() 216 } 217 return &querypb.VTGateCallerID{Username: lud.username, Groups: lud.groups} 218 } 219 220 // ServerConfig holds the config for and LDAP server 221 // * include port in ldapServer, "ldap.example.com:386" 222 type ServerConfig struct { 223 LdapServer string 224 LdapCert string 225 LdapKey string 226 LdapCA string 227 LdapCRL string 228 LdapTLSMinVersion string 229 } 230 231 // Client provides an interface we can mock 232 type Client interface { 233 Connect(network string, config *ServerConfig) error 234 Close() 235 Bind(string, string) error 236 Search(*ldap.SearchRequest) (*ldap.SearchResult, error) 237 } 238 239 // ClientImpl is the real implementation of LdapClient 240 type ClientImpl struct { 241 *ldap.Conn 242 } 243 244 // Connect calls ldap.Dial and then upgrades the connection to TLS 245 // This must be called before any other methods 246 func (lci *ClientImpl) Connect(network string, config *ServerConfig) error { 247 conn, err := ldap.Dial(network, config.LdapServer) 248 if err != nil { 249 return err 250 } 251 lci.Conn = conn 252 // Reconnect with TLS ... why don't we simply DialTLS directly? 253 serverName, _, err := netutil.SplitHostPort(config.LdapServer) 254 if err != nil { 255 return err 256 } 257 258 tlsVersion, err := vttls.TLSVersionToNumber(config.LdapTLSMinVersion) 259 if err != nil { 260 return err 261 } 262 263 tlsConfig, err := vttls.ClientConfig(vttls.VerifyIdentity, config.LdapCert, config.LdapKey, config.LdapCA, config.LdapCRL, serverName, tlsVersion) 264 if err != nil { 265 return err 266 } 267 err = conn.StartTLS(tlsConfig) 268 if err != nil { 269 return err 270 } 271 return nil 272 }