github.com/hellobchain/third_party@v0.0.0-20230331131523-deb0478a2e52/ldap.v2/conn.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ldap
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"github.com/hellobchain/newcryptosm/tls"
    11  	ber "gopkg.in/asn1-ber.v1"
    12  	"log"
    13  	"net"
    14  	"sync"
    15  	"time"
    16  )
    17  
    18  const (
    19  	// MessageQuit causes the processMessages loop to exit
    20  	MessageQuit = 0
    21  	// MessageRequest sends a request to the server
    22  	MessageRequest = 1
    23  	// MessageResponse receives a response from the server
    24  	MessageResponse = 2
    25  	// MessageFinish indicates the client considers a particular message ID to be finished
    26  	MessageFinish = 3
    27  	// MessageTimeout indicates the client-specified timeout for a particular message ID has been reached
    28  	MessageTimeout = 4
    29  )
    30  
    31  // PacketResponse contains the packet or error encountered reading a response
    32  type PacketResponse struct {
    33  	// Packet is the packet read from the server
    34  	Packet *ber.Packet
    35  	// Error is an error encountered while reading
    36  	Error error
    37  }
    38  
    39  // ReadPacket returns the packet or an error
    40  func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) {
    41  	if (pr == nil) || (pr.Packet == nil && pr.Error == nil) {
    42  		return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
    43  	}
    44  	return pr.Packet, pr.Error
    45  }
    46  
    47  type messageContext struct {
    48  	id int64
    49  	// close(done) should only be called from finishMessage()
    50  	done chan struct{}
    51  	// close(responses) should only be called from processMessages(), and only sent to from sendResponse()
    52  	responses chan *PacketResponse
    53  }
    54  
    55  // sendResponse should only be called within the processMessages() loop which
    56  // is also responsible for closing the responses channel.
    57  func (msgCtx *messageContext) sendResponse(packet *PacketResponse) {
    58  	select {
    59  	case msgCtx.responses <- packet:
    60  		// Successfully sent packet to message handler.
    61  	case <-msgCtx.done:
    62  		// The request handler is done and will not receive more
    63  		// packets.
    64  	}
    65  }
    66  
    67  type messagePacket struct {
    68  	Op        int
    69  	MessageID int64
    70  	Packet    *ber.Packet
    71  	Context   *messageContext
    72  }
    73  
    74  type sendMessageFlags uint
    75  
    76  const (
    77  	startTLS sendMessageFlags = 1 << iota
    78  )
    79  
    80  // Conn represents an LDAP Connection
    81  type Conn struct {
    82  	conn                net.Conn
    83  	isTLS               bool
    84  	isClosing           bool
    85  	closeErr            error
    86  	isStartingTLS       bool
    87  	Debug               debugging
    88  	chanConfirm         chan bool
    89  	messageContexts     map[int64]*messageContext
    90  	chanMessage         chan *messagePacket
    91  	chanMessageID       chan int64
    92  	wgSender            sync.WaitGroup
    93  	wgClose             sync.WaitGroup
    94  	once                sync.Once
    95  	outstandingRequests uint
    96  	messageMutex        sync.Mutex
    97  	requestTimeout      time.Duration
    98  }
    99  
   100  var _ Client = &Conn{}
   101  
   102  // DefaultTimeout is a package-level variable that sets the timeout value
   103  // used for the Dial and DialTLS methods.
   104  //
   105  // WARNING: since this is a package-level variable, setting this value from
   106  // multiple places will probably result in undesired behaviour.
   107  var DefaultTimeout = 60 * time.Second
   108  
   109  // Dial connects to the given address on the given network using net.Dial
   110  // and then returns a new Conn for the connection.
   111  func Dial(network, addr string) (*Conn, error) {
   112  	c, err := net.DialTimeout(network, addr, DefaultTimeout)
   113  	if err != nil {
   114  		return nil, NewError(ErrorNetwork, err)
   115  	}
   116  	conn := NewConn(c, false)
   117  	conn.Start()
   118  	return conn, nil
   119  }
   120  
   121  // DialTLS connects to the given address on the given network using tls.Dial
   122  // and then returns a new Conn for the connection.
   123  func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
   124  	dc, err := net.DialTimeout(network, addr, DefaultTimeout)
   125  	if err != nil {
   126  		return nil, NewError(ErrorNetwork, err)
   127  	}
   128  	c := tls.Client(dc, config)
   129  	err = c.Handshake()
   130  	if err != nil {
   131  		// Handshake error, close the established connection before we return an error
   132  		dc.Close()
   133  		return nil, NewError(ErrorNetwork, err)
   134  	}
   135  	conn := NewConn(c, true)
   136  	conn.Start()
   137  	return conn, nil
   138  }
   139  
   140  // NewConn returns a new Conn using conn for network I/O.
   141  func NewConn(conn net.Conn, isTLS bool) *Conn {
   142  	return &Conn{
   143  		conn:            conn,
   144  		chanConfirm:     make(chan bool),
   145  		chanMessageID:   make(chan int64),
   146  		chanMessage:     make(chan *messagePacket, 10),
   147  		messageContexts: map[int64]*messageContext{},
   148  		requestTimeout:  0,
   149  		isTLS:           isTLS,
   150  	}
   151  }
   152  
   153  // Start initializes goroutines to read responses and process messages
   154  func (l *Conn) Start() {
   155  	go l.reader()
   156  	go l.processMessages()
   157  	l.wgClose.Add(1)
   158  }
   159  
   160  // Close closes the connection.
   161  func (l *Conn) Close() {
   162  	l.once.Do(func() {
   163  		l.isClosing = true
   164  		l.wgSender.Wait()
   165  
   166  		l.Debug.Printf("Sending quit message and waiting for confirmation")
   167  		l.chanMessage <- &messagePacket{Op: MessageQuit}
   168  		<-l.chanConfirm
   169  		close(l.chanMessage)
   170  
   171  		l.Debug.Printf("Closing network connection")
   172  		if err := l.conn.Close(); err != nil {
   173  			log.Print(err)
   174  		}
   175  
   176  		l.wgClose.Done()
   177  	})
   178  	l.wgClose.Wait()
   179  }
   180  
   181  // SetTimeout sets the time after a request is sent that a MessageTimeout triggers
   182  func (l *Conn) SetTimeout(timeout time.Duration) {
   183  	if timeout > 0 {
   184  		l.requestTimeout = timeout
   185  	}
   186  }
   187  
   188  // Returns the next available messageID
   189  func (l *Conn) nextMessageID() int64 {
   190  	if l.chanMessageID != nil {
   191  		if messageID, ok := <-l.chanMessageID; ok {
   192  			return messageID
   193  		}
   194  	}
   195  	return 0
   196  }
   197  
   198  // StartTLS sends the command to start a TLS session and then creates a new TLS Client
   199  func (l *Conn) StartTLS(config *tls.Config) error {
   200  	if l.isTLS {
   201  		return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
   202  	}
   203  
   204  	packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
   205  	packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
   206  	request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
   207  	request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
   208  	packet.AppendChild(request)
   209  	l.Debug.PrintPacket(packet)
   210  
   211  	msgCtx, err := l.sendMessageWithFlags(packet, startTLS)
   212  	if err != nil {
   213  		return err
   214  	}
   215  	defer l.finishMessage(msgCtx)
   216  
   217  	l.Debug.Printf("%d: waiting for response", msgCtx.id)
   218  
   219  	packetResponse, ok := <-msgCtx.responses
   220  	if !ok {
   221  		return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
   222  	}
   223  	packet, err = packetResponse.ReadPacket()
   224  	l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
   225  	if err != nil {
   226  		return err
   227  	}
   228  
   229  	if l.Debug {
   230  		if err := addLDAPDescriptions(packet); err != nil {
   231  			l.Close()
   232  			return err
   233  		}
   234  		ber.PrintPacket(packet)
   235  	}
   236  
   237  	if resultCode, message := getLDAPResultCode(packet); resultCode == LDAPResultSuccess {
   238  		conn := tls.Client(l.conn, config)
   239  
   240  		if err := conn.Handshake(); err != nil {
   241  			l.Close()
   242  			return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", err))
   243  		}
   244  
   245  		l.isTLS = true
   246  		l.conn = conn
   247  	} else {
   248  		return NewError(resultCode, fmt.Errorf("ldap: cannot StartTLS (%s)", message))
   249  	}
   250  	go l.reader()
   251  
   252  	return nil
   253  }
   254  
   255  func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) {
   256  	return l.sendMessageWithFlags(packet, 0)
   257  }
   258  
   259  func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) {
   260  	if l.isClosing {
   261  		return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
   262  	}
   263  	l.messageMutex.Lock()
   264  	l.Debug.Printf("flags&startTLS = %d", flags&startTLS)
   265  	if l.isStartingTLS {
   266  		l.messageMutex.Unlock()
   267  		return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase"))
   268  	}
   269  	if flags&startTLS != 0 {
   270  		if l.outstandingRequests != 0 {
   271  			l.messageMutex.Unlock()
   272  			return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests"))
   273  		}
   274  		l.isStartingTLS = true
   275  	}
   276  	l.outstandingRequests++
   277  
   278  	l.messageMutex.Unlock()
   279  
   280  	responses := make(chan *PacketResponse)
   281  	messageID := packet.Children[0].Value.(int64)
   282  	message := &messagePacket{
   283  		Op:        MessageRequest,
   284  		MessageID: messageID,
   285  		Packet:    packet,
   286  		Context: &messageContext{
   287  			id:        messageID,
   288  			done:      make(chan struct{}),
   289  			responses: responses,
   290  		},
   291  	}
   292  	l.sendProcessMessage(message)
   293  	return message.Context, nil
   294  }
   295  
   296  func (l *Conn) finishMessage(msgCtx *messageContext) {
   297  	close(msgCtx.done)
   298  
   299  	if l.isClosing {
   300  		return
   301  	}
   302  
   303  	l.messageMutex.Lock()
   304  	l.outstandingRequests--
   305  	if l.isStartingTLS {
   306  		l.isStartingTLS = false
   307  	}
   308  	l.messageMutex.Unlock()
   309  
   310  	message := &messagePacket{
   311  		Op:        MessageFinish,
   312  		MessageID: msgCtx.id,
   313  	}
   314  	l.sendProcessMessage(message)
   315  }
   316  
   317  func (l *Conn) sendProcessMessage(message *messagePacket) bool {
   318  	if l.isClosing {
   319  		return false
   320  	}
   321  	l.wgSender.Add(1)
   322  	l.chanMessage <- message
   323  	l.wgSender.Done()
   324  	return true
   325  }
   326  
   327  func (l *Conn) processMessages() {
   328  	defer func() {
   329  		if err := recover(); err != nil {
   330  			log.Printf("ldap: recovered panic in processMessages: %v", err)
   331  		}
   332  		for messageID, msgCtx := range l.messageContexts {
   333  			// If we are closing due to an error, inform anyone who
   334  			// is waiting about the error.
   335  			if l.isClosing && l.closeErr != nil {
   336  				msgCtx.sendResponse(&PacketResponse{Error: l.closeErr})
   337  			}
   338  			l.Debug.Printf("Closing channel for MessageID %d", messageID)
   339  			close(msgCtx.responses)
   340  			delete(l.messageContexts, messageID)
   341  		}
   342  		close(l.chanMessageID)
   343  		l.chanConfirm <- true
   344  		close(l.chanConfirm)
   345  	}()
   346  
   347  	var messageID int64 = 1
   348  	for {
   349  		select {
   350  		case l.chanMessageID <- messageID:
   351  			messageID++
   352  		case message, ok := <-l.chanMessage:
   353  			if !ok {
   354  				l.Debug.Printf("Shutting down - message channel is closed")
   355  				return
   356  			}
   357  			switch message.Op {
   358  			case MessageQuit:
   359  				l.Debug.Printf("Shutting down - quit message received")
   360  				return
   361  			case MessageRequest:
   362  				// Add to message list and write to network
   363  				l.Debug.Printf("Sending message %d", message.MessageID)
   364  
   365  				buf := message.Packet.Bytes()
   366  				_, err := l.conn.Write(buf)
   367  				if err != nil {
   368  					l.Debug.Printf("Error Sending Message: %s", err.Error())
   369  					message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)})
   370  					close(message.Context.responses)
   371  					break
   372  				}
   373  
   374  				// Only add to messageContexts if we were able to
   375  				// successfully write the message.
   376  				l.messageContexts[message.MessageID] = message.Context
   377  
   378  				// Add timeout if defined
   379  				if l.requestTimeout > 0 {
   380  					go func() {
   381  						defer func() {
   382  							if err := recover(); err != nil {
   383  								log.Printf("ldap: recovered panic in RequestTimeout: %v", err)
   384  							}
   385  						}()
   386  						time.Sleep(l.requestTimeout)
   387  						timeoutMessage := &messagePacket{
   388  							Op:        MessageTimeout,
   389  							MessageID: message.MessageID,
   390  						}
   391  						l.sendProcessMessage(timeoutMessage)
   392  					}()
   393  				}
   394  			case MessageResponse:
   395  				l.Debug.Printf("Receiving message %d", message.MessageID)
   396  				if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
   397  					msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
   398  				} else {
   399  					log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing)
   400  					ber.PrintPacket(message.Packet)
   401  				}
   402  			case MessageTimeout:
   403  				// Handle the timeout by closing the channel
   404  				// All reads will return immediately
   405  				if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
   406  					l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
   407  					msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")})
   408  					delete(l.messageContexts, message.MessageID)
   409  					close(msgCtx.responses)
   410  				}
   411  			case MessageFinish:
   412  				l.Debug.Printf("Finished message %d", message.MessageID)
   413  				if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
   414  					delete(l.messageContexts, message.MessageID)
   415  					close(msgCtx.responses)
   416  				}
   417  			}
   418  		}
   419  	}
   420  }
   421  
   422  func (l *Conn) reader() {
   423  	cleanstop := false
   424  	defer func() {
   425  		if err := recover(); err != nil {
   426  			log.Printf("ldap: recovered panic in reader: %v", err)
   427  		}
   428  		if !cleanstop {
   429  			l.Close()
   430  		}
   431  	}()
   432  
   433  	for {
   434  		if cleanstop {
   435  			l.Debug.Printf("reader clean stopping (without closing the connection)")
   436  			return
   437  		}
   438  		packet, err := ber.ReadPacket(l.conn)
   439  		if err != nil {
   440  			// A read error is expected here if we are closing the connection...
   441  			if !l.isClosing {
   442  				l.closeErr = fmt.Errorf("unable to read LDAP response packet: %s", err)
   443  				l.Debug.Printf("reader error: %s", err.Error())
   444  			}
   445  			return
   446  		}
   447  		addLDAPDescriptions(packet)
   448  		if len(packet.Children) == 0 {
   449  			l.Debug.Printf("Received bad ldap packet")
   450  			continue
   451  		}
   452  		l.messageMutex.Lock()
   453  		if l.isStartingTLS {
   454  			cleanstop = true
   455  		}
   456  		l.messageMutex.Unlock()
   457  		message := &messagePacket{
   458  			Op:        MessageResponse,
   459  			MessageID: packet.Children[0].Value.(int64),
   460  			Packet:    packet,
   461  		}
   462  		if !l.sendProcessMessage(message) {
   463  			return
   464  		}
   465  	}
   466  }