github.com/hernad/nomad@v1.6.112/helper/pool/pool.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package pool 5 6 import ( 7 "container/list" 8 "fmt" 9 "io" 10 "log" 11 "net" 12 "net/rpc" 13 "sync" 14 "sync/atomic" 15 "time" 16 17 hclog "github.com/hashicorp/go-hclog" 18 msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" 19 "github.com/hernad/nomad/helper" 20 "github.com/hernad/nomad/helper/tlsutil" 21 "github.com/hernad/nomad/nomad/structs" 22 "github.com/hashicorp/yamux" 23 ) 24 25 // NewClientCodec returns a new rpc.ClientCodec to be used to make RPC calls. 26 func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec { 27 return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle) 28 } 29 30 // NewServerCodec returns a new rpc.ServerCodec to be used to handle RPCs. 31 func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec { 32 return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle) 33 } 34 35 // streamClient is used to wrap a stream with an RPC client 36 type StreamClient struct { 37 stream net.Conn 38 codec rpc.ClientCodec 39 } 40 41 func (sc *StreamClient) Close() { 42 sc.stream.Close() 43 sc.codec.Close() 44 } 45 46 // Conn is a pooled connection to a Nomad server 47 type Conn struct { 48 refCount int32 49 shouldClose int32 50 51 addr net.Addr 52 session *yamux.Session 53 lastUsed atomic.Pointer[time.Time] 54 55 pool *ConnPool 56 57 clients *list.List 58 clientLock sync.Mutex 59 } 60 61 // markForUse does all the bookkeeping required to ready a connection for use, 62 // and ensure that active connections don't get reaped. 63 func (c *Conn) markForUse() { 64 now := time.Now() 65 c.lastUsed.Store(&now) 66 atomic.AddInt32(&c.refCount, 1) 67 } 68 69 // releaseUse is the complement of `markForUse`, to free up the reference count 70 func (c *Conn) releaseUse() { 71 refCount := atomic.AddInt32(&c.refCount, -1) 72 if refCount == 0 && atomic.LoadInt32(&c.shouldClose) == 1 { 73 c.Close() 74 } 75 } 76 77 func (c *Conn) Close() error { 78 return c.session.Close() 79 } 80 81 // getClient is used to get a cached or new client 82 func (c *Conn) getRPCClient() (*StreamClient, error) { 83 // Check for cached client 84 c.clientLock.Lock() 85 front := c.clients.Front() 86 if front != nil { 87 c.clients.Remove(front) 88 } 89 c.clientLock.Unlock() 90 if front != nil { 91 return front.Value.(*StreamClient), nil 92 } 93 94 // Open a new session 95 stream, err := c.session.Open() 96 if err != nil { 97 return nil, err 98 } 99 100 if _, err := stream.Write([]byte{byte(RpcNomad)}); err != nil { 101 stream.Close() 102 return nil, err 103 } 104 105 // Create a client codec 106 codec := NewClientCodec(stream) 107 108 // Return a new stream client 109 sc := &StreamClient{ 110 stream: stream, 111 codec: codec, 112 } 113 return sc, nil 114 } 115 116 // returnClient is used when done with a stream 117 // to allow re-use by a future RPC 118 func (c *Conn) returnClient(client *StreamClient) { 119 didSave := false 120 c.clientLock.Lock() 121 if c.clients.Len() < c.pool.maxStreams && atomic.LoadInt32(&c.shouldClose) == 0 { 122 c.clients.PushFront(client) 123 didSave = true 124 125 // If this is a Yamux stream, shrink the internal buffers so that 126 // we can GC the idle memory 127 if ys, ok := client.stream.(*yamux.Stream); ok { 128 ys.Shrink() 129 } 130 } 131 c.clientLock.Unlock() 132 if !didSave { 133 client.Close() 134 } 135 } 136 137 func (c *Conn) IsClosed() bool { 138 return c.session.IsClosed() 139 } 140 141 func (c *Conn) AcceptStream() (net.Conn, error) { 142 s, err := c.session.AcceptStream() 143 if err != nil { 144 return nil, err 145 } 146 147 c.markForUse() 148 return &incomingStream{ 149 Stream: s, 150 parent: c, 151 }, nil 152 } 153 154 // incomingStream wraps yamux.Stream but frees the underlying yamux.Session 155 // when closed 156 type incomingStream struct { 157 *yamux.Stream 158 159 parent *Conn 160 } 161 162 func (s *incomingStream) Close() error { 163 err := s.Stream.Close() 164 165 // always release parent even if error 166 s.parent.releaseUse() 167 168 return err 169 } 170 171 // ConnPool is used to maintain a connection pool to other 172 // Nomad servers. This is used to reduce the latency of 173 // RPC requests between servers. It is only used to pool 174 // connections in the rpcNomad mode. Raft connections 175 // are pooled separately. 176 type ConnPool struct { 177 sync.Mutex 178 179 // logger is the logger to be used 180 logger *log.Logger 181 182 // The maximum time to keep a connection open 183 maxTime time.Duration 184 185 // The maximum number of open streams to keep 186 maxStreams int 187 188 // Pool maps an address to a open connection 189 pool map[string]*Conn 190 191 // limiter is used to throttle the number of connect attempts 192 // to a given address. The first thread will attempt a connection 193 // and put a channel in here, which all other threads will wait 194 // on to close. 195 limiter map[string]chan struct{} 196 197 // TLS wrapper 198 tlsWrap tlsutil.RegionWrapper 199 200 // Used to indicate the pool is shutdown 201 shutdown bool 202 shutdownCh chan struct{} 203 204 // connListener is used to notify a potential listener of a new connection 205 // being made. 206 connListener chan<- *Conn 207 } 208 209 // NewPool is used to make a new connection pool 210 // Maintain at most one connection per host, for up to maxTime. 211 // Set maxTime to 0 to disable reaping. maxStreams is used to control 212 // the number of idle streams allowed. 213 // If TLS settings are provided outgoing connections use TLS. 214 func NewPool(logger hclog.Logger, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.RegionWrapper) *ConnPool { 215 pool := &ConnPool{ 216 logger: logger.StandardLogger(&hclog.StandardLoggerOptions{InferLevels: true}), 217 maxTime: maxTime, 218 maxStreams: maxStreams, 219 pool: make(map[string]*Conn), 220 limiter: make(map[string]chan struct{}), 221 tlsWrap: tlsWrap, 222 shutdownCh: make(chan struct{}), 223 } 224 if maxTime > 0 { 225 go pool.reap() 226 } 227 return pool 228 } 229 230 // Shutdown is used to close the connection pool 231 func (p *ConnPool) Shutdown() error { 232 p.Lock() 233 defer p.Unlock() 234 235 for _, conn := range p.pool { 236 conn.Close() 237 } 238 p.pool = make(map[string]*Conn) 239 240 if p.shutdown { 241 return nil 242 } 243 244 if p.connListener != nil { 245 close(p.connListener) 246 p.connListener = nil 247 } 248 249 p.shutdown = true 250 close(p.shutdownCh) 251 return nil 252 } 253 254 // ReloadTLS reloads TLS configuration on the fly 255 func (p *ConnPool) ReloadTLS(tlsWrap tlsutil.RegionWrapper) { 256 p.Lock() 257 defer p.Unlock() 258 259 oldPool := p.pool 260 for _, conn := range oldPool { 261 conn.Close() 262 } 263 p.pool = make(map[string]*Conn) 264 p.tlsWrap = tlsWrap 265 } 266 267 // SetConnListener is used to listen to new connections being made. The 268 // channel will be closed when the conn pool is closed or a new listener is set. 269 func (p *ConnPool) SetConnListener(l chan<- *Conn) { 270 p.Lock() 271 defer p.Unlock() 272 273 // Close the old listener 274 if p.connListener != nil { 275 close(p.connListener) 276 } 277 278 // Store the new listener 279 p.connListener = l 280 } 281 282 // Acquire is used to get a connection that is 283 // pooled or to return a new connection 284 func (p *ConnPool) acquire(region string, addr net.Addr) (*Conn, error) { 285 // Check to see if there's a pooled connection available. This is up 286 // here since it should the vastly more common case than the rest 287 // of the code here. 288 p.Lock() 289 c := p.pool[addr.String()] 290 if c != nil { 291 c.markForUse() 292 p.Unlock() 293 return c, nil 294 } 295 296 // If not (while we are still locked), set up the throttling structure 297 // for this address, which will make everyone else wait until our 298 // attempt is done. 299 var wait chan struct{} 300 var ok bool 301 if wait, ok = p.limiter[addr.String()]; !ok { 302 wait = make(chan struct{}) 303 p.limiter[addr.String()] = wait 304 } 305 isLeadThread := !ok 306 p.Unlock() 307 308 // If we are the lead thread, make the new connection and then wake 309 // everybody else up to see if we got it. 310 if isLeadThread { 311 c, err := p.getNewConn(region, addr) 312 p.Lock() 313 delete(p.limiter, addr.String()) 314 close(wait) 315 if err != nil { 316 p.Unlock() 317 return nil, err 318 } 319 320 p.pool[addr.String()] = c 321 322 // If there is a connection listener, notify them of the new connection. 323 if p.connListener != nil { 324 select { 325 case p.connListener <- c: 326 default: 327 } 328 } 329 330 p.Unlock() 331 return c, nil 332 } 333 334 // Otherwise, wait for the lead thread to attempt the connection 335 // and use what's in the pool at that point. 336 select { 337 case <-p.shutdownCh: 338 return nil, fmt.Errorf("rpc error: shutdown") 339 case <-wait: 340 } 341 342 // See if the lead thread was able to get us a connection. 343 p.Lock() 344 if c := p.pool[addr.String()]; c != nil { 345 c.markForUse() 346 p.Unlock() 347 return c, nil 348 } 349 350 p.Unlock() 351 return nil, fmt.Errorf("rpc error: lead thread didn't get connection") 352 } 353 354 // getNewConn is used to return a new connection 355 func (p *ConnPool) getNewConn(region string, addr net.Addr) (*Conn, error) { 356 // Try to dial the conn 357 conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second) 358 if err != nil { 359 return nil, err 360 } 361 362 // Cast to TCPConn 363 if tcp, ok := conn.(*net.TCPConn); ok { 364 tcp.SetKeepAlive(true) 365 tcp.SetNoDelay(true) 366 } 367 368 // Check if TLS is enabled 369 if p.tlsWrap != nil { 370 // Switch the connection into TLS mode 371 if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil { 372 conn.Close() 373 return nil, err 374 } 375 376 // Wrap the connection in a TLS client 377 tlsConn, err := p.tlsWrap(region, conn) 378 if err != nil { 379 conn.Close() 380 return nil, err 381 } 382 conn = tlsConn 383 } 384 385 // Write the multiplex byte to set the mode 386 if _, err := conn.Write([]byte{byte(RpcMultiplexV2)}); err != nil { 387 conn.Close() 388 return nil, err 389 } 390 391 // Setup the logger 392 conf := yamux.DefaultConfig() 393 conf.LogOutput = nil 394 conf.Logger = p.logger 395 396 // Create a multiplexed session 397 session, err := yamux.Client(conn, conf) 398 if err != nil { 399 conn.Close() 400 return nil, err 401 } 402 403 // Wrap the connection 404 c := &Conn{ 405 refCount: 1, 406 addr: addr, 407 session: session, 408 clients: list.New(), 409 lastUsed: atomic.Pointer[time.Time]{}, 410 pool: p, 411 } 412 413 now := time.Now() 414 c.lastUsed.Store(&now) 415 return c, nil 416 } 417 418 // clearConn is used to clear any cached connection, potentially in response to 419 // an error 420 func (p *ConnPool) clearConn(conn *Conn) { 421 // Ensure returned streams are closed 422 atomic.StoreInt32(&conn.shouldClose, 1) 423 424 // Clear from the cache 425 p.Lock() 426 if c, ok := p.pool[conn.addr.String()]; ok && c == conn { 427 delete(p.pool, conn.addr.String()) 428 } 429 p.Unlock() 430 431 // Close down immediately if idle 432 if refCount := atomic.LoadInt32(&conn.refCount); refCount == 0 { 433 conn.Close() 434 } 435 } 436 437 // getClient is used to get a usable client for an address 438 func (p *ConnPool) getRPCClient(region string, addr net.Addr) (*Conn, *StreamClient, error) { 439 retries := 0 440 START: 441 // Try to get a conn first 442 conn, err := p.acquire(region, addr) 443 if err != nil { 444 return nil, nil, fmt.Errorf("failed to get conn: %v", err) 445 } 446 447 // Get a client 448 client, err := conn.getRPCClient() 449 if err != nil { 450 p.clearConn(conn) 451 conn.releaseUse() 452 453 // Try to redial, possible that the TCP session closed due to timeout 454 if retries == 0 { 455 retries++ 456 goto START 457 } 458 return nil, nil, fmt.Errorf("failed to start stream: %v", err) 459 } 460 return conn, client, nil 461 } 462 463 // StreamingRPC is used to make an streaming RPC call. Callers must 464 // close the connection when done. 465 func (p *ConnPool) StreamingRPC(region string, addr net.Addr) (net.Conn, error) { 466 conn, err := p.acquire(region, addr) 467 if err != nil { 468 return nil, fmt.Errorf("failed to get conn: %v", err) 469 } 470 471 s, err := conn.session.Open() 472 if err != nil { 473 return nil, fmt.Errorf("failed to open a streaming connection: %v", err) 474 } 475 476 if _, err := s.Write([]byte{byte(RpcStreaming)}); err != nil { 477 conn.Close() 478 return nil, err 479 } 480 481 return s, nil 482 } 483 484 // RPC is used to make an RPC call to a remote host 485 func (p *ConnPool) RPC(region string, addr net.Addr, method string, args interface{}, reply interface{}) error { 486 // Get a usable client 487 conn, sc, err := p.getRPCClient(region, addr) 488 if err != nil { 489 return fmt.Errorf("rpc error: %w", err) 490 } 491 defer conn.releaseUse() 492 493 // Make the RPC call 494 err = msgpackrpc.CallWithCodec(sc.codec, method, args, reply) 495 if err != nil { 496 sc.Close() 497 498 // If we read EOF, the session is toast. Clear it and open a 499 // new session next time 500 // See https://github.com/hashicorp/consul/blob/v1.6.3/agent/pool/pool.go#L471-L477 501 if helper.IsErrEOF(err) { 502 p.clearConn(conn) 503 } 504 505 // If the error is an RPC Coded error 506 // return the coded error without wrapping 507 if structs.IsErrRPCCoded(err) { 508 return err 509 } 510 511 // TODO wrap with RPCCoded error instead 512 return fmt.Errorf("rpc error: %w", err) 513 } 514 515 // Done with the connection 516 conn.returnClient(sc) 517 return nil 518 } 519 520 // Reap is used to close conns open over maxTime 521 func (p *ConnPool) reap() { 522 for { 523 // Sleep for a while 524 select { 525 case <-p.shutdownCh: 526 return 527 case <-time.After(time.Second): 528 } 529 530 // Reap all old conns 531 p.Lock() 532 var removed []string 533 now := time.Now() 534 for host, conn := range p.pool { 535 // Skip recently used connections 536 if now.Sub(*conn.lastUsed.Load()) < p.maxTime { 537 continue 538 } 539 540 // Skip connections with active streams 541 if atomic.LoadInt32(&conn.refCount) > 0 { 542 continue 543 } 544 545 // Close the conn 546 conn.Close() 547 548 // Remove from pool 549 removed = append(removed, host) 550 } 551 for _, host := range removed { 552 delete(p.pool, host) 553 } 554 p.Unlock() 555 } 556 }