github.com/bigcommerce/nomad@v0.9.3-bc/helper/pool/pool.go (about) 1 package pool 2 3 import ( 4 "container/list" 5 "fmt" 6 "io" 7 "log" 8 "net" 9 "net/rpc" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 hclog "github.com/hashicorp/go-hclog" 15 msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" 16 "github.com/hashicorp/nomad/helper/tlsutil" 17 "github.com/hashicorp/nomad/nomad/structs" 18 "github.com/hashicorp/yamux" 19 ) 20 21 // NewClientCodec returns a new rpc.ClientCodec to be used to make RPC calls. 22 func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec { 23 return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.HashiMsgpackHandle) 24 } 25 26 // NewServerCodec returns a new rpc.ServerCodec to be used to handle RPCs. 27 func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec { 28 return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.HashiMsgpackHandle) 29 } 30 31 // streamClient is used to wrap a stream with an RPC client 32 type StreamClient struct { 33 stream net.Conn 34 codec rpc.ClientCodec 35 } 36 37 func (sc *StreamClient) Close() { 38 sc.stream.Close() 39 } 40 41 // Conn is a pooled connection to a Nomad server 42 type Conn struct { 43 refCount int32 44 shouldClose int32 45 46 addr net.Addr 47 session *yamux.Session 48 lastUsed time.Time 49 version int 50 51 pool *ConnPool 52 53 clients *list.List 54 clientLock sync.Mutex 55 } 56 57 // markForUse does all the bookkeeping required to ready a connection for use. 58 func (c *Conn) markForUse() { 59 c.lastUsed = time.Now() 60 atomic.AddInt32(&c.refCount, 1) 61 } 62 63 func (c *Conn) Close() error { 64 return c.session.Close() 65 } 66 67 // getClient is used to get a cached or new client 68 func (c *Conn) getClient() (*StreamClient, error) { 69 // Check for cached client 70 c.clientLock.Lock() 71 front := c.clients.Front() 72 if front != nil { 73 c.clients.Remove(front) 74 } 75 c.clientLock.Unlock() 76 if front != nil { 77 return front.Value.(*StreamClient), nil 78 } 79 80 // Open a new session 81 stream, err := c.session.Open() 82 if err != nil { 83 return nil, err 84 } 85 86 // Create a client codec 87 codec := NewClientCodec(stream) 88 89 // Return a new stream client 90 sc := &StreamClient{ 91 stream: stream, 92 codec: codec, 93 } 94 return sc, nil 95 } 96 97 // returnClient is used when done with a stream 98 // to allow re-use by a future RPC 99 func (c *Conn) returnClient(client *StreamClient) { 100 didSave := false 101 c.clientLock.Lock() 102 if c.clients.Len() < c.pool.maxStreams && atomic.LoadInt32(&c.shouldClose) == 0 { 103 c.clients.PushFront(client) 104 didSave = true 105 106 // If this is a Yamux stream, shrink the internal buffers so that 107 // we can GC the idle memory 108 if ys, ok := client.stream.(*yamux.Stream); ok { 109 ys.Shrink() 110 } 111 } 112 c.clientLock.Unlock() 113 if !didSave { 114 client.Close() 115 } 116 } 117 118 // ConnPool is used to maintain a connection pool to other 119 // Nomad servers. This is used to reduce the latency of 120 // RPC requests between servers. It is only used to pool 121 // connections in the rpcNomad mode. Raft connections 122 // are pooled separately. 123 type ConnPool struct { 124 sync.Mutex 125 126 // logger is the logger to be used 127 logger *log.Logger 128 129 // The maximum time to keep a connection open 130 maxTime time.Duration 131 132 // The maximum number of open streams to keep 133 maxStreams int 134 135 // Pool maps an address to a open connection 136 pool map[string]*Conn 137 138 // limiter is used to throttle the number of connect attempts 139 // to a given address. The first thread will attempt a connection 140 // and put a channel in here, which all other threads will wait 141 // on to close. 142 limiter map[string]chan struct{} 143 144 // TLS wrapper 145 tlsWrap tlsutil.RegionWrapper 146 147 // Used to indicate the pool is shutdown 148 shutdown bool 149 shutdownCh chan struct{} 150 151 // connListener is used to notify a potential listener of a new connection 152 // being made. 153 connListener chan<- *yamux.Session 154 } 155 156 // NewPool is used to make a new connection pool 157 // Maintain at most one connection per host, for up to maxTime. 158 // Set maxTime to 0 to disable reaping. maxStreams is used to control 159 // the number of idle streams allowed. 160 // If TLS settings are provided outgoing connections use TLS. 161 func NewPool(logger hclog.Logger, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.RegionWrapper) *ConnPool { 162 pool := &ConnPool{ 163 logger: logger.StandardLogger(&hclog.StandardLoggerOptions{InferLevels: true}), 164 maxTime: maxTime, 165 maxStreams: maxStreams, 166 pool: make(map[string]*Conn), 167 limiter: make(map[string]chan struct{}), 168 tlsWrap: tlsWrap, 169 shutdownCh: make(chan struct{}), 170 } 171 if maxTime > 0 { 172 go pool.reap() 173 } 174 return pool 175 } 176 177 // Shutdown is used to close the connection pool 178 func (p *ConnPool) Shutdown() error { 179 p.Lock() 180 defer p.Unlock() 181 182 for _, conn := range p.pool { 183 conn.Close() 184 } 185 p.pool = make(map[string]*Conn) 186 187 if p.shutdown { 188 return nil 189 } 190 191 if p.connListener != nil { 192 close(p.connListener) 193 p.connListener = nil 194 } 195 196 p.shutdown = true 197 close(p.shutdownCh) 198 return nil 199 } 200 201 // ReloadTLS reloads TLS configuration on the fly 202 func (p *ConnPool) ReloadTLS(tlsWrap tlsutil.RegionWrapper) { 203 p.Lock() 204 defer p.Unlock() 205 206 oldPool := p.pool 207 for _, conn := range oldPool { 208 conn.Close() 209 } 210 p.pool = make(map[string]*Conn) 211 p.tlsWrap = tlsWrap 212 } 213 214 // SetConnListener is used to listen to new connections being made. The 215 // channel will be closed when the conn pool is closed or a new listener is set. 216 func (p *ConnPool) SetConnListener(l chan<- *yamux.Session) { 217 p.Lock() 218 defer p.Unlock() 219 220 // Close the old listener 221 if p.connListener != nil { 222 close(p.connListener) 223 } 224 225 // Store the new listener 226 p.connListener = l 227 } 228 229 // Acquire is used to get a connection that is 230 // pooled or to return a new connection 231 func (p *ConnPool) acquire(region string, addr net.Addr, version int) (*Conn, error) { 232 // Check to see if there's a pooled connection available. This is up 233 // here since it should the vastly more common case than the rest 234 // of the code here. 235 p.Lock() 236 c := p.pool[addr.String()] 237 if c != nil { 238 c.markForUse() 239 p.Unlock() 240 return c, nil 241 } 242 243 // If not (while we are still locked), set up the throttling structure 244 // for this address, which will make everyone else wait until our 245 // attempt is done. 246 var wait chan struct{} 247 var ok bool 248 if wait, ok = p.limiter[addr.String()]; !ok { 249 wait = make(chan struct{}) 250 p.limiter[addr.String()] = wait 251 } 252 isLeadThread := !ok 253 p.Unlock() 254 255 // If we are the lead thread, make the new connection and then wake 256 // everybody else up to see if we got it. 257 if isLeadThread { 258 c, err := p.getNewConn(region, addr, version) 259 p.Lock() 260 delete(p.limiter, addr.String()) 261 close(wait) 262 if err != nil { 263 p.Unlock() 264 return nil, err 265 } 266 267 p.pool[addr.String()] = c 268 269 // If there is a connection listener, notify them of the new connection. 270 if p.connListener != nil { 271 select { 272 case p.connListener <- c.session: 273 default: 274 } 275 } 276 277 p.Unlock() 278 return c, nil 279 } 280 281 // Otherwise, wait for the lead thread to attempt the connection 282 // and use what's in the pool at that point. 283 select { 284 case <-p.shutdownCh: 285 return nil, fmt.Errorf("rpc error: shutdown") 286 case <-wait: 287 } 288 289 // See if the lead thread was able to get us a connection. 290 p.Lock() 291 if c := p.pool[addr.String()]; c != nil { 292 c.markForUse() 293 p.Unlock() 294 return c, nil 295 } 296 297 p.Unlock() 298 return nil, fmt.Errorf("rpc error: lead thread didn't get connection") 299 } 300 301 // getNewConn is used to return a new connection 302 func (p *ConnPool) getNewConn(region string, addr net.Addr, version int) (*Conn, error) { 303 // Try to dial the conn 304 conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second) 305 if err != nil { 306 return nil, err 307 } 308 309 // Cast to TCPConn 310 if tcp, ok := conn.(*net.TCPConn); ok { 311 tcp.SetKeepAlive(true) 312 tcp.SetNoDelay(true) 313 } 314 315 // Check if TLS is enabled 316 if p.tlsWrap != nil { 317 // Switch the connection into TLS mode 318 if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil { 319 conn.Close() 320 return nil, err 321 } 322 323 // Wrap the connection in a TLS client 324 tlsConn, err := p.tlsWrap(region, conn) 325 if err != nil { 326 conn.Close() 327 return nil, err 328 } 329 conn = tlsConn 330 } 331 332 // Write the multiplex byte to set the mode 333 if _, err := conn.Write([]byte{byte(RpcMultiplex)}); err != nil { 334 conn.Close() 335 return nil, err 336 } 337 338 // Setup the logger 339 conf := yamux.DefaultConfig() 340 conf.LogOutput = nil 341 conf.Logger = p.logger 342 343 // Create a multiplexed session 344 session, err := yamux.Client(conn, conf) 345 if err != nil { 346 conn.Close() 347 return nil, err 348 } 349 350 // Wrap the connection 351 c := &Conn{ 352 refCount: 1, 353 addr: addr, 354 session: session, 355 clients: list.New(), 356 lastUsed: time.Now(), 357 version: version, 358 pool: p, 359 } 360 return c, nil 361 } 362 363 // clearConn is used to clear any cached connection, potentially in response to 364 // an error 365 func (p *ConnPool) clearConn(conn *Conn) { 366 // Ensure returned streams are closed 367 atomic.StoreInt32(&conn.shouldClose, 1) 368 369 // Clear from the cache 370 p.Lock() 371 if c, ok := p.pool[conn.addr.String()]; ok && c == conn { 372 delete(p.pool, conn.addr.String()) 373 } 374 p.Unlock() 375 376 // Close down immediately if idle 377 if refCount := atomic.LoadInt32(&conn.refCount); refCount == 0 { 378 conn.Close() 379 } 380 } 381 382 // releaseConn is invoked when we are done with a conn to reduce the ref count 383 func (p *ConnPool) releaseConn(conn *Conn) { 384 refCount := atomic.AddInt32(&conn.refCount, -1) 385 if refCount == 0 && atomic.LoadInt32(&conn.shouldClose) == 1 { 386 conn.Close() 387 } 388 } 389 390 // getClient is used to get a usable client for an address and protocol version 391 func (p *ConnPool) getClient(region string, addr net.Addr, version int) (*Conn, *StreamClient, error) { 392 retries := 0 393 START: 394 // Try to get a conn first 395 conn, err := p.acquire(region, addr, version) 396 if err != nil { 397 return nil, nil, fmt.Errorf("failed to get conn: %v", err) 398 } 399 400 // Get a client 401 client, err := conn.getClient() 402 if err != nil { 403 p.clearConn(conn) 404 p.releaseConn(conn) 405 406 // Try to redial, possible that the TCP session closed due to timeout 407 if retries == 0 { 408 retries++ 409 goto START 410 } 411 return nil, nil, fmt.Errorf("failed to start stream: %v", err) 412 } 413 return conn, client, nil 414 } 415 416 // RPC is used to make an RPC call to a remote host 417 func (p *ConnPool) RPC(region string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error { 418 // Get a usable client 419 conn, sc, err := p.getClient(region, addr, version) 420 if err != nil { 421 return fmt.Errorf("rpc error: %v", err) 422 } 423 424 // Make the RPC call 425 err = msgpackrpc.CallWithCodec(sc.codec, method, args, reply) 426 if err != nil { 427 sc.Close() 428 p.releaseConn(conn) 429 return fmt.Errorf("rpc error: %v", err) 430 } 431 432 // Done with the connection 433 conn.returnClient(sc) 434 p.releaseConn(conn) 435 return nil 436 } 437 438 // Reap is used to close conns open over maxTime 439 func (p *ConnPool) reap() { 440 for { 441 // Sleep for a while 442 select { 443 case <-p.shutdownCh: 444 return 445 case <-time.After(time.Second): 446 } 447 448 // Reap all old conns 449 p.Lock() 450 var removed []string 451 now := time.Now() 452 for host, conn := range p.pool { 453 // Skip recently used connections 454 if now.Sub(conn.lastUsed) < p.maxTime { 455 continue 456 } 457 458 // Skip connections with active streams 459 if atomic.LoadInt32(&conn.refCount) > 0 { 460 continue 461 } 462 463 // Close the conn 464 conn.Close() 465 466 // Remove from pool 467 removed = append(removed, host) 468 } 469 for _, host := range removed { 470 delete(p.pool, host) 471 } 472 p.Unlock() 473 } 474 }