github.com/hellobchain/third_party@v0.0.0-20230331131523-deb0478a2e52/ldap.v2/conn.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ldap 6 7 import ( 8 "errors" 9 "fmt" 10 "github.com/hellobchain/newcryptosm/tls" 11 ber "gopkg.in/asn1-ber.v1" 12 "log" 13 "net" 14 "sync" 15 "time" 16 ) 17 18 const ( 19 // MessageQuit causes the processMessages loop to exit 20 MessageQuit = 0 21 // MessageRequest sends a request to the server 22 MessageRequest = 1 23 // MessageResponse receives a response from the server 24 MessageResponse = 2 25 // MessageFinish indicates the client considers a particular message ID to be finished 26 MessageFinish = 3 27 // MessageTimeout indicates the client-specified timeout for a particular message ID has been reached 28 MessageTimeout = 4 29 ) 30 31 // PacketResponse contains the packet or error encountered reading a response 32 type PacketResponse struct { 33 // Packet is the packet read from the server 34 Packet *ber.Packet 35 // Error is an error encountered while reading 36 Error error 37 } 38 39 // ReadPacket returns the packet or an error 40 func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) { 41 if (pr == nil) || (pr.Packet == nil && pr.Error == nil) { 42 return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response")) 43 } 44 return pr.Packet, pr.Error 45 } 46 47 type messageContext struct { 48 id int64 49 // close(done) should only be called from finishMessage() 50 done chan struct{} 51 // close(responses) should only be called from processMessages(), and only sent to from sendResponse() 52 responses chan *PacketResponse 53 } 54 55 // sendResponse should only be called within the processMessages() loop which 56 // is also responsible for closing the responses channel. 57 func (msgCtx *messageContext) sendResponse(packet *PacketResponse) { 58 select { 59 case msgCtx.responses <- packet: 60 // Successfully sent packet to message handler. 61 case <-msgCtx.done: 62 // The request handler is done and will not receive more 63 // packets. 64 } 65 } 66 67 type messagePacket struct { 68 Op int 69 MessageID int64 70 Packet *ber.Packet 71 Context *messageContext 72 } 73 74 type sendMessageFlags uint 75 76 const ( 77 startTLS sendMessageFlags = 1 << iota 78 ) 79 80 // Conn represents an LDAP Connection 81 type Conn struct { 82 conn net.Conn 83 isTLS bool 84 isClosing bool 85 closeErr error 86 isStartingTLS bool 87 Debug debugging 88 chanConfirm chan bool 89 messageContexts map[int64]*messageContext 90 chanMessage chan *messagePacket 91 chanMessageID chan int64 92 wgSender sync.WaitGroup 93 wgClose sync.WaitGroup 94 once sync.Once 95 outstandingRequests uint 96 messageMutex sync.Mutex 97 requestTimeout time.Duration 98 } 99 100 var _ Client = &Conn{} 101 102 // DefaultTimeout is a package-level variable that sets the timeout value 103 // used for the Dial and DialTLS methods. 104 // 105 // WARNING: since this is a package-level variable, setting this value from 106 // multiple places will probably result in undesired behaviour. 107 var DefaultTimeout = 60 * time.Second 108 109 // Dial connects to the given address on the given network using net.Dial 110 // and then returns a new Conn for the connection. 111 func Dial(network, addr string) (*Conn, error) { 112 c, err := net.DialTimeout(network, addr, DefaultTimeout) 113 if err != nil { 114 return nil, NewError(ErrorNetwork, err) 115 } 116 conn := NewConn(c, false) 117 conn.Start() 118 return conn, nil 119 } 120 121 // DialTLS connects to the given address on the given network using tls.Dial 122 // and then returns a new Conn for the connection. 123 func DialTLS(network, addr string, config *tls.Config) (*Conn, error) { 124 dc, err := net.DialTimeout(network, addr, DefaultTimeout) 125 if err != nil { 126 return nil, NewError(ErrorNetwork, err) 127 } 128 c := tls.Client(dc, config) 129 err = c.Handshake() 130 if err != nil { 131 // Handshake error, close the established connection before we return an error 132 dc.Close() 133 return nil, NewError(ErrorNetwork, err) 134 } 135 conn := NewConn(c, true) 136 conn.Start() 137 return conn, nil 138 } 139 140 // NewConn returns a new Conn using conn for network I/O. 141 func NewConn(conn net.Conn, isTLS bool) *Conn { 142 return &Conn{ 143 conn: conn, 144 chanConfirm: make(chan bool), 145 chanMessageID: make(chan int64), 146 chanMessage: make(chan *messagePacket, 10), 147 messageContexts: map[int64]*messageContext{}, 148 requestTimeout: 0, 149 isTLS: isTLS, 150 } 151 } 152 153 // Start initializes goroutines to read responses and process messages 154 func (l *Conn) Start() { 155 go l.reader() 156 go l.processMessages() 157 l.wgClose.Add(1) 158 } 159 160 // Close closes the connection. 161 func (l *Conn) Close() { 162 l.once.Do(func() { 163 l.isClosing = true 164 l.wgSender.Wait() 165 166 l.Debug.Printf("Sending quit message and waiting for confirmation") 167 l.chanMessage <- &messagePacket{Op: MessageQuit} 168 <-l.chanConfirm 169 close(l.chanMessage) 170 171 l.Debug.Printf("Closing network connection") 172 if err := l.conn.Close(); err != nil { 173 log.Print(err) 174 } 175 176 l.wgClose.Done() 177 }) 178 l.wgClose.Wait() 179 } 180 181 // SetTimeout sets the time after a request is sent that a MessageTimeout triggers 182 func (l *Conn) SetTimeout(timeout time.Duration) { 183 if timeout > 0 { 184 l.requestTimeout = timeout 185 } 186 } 187 188 // Returns the next available messageID 189 func (l *Conn) nextMessageID() int64 { 190 if l.chanMessageID != nil { 191 if messageID, ok := <-l.chanMessageID; ok { 192 return messageID 193 } 194 } 195 return 0 196 } 197 198 // StartTLS sends the command to start a TLS session and then creates a new TLS Client 199 func (l *Conn) StartTLS(config *tls.Config) error { 200 if l.isTLS { 201 return NewError(ErrorNetwork, errors.New("ldap: already encrypted")) 202 } 203 204 packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") 205 packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) 206 request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS") 207 request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command")) 208 packet.AppendChild(request) 209 l.Debug.PrintPacket(packet) 210 211 msgCtx, err := l.sendMessageWithFlags(packet, startTLS) 212 if err != nil { 213 return err 214 } 215 defer l.finishMessage(msgCtx) 216 217 l.Debug.Printf("%d: waiting for response", msgCtx.id) 218 219 packetResponse, ok := <-msgCtx.responses 220 if !ok { 221 return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) 222 } 223 packet, err = packetResponse.ReadPacket() 224 l.Debug.Printf("%d: got response %p", msgCtx.id, packet) 225 if err != nil { 226 return err 227 } 228 229 if l.Debug { 230 if err := addLDAPDescriptions(packet); err != nil { 231 l.Close() 232 return err 233 } 234 ber.PrintPacket(packet) 235 } 236 237 if resultCode, message := getLDAPResultCode(packet); resultCode == LDAPResultSuccess { 238 conn := tls.Client(l.conn, config) 239 240 if err := conn.Handshake(); err != nil { 241 l.Close() 242 return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", err)) 243 } 244 245 l.isTLS = true 246 l.conn = conn 247 } else { 248 return NewError(resultCode, fmt.Errorf("ldap: cannot StartTLS (%s)", message)) 249 } 250 go l.reader() 251 252 return nil 253 } 254 255 func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) { 256 return l.sendMessageWithFlags(packet, 0) 257 } 258 259 func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) { 260 if l.isClosing { 261 return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed")) 262 } 263 l.messageMutex.Lock() 264 l.Debug.Printf("flags&startTLS = %d", flags&startTLS) 265 if l.isStartingTLS { 266 l.messageMutex.Unlock() 267 return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase")) 268 } 269 if flags&startTLS != 0 { 270 if l.outstandingRequests != 0 { 271 l.messageMutex.Unlock() 272 return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests")) 273 } 274 l.isStartingTLS = true 275 } 276 l.outstandingRequests++ 277 278 l.messageMutex.Unlock() 279 280 responses := make(chan *PacketResponse) 281 messageID := packet.Children[0].Value.(int64) 282 message := &messagePacket{ 283 Op: MessageRequest, 284 MessageID: messageID, 285 Packet: packet, 286 Context: &messageContext{ 287 id: messageID, 288 done: make(chan struct{}), 289 responses: responses, 290 }, 291 } 292 l.sendProcessMessage(message) 293 return message.Context, nil 294 } 295 296 func (l *Conn) finishMessage(msgCtx *messageContext) { 297 close(msgCtx.done) 298 299 if l.isClosing { 300 return 301 } 302 303 l.messageMutex.Lock() 304 l.outstandingRequests-- 305 if l.isStartingTLS { 306 l.isStartingTLS = false 307 } 308 l.messageMutex.Unlock() 309 310 message := &messagePacket{ 311 Op: MessageFinish, 312 MessageID: msgCtx.id, 313 } 314 l.sendProcessMessage(message) 315 } 316 317 func (l *Conn) sendProcessMessage(message *messagePacket) bool { 318 if l.isClosing { 319 return false 320 } 321 l.wgSender.Add(1) 322 l.chanMessage <- message 323 l.wgSender.Done() 324 return true 325 } 326 327 func (l *Conn) processMessages() { 328 defer func() { 329 if err := recover(); err != nil { 330 log.Printf("ldap: recovered panic in processMessages: %v", err) 331 } 332 for messageID, msgCtx := range l.messageContexts { 333 // If we are closing due to an error, inform anyone who 334 // is waiting about the error. 335 if l.isClosing && l.closeErr != nil { 336 msgCtx.sendResponse(&PacketResponse{Error: l.closeErr}) 337 } 338 l.Debug.Printf("Closing channel for MessageID %d", messageID) 339 close(msgCtx.responses) 340 delete(l.messageContexts, messageID) 341 } 342 close(l.chanMessageID) 343 l.chanConfirm <- true 344 close(l.chanConfirm) 345 }() 346 347 var messageID int64 = 1 348 for { 349 select { 350 case l.chanMessageID <- messageID: 351 messageID++ 352 case message, ok := <-l.chanMessage: 353 if !ok { 354 l.Debug.Printf("Shutting down - message channel is closed") 355 return 356 } 357 switch message.Op { 358 case MessageQuit: 359 l.Debug.Printf("Shutting down - quit message received") 360 return 361 case MessageRequest: 362 // Add to message list and write to network 363 l.Debug.Printf("Sending message %d", message.MessageID) 364 365 buf := message.Packet.Bytes() 366 _, err := l.conn.Write(buf) 367 if err != nil { 368 l.Debug.Printf("Error Sending Message: %s", err.Error()) 369 message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}) 370 close(message.Context.responses) 371 break 372 } 373 374 // Only add to messageContexts if we were able to 375 // successfully write the message. 376 l.messageContexts[message.MessageID] = message.Context 377 378 // Add timeout if defined 379 if l.requestTimeout > 0 { 380 go func() { 381 defer func() { 382 if err := recover(); err != nil { 383 log.Printf("ldap: recovered panic in RequestTimeout: %v", err) 384 } 385 }() 386 time.Sleep(l.requestTimeout) 387 timeoutMessage := &messagePacket{ 388 Op: MessageTimeout, 389 MessageID: message.MessageID, 390 } 391 l.sendProcessMessage(timeoutMessage) 392 }() 393 } 394 case MessageResponse: 395 l.Debug.Printf("Receiving message %d", message.MessageID) 396 if msgCtx, ok := l.messageContexts[message.MessageID]; ok { 397 msgCtx.sendResponse(&PacketResponse{message.Packet, nil}) 398 } else { 399 log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing) 400 ber.PrintPacket(message.Packet) 401 } 402 case MessageTimeout: 403 // Handle the timeout by closing the channel 404 // All reads will return immediately 405 if msgCtx, ok := l.messageContexts[message.MessageID]; ok { 406 l.Debug.Printf("Receiving message timeout for %d", message.MessageID) 407 msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")}) 408 delete(l.messageContexts, message.MessageID) 409 close(msgCtx.responses) 410 } 411 case MessageFinish: 412 l.Debug.Printf("Finished message %d", message.MessageID) 413 if msgCtx, ok := l.messageContexts[message.MessageID]; ok { 414 delete(l.messageContexts, message.MessageID) 415 close(msgCtx.responses) 416 } 417 } 418 } 419 } 420 } 421 422 func (l *Conn) reader() { 423 cleanstop := false 424 defer func() { 425 if err := recover(); err != nil { 426 log.Printf("ldap: recovered panic in reader: %v", err) 427 } 428 if !cleanstop { 429 l.Close() 430 } 431 }() 432 433 for { 434 if cleanstop { 435 l.Debug.Printf("reader clean stopping (without closing the connection)") 436 return 437 } 438 packet, err := ber.ReadPacket(l.conn) 439 if err != nil { 440 // A read error is expected here if we are closing the connection... 441 if !l.isClosing { 442 l.closeErr = fmt.Errorf("unable to read LDAP response packet: %s", err) 443 l.Debug.Printf("reader error: %s", err.Error()) 444 } 445 return 446 } 447 addLDAPDescriptions(packet) 448 if len(packet.Children) == 0 { 449 l.Debug.Printf("Received bad ldap packet") 450 continue 451 } 452 l.messageMutex.Lock() 453 if l.isStartingTLS { 454 cleanstop = true 455 } 456 l.messageMutex.Unlock() 457 message := &messagePacket{ 458 Op: MessageResponse, 459 MessageID: packet.Children[0].Value.(int64), 460 Packet: packet, 461 } 462 if !l.sendProcessMessage(message) { 463 return 464 } 465 } 466 }