github.com/infraboard/keyauth@v0.8.1/apps/provider/auth/ldap/ldap.go (about) 1 package ldap 2 3 import ( 4 "crypto/tls" 5 "fmt" 6 "net/url" 7 "strings" 8 9 "github.com/go-ldap/ldap/v3" 10 "github.com/infraboard/mcube/exception" 11 "github.com/infraboard/mcube/logger" 12 "github.com/infraboard/mcube/logger/zap" 13 ) 14 15 // OWASP recommends to escape some special characters. 16 // https://github.com/OWASP/CheatSheetSeries/blob/master/cheatsheets/LDAP_Injection_Prevention_Cheat_Sheet.md 17 const specialLDAPRunes = ",#+<>;\"=" 18 19 // UserProvider LDAP provider 20 type UserProvider interface { 21 CheckConnect() error 22 CheckUserPassword(username string, password string) (bool, error) 23 GetDetails(username string) (*UserProfile, error) 24 UpdatePassword(username string, newPassword string) error 25 } 26 27 // NewProvider todo 28 func NewProvider(conf *Config) *Provider { 29 return &Provider{ 30 conf: conf, 31 log: zap.L().Named("LDAP"), 32 } 33 } 34 35 // Provider todo 36 type Provider struct { 37 conf *Config 38 log logger.Logger 39 } 40 41 func (p *Provider) dialTLS(network, addr string, config *tls.Config) (Connection, error) { 42 conn, err := ldap.DialTLS(network, addr, config) 43 if err != nil { 44 return nil, err 45 } 46 47 return NewLDAPConnectionImpl(conn), nil 48 } 49 50 func (p *Provider) dial(network, addr string) (Connection, error) { 51 conn, err := ldap.Dial(network, addr) 52 if err != nil { 53 return nil, err 54 } 55 56 return NewLDAPConnectionImpl(conn), nil 57 } 58 59 func (p *Provider) connect(userDN string, password string) (Connection, error) { 60 var conn Connection 61 62 url, err := url.Parse(p.conf.URL) 63 if err != nil { 64 return nil, fmt.Errorf("unable to parse URL to LDAP: %s", url) 65 } 66 67 if url.Scheme == "ldaps" { 68 p.log.Debug("LDAP client starts a TLS session") 69 tlsConn, err := p.dialTLS("tcp", url.Host, &tls.Config{ 70 InsecureSkipVerify: p.conf.SkipVerify, 71 }) 72 if err != nil { 73 return nil, err 74 } 75 76 conn = tlsConn 77 } else { 78 p.log.Debug("LDAP client starts a session over raw TCP") 79 rawConn, err := p.dial("tcp", url.Host) 80 if err != nil { 81 return nil, err 82 } 83 conn = rawConn 84 } 85 86 if err := conn.Bind(userDN, password); err != nil { 87 return nil, err 88 } 89 90 return conn, nil 91 } 92 93 // CheckConnect todo 94 func (p *Provider) CheckConnect() error { 95 adminClient, err := p.connect(p.conf.User, p.conf.Password) 96 if err != nil { 97 return err 98 } 99 defer adminClient.Close() 100 101 return nil 102 } 103 104 // CheckUserPassword checks if provided password matches for the given user. 105 func (p *Provider) CheckUserPassword(inputUsername string, password string) (bool, error) { 106 adminClient, err := p.connect(p.conf.User, p.conf.Password) 107 if err != nil { 108 return false, err 109 } 110 defer adminClient.Close() 111 112 profile, err := p.getUserProfile(adminClient, inputUsername) 113 if err != nil { 114 return false, err 115 } 116 117 conn, err := p.connect(profile.DN, password) 118 if err != nil { 119 return false, fmt.Errorf("authentication of user %s failed. Cause: %s", inputUsername, err) 120 } 121 defer conn.Close() 122 123 return true, nil 124 } 125 126 func (p *Provider) ldapEscape(inputUsername string) string { 127 inputUsername = ldap.EscapeFilter(inputUsername) 128 for _, c := range specialLDAPRunes { 129 inputUsername = strings.ReplaceAll(inputUsername, string(c), fmt.Sprintf("\\%c", c)) 130 } 131 132 return inputUsername 133 } 134 135 func (p *Provider) resolveUsersFilter(userFilter string, inputUsername string) string { 136 inputUsername = p.ldapEscape(inputUsername) 137 138 // We temporarily keep placeholder {0} for backward compatibility. 139 userFilter = strings.ReplaceAll(userFilter, "{0}", inputUsername) 140 141 // The {username} placeholder is equivalent to {0}, it's the new way, a named placeholder. 142 userFilter = strings.ReplaceAll(userFilter, "{input}", inputUsername) 143 144 // {username_attribute} and {mail_attribute} are replaced by the content of the attribute defined 145 // in configuration. 146 userFilter = strings.ReplaceAll(userFilter, "{username_attribute}", p.conf.UsernameAttribute) 147 userFilter = strings.ReplaceAll(userFilter, "{mail_attribute}", p.conf.MailAttribute) 148 return userFilter 149 } 150 151 func (p *Provider) getUserProfile(conn Connection, inputUsername string) (*UserProfile, error) { 152 userFilter := p.resolveUsersFilter(p.conf.UsersFilter, inputUsername) 153 p.log.Debugf("Computed user filter is %s", userFilter) 154 155 baseDN := p.conf.BaseDN 156 if p.conf.AdditionalUsersDN != "" { 157 baseDN = p.conf.AdditionalUsersDN + "," + baseDN 158 } 159 160 attributes := []string{"dn", 161 p.conf.MailAttribute, 162 p.conf.UsernameAttribute, 163 p.conf.DisplayNameAttribute, 164 } 165 166 // Search for the given username. 167 searchRequest := ldap.NewSearchRequest( 168 baseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 169 1, 0, false, userFilter, attributes, nil, 170 ) 171 172 sr, err := conn.Search(searchRequest) 173 if err != nil { 174 return nil, fmt.Errorf("cannot find user DN of user %s. Cause: %s", inputUsername, err) 175 } 176 177 if len(sr.Entries) == 0 { 178 return nil, exception.NewNotFound("user not found") 179 } 180 181 if len(sr.Entries) > 1 { 182 return nil, fmt.Errorf("multiple users %s found", inputUsername) 183 } 184 185 userProfile := UserProfile{ 186 DN: sr.Entries[0].DN, 187 } 188 189 for _, attr := range sr.Entries[0].Attributes { 190 if attr.Name == p.conf.MailAttribute { 191 userProfile.Emails = attr.Values 192 } 193 194 if attr.Name == p.conf.UsernameAttribute { 195 if len(attr.Values) != 1 { 196 return nil, fmt.Errorf("user %s cannot have multiple value for attribute %s", 197 inputUsername, p.conf.UsernameAttribute) 198 } 199 200 userProfile.Username = attr.Values[0] 201 } 202 if attr.Name == p.conf.DisplayNameAttribute { 203 userProfile.DisplayName = attr.Values[0] 204 } 205 } 206 207 if userProfile.DN == "" { 208 return nil, fmt.Errorf("no DN has been found for user %s", inputUsername) 209 } 210 211 return &userProfile, nil 212 } 213 214 func (p *Provider) resolveGroupsFilter(inputUsername string, profile *UserProfile) (string, error) { //nolint:unparam 215 inputUsername = p.ldapEscape(inputUsername) 216 217 // We temporarily keep placeholder {0} for backward compatibility. 218 groupFilter := strings.ReplaceAll(p.conf.GroupsFilter, "{0}", inputUsername) 219 groupFilter = strings.ReplaceAll(groupFilter, "{input}", inputUsername) 220 221 if profile != nil { 222 // We temporarily keep placeholder {1} for backward compatibility. 223 groupFilter = strings.ReplaceAll(groupFilter, "{1}", ldap.EscapeFilter(profile.Username)) 224 groupFilter = strings.ReplaceAll(groupFilter, "{username}", ldap.EscapeFilter(profile.Username)) 225 groupFilter = strings.ReplaceAll(groupFilter, "{dn}", ldap.EscapeFilter(profile.DN)) 226 } 227 228 return groupFilter, nil 229 } 230 231 // GetDetails retrieve the groups a user belongs to. 232 func (p *Provider) GetDetails(inputUsername string) (*UserProfile, error) { 233 conn, err := p.connect(p.conf.User, p.conf.Password) 234 if err != nil { 235 return nil, err 236 } 237 defer conn.Close() 238 239 profile, err := p.getUserProfile(conn, inputUsername) 240 if err != nil { 241 return nil, err 242 } 243 244 groupsFilter, err := p.resolveGroupsFilter(inputUsername, profile) 245 if err != nil { 246 return nil, fmt.Errorf("unable to create group filter for user %s. Cause: %s", inputUsername, err) 247 } 248 249 p.log.Debugf("Computed groups filter is %s", groupsFilter) 250 251 groupBaseDN := p.conf.BaseDN 252 if p.conf.AdditionalGroupsDN != "" { 253 groupBaseDN = p.conf.AdditionalGroupsDN + "," + groupBaseDN 254 } 255 256 // Search for the given username. 257 searchGroupRequest := ldap.NewSearchRequest( 258 groupBaseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 259 0, 0, false, groupsFilter, []string{p.conf.GroupNameAttribute}, nil, 260 ) 261 262 sr, err := conn.Search(searchGroupRequest) 263 264 if err != nil { 265 return nil, fmt.Errorf("unable to retrieve groups of user %s. Cause: %s", inputUsername, err) 266 } 267 268 for _, res := range sr.Entries { 269 if len(res.Attributes) == 0 { 270 p.log.Warnf("No groups retrieved from LDAP for user %s", inputUsername) 271 break 272 } 273 // Append all values of the document. Normally there should be only one per document. 274 profile.Groups = append(profile.Groups, res.Attributes[0].Values...) 275 } 276 277 return profile, nil 278 } 279 280 // UpdatePassword update the password of the given user. 281 func (p *Provider) UpdatePassword(inputUsername string, newPassword string) error { 282 client, err := p.connect(p.conf.User, p.conf.Password) 283 284 if err != nil { 285 return fmt.Errorf("unable to update password. Cause: %s", err) 286 } 287 288 profile, err := p.getUserProfile(client, inputUsername) 289 290 if err != nil { 291 return fmt.Errorf("unable to update password. Cause: %s", err) 292 } 293 294 modifyRequest := ldap.NewModifyRequest(profile.DN, nil) 295 296 modifyRequest.Replace("userPassword", []string{newPassword}) 297 298 err = client.Modify(modifyRequest) 299 300 if err != nil { 301 return fmt.Errorf("unable to update password. Cause: %s", err) 302 } 303 304 return nil 305 }