github.com/cs3org/reva/v2@v2.27.7/pkg/utils/ldap/reconnect.go (about)

     1  // Copyright 2022 CERN
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // In applying this license, CERN does not waive the privileges and immunities
    16  // granted to it by virtue of its status as an Intergovernmental Organization
    17  // or submit itself to any jurisdiction.
    18  
    19  package ldap
    20  
    21  // LDAP automatic reconnection mechanism, inspired by:
    22  // https://gist.github.com/emsearcy/cba3295d1a06d4c432ab4f6173b65e4f#file-ldap_snippet-go
    23  
    24  import (
    25  	"context"
    26  	"crypto/tls"
    27  	"errors"
    28  	"fmt"
    29  	"time"
    30  
    31  	"github.com/go-ldap/ldap/v3"
    32  	"github.com/rs/zerolog"
    33  )
    34  
    35  var (
    36  	defaultRetries = 1
    37  	errMaxRetries  = errors.New("max retries")
    38  )
    39  
    40  type ldapConnection struct {
    41  	Conn  *ldap.Conn
    42  	Error error
    43  }
    44  
    45  // ConnWithReconnect maintains an LDAP Connection that automatically reconnects after network errors
    46  type ConnWithReconnect struct {
    47  	conn    chan ldapConnection
    48  	reset   chan *ldap.Conn
    49  	retries int
    50  	logger  *zerolog.Logger
    51  }
    52  
    53  // Config holds the basic configuration of the LDAP Connection
    54  type Config struct {
    55  	URI          string
    56  	BindDN       string
    57  	BindPassword string
    58  	TLSConfig    *tls.Config
    59  }
    60  
    61  // NewLDAPWithReconnect Returns a new ConnWithReconnect initialized from config
    62  func NewLDAPWithReconnect(config Config) *ConnWithReconnect {
    63  	conn := ConnWithReconnect{
    64  		conn:    make(chan ldapConnection),
    65  		reset:   make(chan *ldap.Conn),
    66  		retries: defaultRetries,
    67  	}
    68  	logger := zerolog.Nop()
    69  	conn.logger = &logger
    70  	go conn.ldapAutoConnect(config)
    71  	return &conn
    72  }
    73  
    74  // SetLogger sets the logger for the current instance
    75  func (c *ConnWithReconnect) SetLogger(logger *zerolog.Logger) {
    76  	c.logger = logger
    77  }
    78  
    79  func (c *ConnWithReconnect) retry(fn func(c ldap.Client) error) error {
    80  	conn, err := c.getConnection()
    81  
    82  	if err != nil {
    83  		return err
    84  	}
    85  
    86  	for try := 0; try <= c.retries; try++ {
    87  		if try > 0 {
    88  			c.logger.Debug().Msgf("retrying attempt %d", try)
    89  			conn, err = c.reconnect(conn)
    90  			if err != nil {
    91  				// reconnection failed stop this attempt
    92  				return err
    93  			}
    94  		}
    95  		if err = fn(conn); err == nil {
    96  			// function succeed no need to retry
    97  			return nil
    98  		}
    99  		if !ldap.IsErrorWithCode(err, ldap.ErrorNetwork) {
   100  			// non network error, stop retrying
   101  			return err
   102  		}
   103  	}
   104  	return ldap.NewError(ldap.ErrorNetwork, errMaxRetries)
   105  }
   106  
   107  // Search implements the ldap.Client interface
   108  func (c *ConnWithReconnect) Search(sr *ldap.SearchRequest) (*ldap.SearchResult, error) {
   109  	var err error
   110  	var res *ldap.SearchResult
   111  
   112  	retryErr := c.retry(func(c ldap.Client) error {
   113  		res, err = c.Search(sr)
   114  		return err
   115  	})
   116  
   117  	return res, retryErr
   118  
   119  }
   120  
   121  // Add implements the ldap.Client interface
   122  func (c *ConnWithReconnect) Add(a *ldap.AddRequest) error {
   123  	err := c.retry(func(c ldap.Client) error {
   124  		return c.Add(a)
   125  	})
   126  
   127  	return err
   128  }
   129  
   130  // Del implements the ldap.Client interface
   131  func (c *ConnWithReconnect) Del(d *ldap.DelRequest) error {
   132  	err := c.retry(func(c ldap.Client) error {
   133  		return c.Del(d)
   134  	})
   135  
   136  	return err
   137  }
   138  
   139  // Modify implements the ldap.Client interface
   140  func (c *ConnWithReconnect) Modify(m *ldap.ModifyRequest) error {
   141  	err := c.retry(func(c ldap.Client) error {
   142  		return c.Modify(m)
   143  	})
   144  
   145  	return err
   146  }
   147  
   148  // ModifyDN implements the ldap.Client interface
   149  func (c *ConnWithReconnect) ModifyDN(m *ldap.ModifyDNRequest) error {
   150  	err := c.retry(func(c ldap.Client) error {
   151  		return c.ModifyDN(m)
   152  	})
   153  
   154  	return err
   155  }
   156  
   157  func (c *ConnWithReconnect) getConnection() (*ldap.Conn, error) {
   158  	conn := <-c.conn
   159  	if conn.Conn != nil && !ldap.IsErrorWithCode(conn.Error, ldap.ErrorNetwork) {
   160  		c.logger.Debug().Msg("using existing Connection")
   161  		return conn.Conn, conn.Error
   162  	}
   163  	return c.reconnect(conn.Conn)
   164  }
   165  
   166  func (c *ConnWithReconnect) ldapAutoConnect(config Config) {
   167  	var (
   168  		l   *ldap.Conn
   169  		err error
   170  	)
   171  
   172  	for {
   173  		select {
   174  		case resConn := <-c.reset:
   175  			// Only close the connection and reconnect if the current
   176  			// connection, matches the one we got via the reset channel.
   177  			// If they differ we already reconnected
   178  			switch {
   179  			case l == nil:
   180  				c.logger.Debug().Msg("reconnecting to LDAP")
   181  				l, err = c.ldapConnect(config)
   182  			case l != resConn:
   183  				c.logger.Debug().Msg("already reconnected")
   184  				continue
   185  			default:
   186  				c.logger.Debug().Msg("closing and reconnecting to LDAP")
   187  				l.Close()
   188  				l, err = c.ldapConnect(config)
   189  			}
   190  		case c.conn <- ldapConnection{l, err}:
   191  		}
   192  	}
   193  }
   194  
   195  func (c *ConnWithReconnect) ldapConnect(config Config) (*ldap.Conn, error) {
   196  	c.logger.Debug().Msgf("Connecting to %s", config.URI)
   197  
   198  	var err error
   199  	var l *ldap.Conn
   200  	if config.TLSConfig != nil {
   201  		l, err = ldap.DialURL(config.URI, ldap.DialWithTLSConfig(config.TLSConfig))
   202  	} else {
   203  		l, err = ldap.DialURL(config.URI)
   204  	}
   205  
   206  	if err != nil {
   207  		c.logger.Error().Err(err).Msg("could not get ldap Connection")
   208  		return nil, err
   209  	}
   210  	c.logger.Debug().Msg("LDAP Connected")
   211  	if config.BindDN != "" {
   212  		c.logger.Debug().Msgf("Binding as %s", config.BindDN)
   213  		err = l.Bind(config.BindDN, config.BindPassword)
   214  		if err != nil {
   215  			c.logger.Debug().Err(err).Msg("Bind failed")
   216  			l.Close()
   217  			return nil, err
   218  		}
   219  
   220  	}
   221  	return l, err
   222  
   223  }
   224  
   225  func (c *ConnWithReconnect) reconnect(resetConn *ldap.Conn) (*ldap.Conn, error) {
   226  	c.logger.Debug().Msg("LDAP connection reset")
   227  	c.reset <- resetConn
   228  	c.logger.Debug().Msg("Waiting for new connection")
   229  	result := <-c.conn
   230  	return result.Conn, result.Error
   231  }
   232  
   233  // Remaining methods to fulfill ldap.Client interface
   234  
   235  // Start implements the ldap.Client interface
   236  func (c *ConnWithReconnect) Start() {}
   237  
   238  // StartTLS implements the ldap.Client interface
   239  func (c *ConnWithReconnect) StartTLS(*tls.Config) error {
   240  	return ldap.NewError(ldap.LDAPResultNotSupported, fmt.Errorf("not implemented"))
   241  }
   242  
   243  // Close implements the ldap.Client interface
   244  func (c *ConnWithReconnect) Close() (err error) {
   245  	conn, err := c.getConnection()
   246  
   247  	if err != nil {
   248  		return err
   249  	}
   250  	return conn.Close()
   251  }
   252  
   253  func (c *ConnWithReconnect) GetLastError() error {
   254  	conn, err := c.getConnection()
   255  
   256  	if err != nil {
   257  		return err
   258  	}
   259  	return conn.GetLastError()
   260  }
   261  
   262  // IsClosing implements the ldap.Client interface
   263  func (c *ConnWithReconnect) IsClosing() bool {
   264  	return false
   265  }
   266  
   267  // SetTimeout implements the ldap.Client interface
   268  func (c *ConnWithReconnect) SetTimeout(time.Duration) {}
   269  
   270  // Bind implements the ldap.Client interface
   271  func (c *ConnWithReconnect) Bind(username, password string) error {
   272  	return ldap.NewError(ldap.LDAPResultNotSupported, fmt.Errorf("not implemented"))
   273  }
   274  
   275  // UnauthenticatedBind implements the ldap.Client interface
   276  func (c *ConnWithReconnect) UnauthenticatedBind(username string) error {
   277  	return ldap.NewError(ldap.LDAPResultNotSupported, fmt.Errorf("not implemented"))
   278  }
   279  
   280  // SimpleBind implements the ldap.Client interface
   281  func (c *ConnWithReconnect) SimpleBind(*ldap.SimpleBindRequest) (*ldap.SimpleBindResult, error) {
   282  	return nil, ldap.NewError(ldap.LDAPResultNotSupported, fmt.Errorf("not implemented"))
   283  }
   284  
   285  // ExternalBind implements the ldap.Client interface
   286  func (c *ConnWithReconnect) ExternalBind() error {
   287  	return ldap.NewError(ldap.LDAPResultNotSupported, fmt.Errorf("not implemented"))
   288  }
   289  
   290  // ModifyWithResult implements the ldap.Client interface
   291  func (c *ConnWithReconnect) ModifyWithResult(m *ldap.ModifyRequest) (*ldap.ModifyResult, error) {
   292  	return nil, ldap.NewError(ldap.LDAPResultNotSupported, fmt.Errorf("not implemented"))
   293  }
   294  
   295  // Compare implements the ldap.Client interface
   296  func (c *ConnWithReconnect) Compare(dn, attribute, value string) (bool, error) {
   297  	return false, ldap.NewError(ldap.LDAPResultNotSupported, fmt.Errorf("not implemented"))
   298  }
   299  
   300  // PasswordModify implements the ldap.Client interface
   301  func (c *ConnWithReconnect) PasswordModify(*ldap.PasswordModifyRequest) (*ldap.PasswordModifyResult, error) {
   302  	return nil, ldap.NewError(ldap.LDAPResultNotSupported, fmt.Errorf("not implemented"))
   303  }
   304  
   305  // SearchWithPaging implements the ldap.Client interface
   306  func (c *ConnWithReconnect) SearchWithPaging(searchRequest *ldap.SearchRequest, pagingSize uint32) (*ldap.SearchResult, error) {
   307  	return nil, ldap.NewError(ldap.LDAPResultNotSupported, fmt.Errorf("not implemented"))
   308  }
   309  
   310  // SearchAsync implements the ldap.Client interface
   311  func (c *ConnWithReconnect) SearchAsync(ctx context.Context, searchRequest *ldap.SearchRequest, bufferSize int) ldap.Response {
   312  	// unimplemented
   313  	return nil
   314  }
   315  
   316  // NTLMUnauthenticatedBind implements the ldap.Client interface
   317  func (c *ConnWithReconnect) NTLMUnauthenticatedBind(domain, username string) error {
   318  	return ldap.NewError(ldap.LDAPResultNotSupported, fmt.Errorf("not implemented"))
   319  }
   320  
   321  // TLSConnectionState implements the ldap.Client interface
   322  func (c *ConnWithReconnect) TLSConnectionState() (tls.ConnectionState, bool) {
   323  	return tls.ConnectionState{}, false
   324  }
   325  
   326  // Unbind implements the ldap.Client interface
   327  func (c *ConnWithReconnect) Unbind() error {
   328  	return ldap.NewError(ldap.LDAPResultNotSupported, fmt.Errorf("not implemented"))
   329  }
   330  
   331  // DirSync implements the ldap.Client interface
   332  func (c *ConnWithReconnect) DirSync(searchRequest *ldap.SearchRequest, flags, maxAttrCount int64, cookie []byte) (*ldap.SearchResult, error) {
   333  	return nil, ldap.NewError(ldap.LDAPResultNotSupported, fmt.Errorf("not implemented"))
   334  }
   335  
   336  // DirSyncAsync implements the ldap.Client interface
   337  func (c *ConnWithReconnect) DirSyncAsync(ctx context.Context, searchRequest *ldap.SearchRequest, bufferSize int, flags, maxAttrCount int64, cookie []byte) ldap.Response {
   338  	// unimplemented
   339  	return nil
   340  }
   341  
   342  // Syncrepl implements the ldap.Client interface
   343  func (c *ConnWithReconnect) Syncrepl(ctx context.Context, searchRequest *ldap.SearchRequest, bufferSize int, mode ldap.ControlSyncRequestMode, cookie []byte, reloadHint bool) ldap.Response {
   344  	// unimplemented
   345  	return nil
   346  }