github.com/infraboard/keyauth@v0.8.1/apps/provider/auth/ldap/ldap.go (about)

     1  package ldap
     2  
     3  import (
     4  	"crypto/tls"
     5  	"fmt"
     6  	"net/url"
     7  	"strings"
     8  
     9  	"github.com/go-ldap/ldap/v3"
    10  	"github.com/infraboard/mcube/exception"
    11  	"github.com/infraboard/mcube/logger"
    12  	"github.com/infraboard/mcube/logger/zap"
    13  )
    14  
    15  // OWASP recommends to escape some special characters.
    16  // https://github.com/OWASP/CheatSheetSeries/blob/master/cheatsheets/LDAP_Injection_Prevention_Cheat_Sheet.md
    17  const specialLDAPRunes = ",#+<>;\"="
    18  
    19  // UserProvider LDAP provider
    20  type UserProvider interface {
    21  	CheckConnect() error
    22  	CheckUserPassword(username string, password string) (bool, error)
    23  	GetDetails(username string) (*UserProfile, error)
    24  	UpdatePassword(username string, newPassword string) error
    25  }
    26  
    27  // NewProvider todo
    28  func NewProvider(conf *Config) *Provider {
    29  	return &Provider{
    30  		conf: conf,
    31  		log:  zap.L().Named("LDAP"),
    32  	}
    33  }
    34  
    35  // Provider todo
    36  type Provider struct {
    37  	conf *Config
    38  	log  logger.Logger
    39  }
    40  
    41  func (p *Provider) dialTLS(network, addr string, config *tls.Config) (Connection, error) {
    42  	conn, err := ldap.DialTLS(network, addr, config)
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  
    47  	return NewLDAPConnectionImpl(conn), nil
    48  }
    49  
    50  func (p *Provider) dial(network, addr string) (Connection, error) {
    51  	conn, err := ldap.Dial(network, addr)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	return NewLDAPConnectionImpl(conn), nil
    57  }
    58  
    59  func (p *Provider) connect(userDN string, password string) (Connection, error) {
    60  	var conn Connection
    61  
    62  	url, err := url.Parse(p.conf.URL)
    63  	if err != nil {
    64  		return nil, fmt.Errorf("unable to parse URL to LDAP: %s", url)
    65  	}
    66  
    67  	if url.Scheme == "ldaps" {
    68  		p.log.Debug("LDAP client starts a TLS session")
    69  		tlsConn, err := p.dialTLS("tcp", url.Host, &tls.Config{
    70  			InsecureSkipVerify: p.conf.SkipVerify,
    71  		})
    72  		if err != nil {
    73  			return nil, err
    74  		}
    75  
    76  		conn = tlsConn
    77  	} else {
    78  		p.log.Debug("LDAP client starts a session over raw TCP")
    79  		rawConn, err := p.dial("tcp", url.Host)
    80  		if err != nil {
    81  			return nil, err
    82  		}
    83  		conn = rawConn
    84  	}
    85  
    86  	if err := conn.Bind(userDN, password); err != nil {
    87  		return nil, err
    88  	}
    89  
    90  	return conn, nil
    91  }
    92  
    93  // CheckConnect todo
    94  func (p *Provider) CheckConnect() error {
    95  	adminClient, err := p.connect(p.conf.User, p.conf.Password)
    96  	if err != nil {
    97  		return err
    98  	}
    99  	defer adminClient.Close()
   100  
   101  	return nil
   102  }
   103  
   104  // CheckUserPassword checks if provided password matches for the given user.
   105  func (p *Provider) CheckUserPassword(inputUsername string, password string) (bool, error) {
   106  	adminClient, err := p.connect(p.conf.User, p.conf.Password)
   107  	if err != nil {
   108  		return false, err
   109  	}
   110  	defer adminClient.Close()
   111  
   112  	profile, err := p.getUserProfile(adminClient, inputUsername)
   113  	if err != nil {
   114  		return false, err
   115  	}
   116  
   117  	conn, err := p.connect(profile.DN, password)
   118  	if err != nil {
   119  		return false, fmt.Errorf("authentication of user %s failed. Cause: %s", inputUsername, err)
   120  	}
   121  	defer conn.Close()
   122  
   123  	return true, nil
   124  }
   125  
   126  func (p *Provider) ldapEscape(inputUsername string) string {
   127  	inputUsername = ldap.EscapeFilter(inputUsername)
   128  	for _, c := range specialLDAPRunes {
   129  		inputUsername = strings.ReplaceAll(inputUsername, string(c), fmt.Sprintf("\\%c", c))
   130  	}
   131  
   132  	return inputUsername
   133  }
   134  
   135  func (p *Provider) resolveUsersFilter(userFilter string, inputUsername string) string {
   136  	inputUsername = p.ldapEscape(inputUsername)
   137  
   138  	// We temporarily keep placeholder {0} for backward compatibility.
   139  	userFilter = strings.ReplaceAll(userFilter, "{0}", inputUsername)
   140  
   141  	// The {username} placeholder is equivalent to {0}, it's the new way, a named placeholder.
   142  	userFilter = strings.ReplaceAll(userFilter, "{input}", inputUsername)
   143  
   144  	// {username_attribute} and {mail_attribute} are replaced by the content of the attribute defined
   145  	// in configuration.
   146  	userFilter = strings.ReplaceAll(userFilter, "{username_attribute}", p.conf.UsernameAttribute)
   147  	userFilter = strings.ReplaceAll(userFilter, "{mail_attribute}", p.conf.MailAttribute)
   148  	return userFilter
   149  }
   150  
   151  func (p *Provider) getUserProfile(conn Connection, inputUsername string) (*UserProfile, error) {
   152  	userFilter := p.resolveUsersFilter(p.conf.UsersFilter, inputUsername)
   153  	p.log.Debugf("Computed user filter is %s", userFilter)
   154  
   155  	baseDN := p.conf.BaseDN
   156  	if p.conf.AdditionalUsersDN != "" {
   157  		baseDN = p.conf.AdditionalUsersDN + "," + baseDN
   158  	}
   159  
   160  	attributes := []string{"dn",
   161  		p.conf.MailAttribute,
   162  		p.conf.UsernameAttribute,
   163  		p.conf.DisplayNameAttribute,
   164  	}
   165  
   166  	// Search for the given username.
   167  	searchRequest := ldap.NewSearchRequest(
   168  		baseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases,
   169  		1, 0, false, userFilter, attributes, nil,
   170  	)
   171  
   172  	sr, err := conn.Search(searchRequest)
   173  	if err != nil {
   174  		return nil, fmt.Errorf("cannot find user DN of user %s. Cause: %s", inputUsername, err)
   175  	}
   176  
   177  	if len(sr.Entries) == 0 {
   178  		return nil, exception.NewNotFound("user not found")
   179  	}
   180  
   181  	if len(sr.Entries) > 1 {
   182  		return nil, fmt.Errorf("multiple users %s found", inputUsername)
   183  	}
   184  
   185  	userProfile := UserProfile{
   186  		DN: sr.Entries[0].DN,
   187  	}
   188  
   189  	for _, attr := range sr.Entries[0].Attributes {
   190  		if attr.Name == p.conf.MailAttribute {
   191  			userProfile.Emails = attr.Values
   192  		}
   193  
   194  		if attr.Name == p.conf.UsernameAttribute {
   195  			if len(attr.Values) != 1 {
   196  				return nil, fmt.Errorf("user %s cannot have multiple value for attribute %s",
   197  					inputUsername, p.conf.UsernameAttribute)
   198  			}
   199  
   200  			userProfile.Username = attr.Values[0]
   201  		}
   202  		if attr.Name == p.conf.DisplayNameAttribute {
   203  			userProfile.DisplayName = attr.Values[0]
   204  		}
   205  	}
   206  
   207  	if userProfile.DN == "" {
   208  		return nil, fmt.Errorf("no DN has been found for user %s", inputUsername)
   209  	}
   210  
   211  	return &userProfile, nil
   212  }
   213  
   214  func (p *Provider) resolveGroupsFilter(inputUsername string, profile *UserProfile) (string, error) { //nolint:unparam
   215  	inputUsername = p.ldapEscape(inputUsername)
   216  
   217  	// We temporarily keep placeholder {0} for backward compatibility.
   218  	groupFilter := strings.ReplaceAll(p.conf.GroupsFilter, "{0}", inputUsername)
   219  	groupFilter = strings.ReplaceAll(groupFilter, "{input}", inputUsername)
   220  
   221  	if profile != nil {
   222  		// We temporarily keep placeholder {1} for backward compatibility.
   223  		groupFilter = strings.ReplaceAll(groupFilter, "{1}", ldap.EscapeFilter(profile.Username))
   224  		groupFilter = strings.ReplaceAll(groupFilter, "{username}", ldap.EscapeFilter(profile.Username))
   225  		groupFilter = strings.ReplaceAll(groupFilter, "{dn}", ldap.EscapeFilter(profile.DN))
   226  	}
   227  
   228  	return groupFilter, nil
   229  }
   230  
   231  // GetDetails retrieve the groups a user belongs to.
   232  func (p *Provider) GetDetails(inputUsername string) (*UserProfile, error) {
   233  	conn, err := p.connect(p.conf.User, p.conf.Password)
   234  	if err != nil {
   235  		return nil, err
   236  	}
   237  	defer conn.Close()
   238  
   239  	profile, err := p.getUserProfile(conn, inputUsername)
   240  	if err != nil {
   241  		return nil, err
   242  	}
   243  
   244  	groupsFilter, err := p.resolveGroupsFilter(inputUsername, profile)
   245  	if err != nil {
   246  		return nil, fmt.Errorf("unable to create group filter for user %s. Cause: %s", inputUsername, err)
   247  	}
   248  
   249  	p.log.Debugf("Computed groups filter is %s", groupsFilter)
   250  
   251  	groupBaseDN := p.conf.BaseDN
   252  	if p.conf.AdditionalGroupsDN != "" {
   253  		groupBaseDN = p.conf.AdditionalGroupsDN + "," + groupBaseDN
   254  	}
   255  
   256  	// Search for the given username.
   257  	searchGroupRequest := ldap.NewSearchRequest(
   258  		groupBaseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases,
   259  		0, 0, false, groupsFilter, []string{p.conf.GroupNameAttribute}, nil,
   260  	)
   261  
   262  	sr, err := conn.Search(searchGroupRequest)
   263  
   264  	if err != nil {
   265  		return nil, fmt.Errorf("unable to retrieve groups of user %s. Cause: %s", inputUsername, err)
   266  	}
   267  
   268  	for _, res := range sr.Entries {
   269  		if len(res.Attributes) == 0 {
   270  			p.log.Warnf("No groups retrieved from LDAP for user %s", inputUsername)
   271  			break
   272  		}
   273  		// Append all values of the document. Normally there should be only one per document.
   274  		profile.Groups = append(profile.Groups, res.Attributes[0].Values...)
   275  	}
   276  
   277  	return profile, nil
   278  }
   279  
   280  // UpdatePassword update the password of the given user.
   281  func (p *Provider) UpdatePassword(inputUsername string, newPassword string) error {
   282  	client, err := p.connect(p.conf.User, p.conf.Password)
   283  
   284  	if err != nil {
   285  		return fmt.Errorf("unable to update password. Cause: %s", err)
   286  	}
   287  
   288  	profile, err := p.getUserProfile(client, inputUsername)
   289  
   290  	if err != nil {
   291  		return fmt.Errorf("unable to update password. Cause: %s", err)
   292  	}
   293  
   294  	modifyRequest := ldap.NewModifyRequest(profile.DN, nil)
   295  
   296  	modifyRequest.Replace("userPassword", []string{newPassword})
   297  
   298  	err = client.Modify(modifyRequest)
   299  
   300  	if err != nil {
   301  		return fmt.Errorf("unable to update password. Cause: %s", err)
   302  	}
   303  
   304  	return nil
   305  }