github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/helper/pool/pool.go (about) 1 package pool 2 3 import ( 4 "container/list" 5 "fmt" 6 "io" 7 "log" 8 "net" 9 "net/rpc" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 hclog "github.com/hashicorp/go-hclog" 15 msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" 16 "github.com/hashicorp/nomad/helper" 17 "github.com/hashicorp/nomad/helper/tlsutil" 18 "github.com/hashicorp/nomad/nomad/structs" 19 "github.com/hashicorp/yamux" 20 ) 21 22 // NewClientCodec returns a new rpc.ClientCodec to be used to make RPC calls. 23 func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec { 24 return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle) 25 } 26 27 // NewServerCodec returns a new rpc.ServerCodec to be used to handle RPCs. 28 func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec { 29 return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle) 30 } 31 32 // streamClient is used to wrap a stream with an RPC client 33 type StreamClient struct { 34 stream net.Conn 35 codec rpc.ClientCodec 36 } 37 38 func (sc *StreamClient) Close() { 39 sc.stream.Close() 40 sc.codec.Close() 41 } 42 43 // Conn is a pooled connection to a Nomad server 44 type Conn struct { 45 refCount int32 46 shouldClose int32 47 48 addr net.Addr 49 session *yamux.Session 50 lastUsed atomic.Pointer[time.Time] 51 52 pool *ConnPool 53 54 clients *list.List 55 clientLock sync.Mutex 56 } 57 58 // markForUse does all the bookkeeping required to ready a connection for use, 59 // and ensure that active connections don't get reaped. 60 func (c *Conn) markForUse() { 61 now := time.Now() 62 c.lastUsed.Store(&now) 63 atomic.AddInt32(&c.refCount, 1) 64 } 65 66 // releaseUse is the complement of `markForUse`, to free up the reference count 67 func (c *Conn) releaseUse() { 68 refCount := atomic.AddInt32(&c.refCount, -1) 69 if refCount == 0 && atomic.LoadInt32(&c.shouldClose) == 1 { 70 c.Close() 71 } 72 } 73 74 func (c *Conn) Close() error { 75 return c.session.Close() 76 } 77 78 // getClient is used to get a cached or new client 79 func (c *Conn) getRPCClient() (*StreamClient, error) { 80 // Check for cached client 81 c.clientLock.Lock() 82 front := c.clients.Front() 83 if front != nil { 84 c.clients.Remove(front) 85 } 86 c.clientLock.Unlock() 87 if front != nil { 88 return front.Value.(*StreamClient), nil 89 } 90 91 // Open a new session 92 stream, err := c.session.Open() 93 if err != nil { 94 return nil, err 95 } 96 97 if _, err := stream.Write([]byte{byte(RpcNomad)}); err != nil { 98 stream.Close() 99 return nil, err 100 } 101 102 // Create a client codec 103 codec := NewClientCodec(stream) 104 105 // Return a new stream client 106 sc := &StreamClient{ 107 stream: stream, 108 codec: codec, 109 } 110 return sc, nil 111 } 112 113 // returnClient is used when done with a stream 114 // to allow re-use by a future RPC 115 func (c *Conn) returnClient(client *StreamClient) { 116 didSave := false 117 c.clientLock.Lock() 118 if c.clients.Len() < c.pool.maxStreams && atomic.LoadInt32(&c.shouldClose) == 0 { 119 c.clients.PushFront(client) 120 didSave = true 121 122 // If this is a Yamux stream, shrink the internal buffers so that 123 // we can GC the idle memory 124 if ys, ok := client.stream.(*yamux.Stream); ok { 125 ys.Shrink() 126 } 127 } 128 c.clientLock.Unlock() 129 if !didSave { 130 client.Close() 131 } 132 } 133 134 func (c *Conn) IsClosed() bool { 135 return c.session.IsClosed() 136 } 137 138 func (c *Conn) AcceptStream() (net.Conn, error) { 139 s, err := c.session.AcceptStream() 140 if err != nil { 141 return nil, err 142 } 143 144 c.markForUse() 145 return &incomingStream{ 146 Stream: s, 147 parent: c, 148 }, nil 149 } 150 151 // incomingStream wraps yamux.Stream but frees the underlying yamux.Session 152 // when closed 153 type incomingStream struct { 154 *yamux.Stream 155 156 parent *Conn 157 } 158 159 func (s *incomingStream) Close() error { 160 err := s.Stream.Close() 161 162 // always release parent even if error 163 s.parent.releaseUse() 164 165 return err 166 } 167 168 // ConnPool is used to maintain a connection pool to other 169 // Nomad servers. This is used to reduce the latency of 170 // RPC requests between servers. It is only used to pool 171 // connections in the rpcNomad mode. Raft connections 172 // are pooled separately. 173 type ConnPool struct { 174 sync.Mutex 175 176 // logger is the logger to be used 177 logger *log.Logger 178 179 // The maximum time to keep a connection open 180 maxTime time.Duration 181 182 // The maximum number of open streams to keep 183 maxStreams int 184 185 // Pool maps an address to a open connection 186 pool map[string]*Conn 187 188 // limiter is used to throttle the number of connect attempts 189 // to a given address. The first thread will attempt a connection 190 // and put a channel in here, which all other threads will wait 191 // on to close. 192 limiter map[string]chan struct{} 193 194 // TLS wrapper 195 tlsWrap tlsutil.RegionWrapper 196 197 // Used to indicate the pool is shutdown 198 shutdown bool 199 shutdownCh chan struct{} 200 201 // connListener is used to notify a potential listener of a new connection 202 // being made. 203 connListener chan<- *Conn 204 } 205 206 // NewPool is used to make a new connection pool 207 // Maintain at most one connection per host, for up to maxTime. 208 // Set maxTime to 0 to disable reaping. maxStreams is used to control 209 // the number of idle streams allowed. 210 // If TLS settings are provided outgoing connections use TLS. 211 func NewPool(logger hclog.Logger, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.RegionWrapper) *ConnPool { 212 pool := &ConnPool{ 213 logger: logger.StandardLogger(&hclog.StandardLoggerOptions{InferLevels: true}), 214 maxTime: maxTime, 215 maxStreams: maxStreams, 216 pool: make(map[string]*Conn), 217 limiter: make(map[string]chan struct{}), 218 tlsWrap: tlsWrap, 219 shutdownCh: make(chan struct{}), 220 } 221 if maxTime > 0 { 222 go pool.reap() 223 } 224 return pool 225 } 226 227 // Shutdown is used to close the connection pool 228 func (p *ConnPool) Shutdown() error { 229 p.Lock() 230 defer p.Unlock() 231 232 for _, conn := range p.pool { 233 conn.Close() 234 } 235 p.pool = make(map[string]*Conn) 236 237 if p.shutdown { 238 return nil 239 } 240 241 if p.connListener != nil { 242 close(p.connListener) 243 p.connListener = nil 244 } 245 246 p.shutdown = true 247 close(p.shutdownCh) 248 return nil 249 } 250 251 // ReloadTLS reloads TLS configuration on the fly 252 func (p *ConnPool) ReloadTLS(tlsWrap tlsutil.RegionWrapper) { 253 p.Lock() 254 defer p.Unlock() 255 256 oldPool := p.pool 257 for _, conn := range oldPool { 258 conn.Close() 259 } 260 p.pool = make(map[string]*Conn) 261 p.tlsWrap = tlsWrap 262 } 263 264 // SetConnListener is used to listen to new connections being made. The 265 // channel will be closed when the conn pool is closed or a new listener is set. 266 func (p *ConnPool) SetConnListener(l chan<- *Conn) { 267 p.Lock() 268 defer p.Unlock() 269 270 // Close the old listener 271 if p.connListener != nil { 272 close(p.connListener) 273 } 274 275 // Store the new listener 276 p.connListener = l 277 } 278 279 // Acquire is used to get a connection that is 280 // pooled or to return a new connection 281 func (p *ConnPool) acquire(region string, addr net.Addr) (*Conn, error) { 282 // Check to see if there's a pooled connection available. This is up 283 // here since it should the vastly more common case than the rest 284 // of the code here. 285 p.Lock() 286 c := p.pool[addr.String()] 287 if c != nil { 288 c.markForUse() 289 p.Unlock() 290 return c, nil 291 } 292 293 // If not (while we are still locked), set up the throttling structure 294 // for this address, which will make everyone else wait until our 295 // attempt is done. 296 var wait chan struct{} 297 var ok bool 298 if wait, ok = p.limiter[addr.String()]; !ok { 299 wait = make(chan struct{}) 300 p.limiter[addr.String()] = wait 301 } 302 isLeadThread := !ok 303 p.Unlock() 304 305 // If we are the lead thread, make the new connection and then wake 306 // everybody else up to see if we got it. 307 if isLeadThread { 308 c, err := p.getNewConn(region, addr) 309 p.Lock() 310 delete(p.limiter, addr.String()) 311 close(wait) 312 if err != nil { 313 p.Unlock() 314 return nil, err 315 } 316 317 p.pool[addr.String()] = c 318 319 // If there is a connection listener, notify them of the new connection. 320 if p.connListener != nil { 321 select { 322 case p.connListener <- c: 323 default: 324 } 325 } 326 327 p.Unlock() 328 return c, nil 329 } 330 331 // Otherwise, wait for the lead thread to attempt the connection 332 // and use what's in the pool at that point. 333 select { 334 case <-p.shutdownCh: 335 return nil, fmt.Errorf("rpc error: shutdown") 336 case <-wait: 337 } 338 339 // See if the lead thread was able to get us a connection. 340 p.Lock() 341 if c := p.pool[addr.String()]; c != nil { 342 c.markForUse() 343 p.Unlock() 344 return c, nil 345 } 346 347 p.Unlock() 348 return nil, fmt.Errorf("rpc error: lead thread didn't get connection") 349 } 350 351 // getNewConn is used to return a new connection 352 func (p *ConnPool) getNewConn(region string, addr net.Addr) (*Conn, error) { 353 // Try to dial the conn 354 conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second) 355 if err != nil { 356 return nil, err 357 } 358 359 // Cast to TCPConn 360 if tcp, ok := conn.(*net.TCPConn); ok { 361 tcp.SetKeepAlive(true) 362 tcp.SetNoDelay(true) 363 } 364 365 // Check if TLS is enabled 366 if p.tlsWrap != nil { 367 // Switch the connection into TLS mode 368 if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil { 369 conn.Close() 370 return nil, err 371 } 372 373 // Wrap the connection in a TLS client 374 tlsConn, err := p.tlsWrap(region, conn) 375 if err != nil { 376 conn.Close() 377 return nil, err 378 } 379 conn = tlsConn 380 } 381 382 // Write the multiplex byte to set the mode 383 if _, err := conn.Write([]byte{byte(RpcMultiplexV2)}); err != nil { 384 conn.Close() 385 return nil, err 386 } 387 388 // Setup the logger 389 conf := yamux.DefaultConfig() 390 conf.LogOutput = nil 391 conf.Logger = p.logger 392 393 // Create a multiplexed session 394 session, err := yamux.Client(conn, conf) 395 if err != nil { 396 conn.Close() 397 return nil, err 398 } 399 400 // Wrap the connection 401 c := &Conn{ 402 refCount: 1, 403 addr: addr, 404 session: session, 405 clients: list.New(), 406 lastUsed: atomic.Pointer[time.Time]{}, 407 pool: p, 408 } 409 410 now := time.Now() 411 c.lastUsed.Store(&now) 412 return c, nil 413 } 414 415 // clearConn is used to clear any cached connection, potentially in response to 416 // an error 417 func (p *ConnPool) clearConn(conn *Conn) { 418 // Ensure returned streams are closed 419 atomic.StoreInt32(&conn.shouldClose, 1) 420 421 // Clear from the cache 422 p.Lock() 423 if c, ok := p.pool[conn.addr.String()]; ok && c == conn { 424 delete(p.pool, conn.addr.String()) 425 } 426 p.Unlock() 427 428 // Close down immediately if idle 429 if refCount := atomic.LoadInt32(&conn.refCount); refCount == 0 { 430 conn.Close() 431 } 432 } 433 434 // getClient is used to get a usable client for an address 435 func (p *ConnPool) getRPCClient(region string, addr net.Addr) (*Conn, *StreamClient, error) { 436 retries := 0 437 START: 438 // Try to get a conn first 439 conn, err := p.acquire(region, addr) 440 if err != nil { 441 return nil, nil, fmt.Errorf("failed to get conn: %v", err) 442 } 443 444 // Get a client 445 client, err := conn.getRPCClient() 446 if err != nil { 447 p.clearConn(conn) 448 conn.releaseUse() 449 450 // Try to redial, possible that the TCP session closed due to timeout 451 if retries == 0 { 452 retries++ 453 goto START 454 } 455 return nil, nil, fmt.Errorf("failed to start stream: %v", err) 456 } 457 return conn, client, nil 458 } 459 460 // StreamingRPC is used to make an streaming RPC call. Callers must 461 // close the connection when done. 462 func (p *ConnPool) StreamingRPC(region string, addr net.Addr) (net.Conn, error) { 463 conn, err := p.acquire(region, addr) 464 if err != nil { 465 return nil, fmt.Errorf("failed to get conn: %v", err) 466 } 467 468 s, err := conn.session.Open() 469 if err != nil { 470 return nil, fmt.Errorf("failed to open a streaming connection: %v", err) 471 } 472 473 if _, err := s.Write([]byte{byte(RpcStreaming)}); err != nil { 474 conn.Close() 475 return nil, err 476 } 477 478 return s, nil 479 } 480 481 // RPC is used to make an RPC call to a remote host 482 func (p *ConnPool) RPC(region string, addr net.Addr, method string, args interface{}, reply interface{}) error { 483 // Get a usable client 484 conn, sc, err := p.getRPCClient(region, addr) 485 if err != nil { 486 return fmt.Errorf("rpc error: %w", err) 487 } 488 defer conn.releaseUse() 489 490 // Make the RPC call 491 err = msgpackrpc.CallWithCodec(sc.codec, method, args, reply) 492 if err != nil { 493 sc.Close() 494 495 // If we read EOF, the session is toast. Clear it and open a 496 // new session next time 497 // See https://github.com/hashicorp/consul/blob/v1.6.3/agent/pool/pool.go#L471-L477 498 if helper.IsErrEOF(err) { 499 p.clearConn(conn) 500 } 501 502 // If the error is an RPC Coded error 503 // return the coded error without wrapping 504 if structs.IsErrRPCCoded(err) { 505 return err 506 } 507 508 // TODO wrap with RPCCoded error instead 509 return fmt.Errorf("rpc error: %w", err) 510 } 511 512 // Done with the connection 513 conn.returnClient(sc) 514 return nil 515 } 516 517 // Reap is used to close conns open over maxTime 518 func (p *ConnPool) reap() { 519 for { 520 // Sleep for a while 521 select { 522 case <-p.shutdownCh: 523 return 524 case <-time.After(time.Second): 525 } 526 527 // Reap all old conns 528 p.Lock() 529 var removed []string 530 now := time.Now() 531 for host, conn := range p.pool { 532 // Skip recently used connections 533 if now.Sub(*conn.lastUsed.Load()) < p.maxTime { 534 continue 535 } 536 537 // Skip connections with active streams 538 if atomic.LoadInt32(&conn.refCount) > 0 { 539 continue 540 } 541 542 // Close the conn 543 conn.Close() 544 545 // Remove from pool 546 removed = append(removed, host) 547 } 548 for _, host := range removed { 549 delete(p.pool, host) 550 } 551 p.Unlock() 552 } 553 }