github.com/quite/nomad@v0.8.6/helper/pool/pool.go (about) 1 package pool 2 3 import ( 4 "container/list" 5 "fmt" 6 "io" 7 "net" 8 "net/rpc" 9 "sync" 10 "sync/atomic" 11 "time" 12 13 msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" 14 "github.com/hashicorp/nomad/helper/tlsutil" 15 "github.com/hashicorp/nomad/nomad/structs" 16 "github.com/hashicorp/yamux" 17 ) 18 19 // NewClientCodec returns a new rpc.ClientCodec to be used to make RPC calls. 20 func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec { 21 return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.HashiMsgpackHandle) 22 } 23 24 // NewServerCodec returns a new rpc.ServerCodec to be used to handle RPCs. 25 func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec { 26 return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.HashiMsgpackHandle) 27 } 28 29 // streamClient is used to wrap a stream with an RPC client 30 type StreamClient struct { 31 stream net.Conn 32 codec rpc.ClientCodec 33 } 34 35 func (sc *StreamClient) Close() { 36 sc.stream.Close() 37 } 38 39 // Conn is a pooled connection to a Nomad server 40 type Conn struct { 41 refCount int32 42 shouldClose int32 43 44 addr net.Addr 45 session *yamux.Session 46 lastUsed time.Time 47 version int 48 49 pool *ConnPool 50 51 clients *list.List 52 clientLock sync.Mutex 53 } 54 55 // markForUse does all the bookkeeping required to ready a connection for use. 56 func (c *Conn) markForUse() { 57 c.lastUsed = time.Now() 58 atomic.AddInt32(&c.refCount, 1) 59 } 60 61 func (c *Conn) Close() error { 62 return c.session.Close() 63 } 64 65 // getClient is used to get a cached or new client 66 func (c *Conn) getClient() (*StreamClient, error) { 67 // Check for cached client 68 c.clientLock.Lock() 69 front := c.clients.Front() 70 if front != nil { 71 c.clients.Remove(front) 72 } 73 c.clientLock.Unlock() 74 if front != nil { 75 return front.Value.(*StreamClient), nil 76 } 77 78 // Open a new session 79 stream, err := c.session.Open() 80 if err != nil { 81 return nil, err 82 } 83 84 // Create a client codec 85 codec := NewClientCodec(stream) 86 87 // Return a new stream client 88 sc := &StreamClient{ 89 stream: stream, 90 codec: codec, 91 } 92 return sc, nil 93 } 94 95 // returnClient is used when done with a stream 96 // to allow re-use by a future RPC 97 func (c *Conn) returnClient(client *StreamClient) { 98 didSave := false 99 c.clientLock.Lock() 100 if c.clients.Len() < c.pool.maxStreams && atomic.LoadInt32(&c.shouldClose) == 0 { 101 c.clients.PushFront(client) 102 didSave = true 103 104 // If this is a Yamux stream, shrink the internal buffers so that 105 // we can GC the idle memory 106 if ys, ok := client.stream.(*yamux.Stream); ok { 107 ys.Shrink() 108 } 109 } 110 c.clientLock.Unlock() 111 if !didSave { 112 client.Close() 113 } 114 } 115 116 // ConnPool is used to maintain a connection pool to other 117 // Nomad servers. This is used to reduce the latency of 118 // RPC requests between servers. It is only used to pool 119 // connections in the rpcNomad mode. Raft connections 120 // are pooled separately. 121 type ConnPool struct { 122 sync.Mutex 123 124 // LogOutput is used to control logging 125 logOutput io.Writer 126 127 // The maximum time to keep a connection open 128 maxTime time.Duration 129 130 // The maximum number of open streams to keep 131 maxStreams int 132 133 // Pool maps an address to a open connection 134 pool map[string]*Conn 135 136 // limiter is used to throttle the number of connect attempts 137 // to a given address. The first thread will attempt a connection 138 // and put a channel in here, which all other threads will wait 139 // on to close. 140 limiter map[string]chan struct{} 141 142 // TLS wrapper 143 tlsWrap tlsutil.RegionWrapper 144 145 // Used to indicate the pool is shutdown 146 shutdown bool 147 shutdownCh chan struct{} 148 149 // connListener is used to notify a potential listener of a new connection 150 // being made. 151 connListener chan<- *yamux.Session 152 } 153 154 // NewPool is used to make a new connection pool 155 // Maintain at most one connection per host, for up to maxTime. 156 // Set maxTime to 0 to disable reaping. maxStreams is used to control 157 // the number of idle streams allowed. 158 // If TLS settings are provided outgoing connections use TLS. 159 func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.RegionWrapper) *ConnPool { 160 pool := &ConnPool{ 161 logOutput: logOutput, 162 maxTime: maxTime, 163 maxStreams: maxStreams, 164 pool: make(map[string]*Conn), 165 limiter: make(map[string]chan struct{}), 166 tlsWrap: tlsWrap, 167 shutdownCh: make(chan struct{}), 168 } 169 if maxTime > 0 { 170 go pool.reap() 171 } 172 return pool 173 } 174 175 // Shutdown is used to close the connection pool 176 func (p *ConnPool) Shutdown() error { 177 p.Lock() 178 defer p.Unlock() 179 180 for _, conn := range p.pool { 181 conn.Close() 182 } 183 p.pool = make(map[string]*Conn) 184 185 if p.shutdown { 186 return nil 187 } 188 189 if p.connListener != nil { 190 close(p.connListener) 191 p.connListener = nil 192 } 193 194 p.shutdown = true 195 close(p.shutdownCh) 196 return nil 197 } 198 199 // ReloadTLS reloads TLS configuration on the fly 200 func (p *ConnPool) ReloadTLS(tlsWrap tlsutil.RegionWrapper) { 201 p.Lock() 202 defer p.Unlock() 203 204 oldPool := p.pool 205 for _, conn := range oldPool { 206 conn.Close() 207 } 208 p.pool = make(map[string]*Conn) 209 p.tlsWrap = tlsWrap 210 } 211 212 // SetConnListener is used to listen to new connections being made. The 213 // channel will be closed when the conn pool is closed or a new listener is set. 214 func (p *ConnPool) SetConnListener(l chan<- *yamux.Session) { 215 p.Lock() 216 defer p.Unlock() 217 218 // Close the old listener 219 if p.connListener != nil { 220 close(p.connListener) 221 } 222 223 // Store the new listener 224 p.connListener = l 225 } 226 227 // Acquire is used to get a connection that is 228 // pooled or to return a new connection 229 func (p *ConnPool) acquire(region string, addr net.Addr, version int) (*Conn, error) { 230 // Check to see if there's a pooled connection available. This is up 231 // here since it should the vastly more common case than the rest 232 // of the code here. 233 p.Lock() 234 c := p.pool[addr.String()] 235 if c != nil { 236 c.markForUse() 237 p.Unlock() 238 return c, nil 239 } 240 241 // If not (while we are still locked), set up the throttling structure 242 // for this address, which will make everyone else wait until our 243 // attempt is done. 244 var wait chan struct{} 245 var ok bool 246 if wait, ok = p.limiter[addr.String()]; !ok { 247 wait = make(chan struct{}) 248 p.limiter[addr.String()] = wait 249 } 250 isLeadThread := !ok 251 p.Unlock() 252 253 // If we are the lead thread, make the new connection and then wake 254 // everybody else up to see if we got it. 255 if isLeadThread { 256 c, err := p.getNewConn(region, addr, version) 257 p.Lock() 258 delete(p.limiter, addr.String()) 259 close(wait) 260 if err != nil { 261 p.Unlock() 262 return nil, err 263 } 264 265 p.pool[addr.String()] = c 266 267 // If there is a connection listener, notify them of the new connection. 268 if p.connListener != nil { 269 select { 270 case p.connListener <- c.session: 271 default: 272 } 273 } 274 275 p.Unlock() 276 return c, nil 277 } 278 279 // Otherwise, wait for the lead thread to attempt the connection 280 // and use what's in the pool at that point. 281 select { 282 case <-p.shutdownCh: 283 return nil, fmt.Errorf("rpc error: shutdown") 284 case <-wait: 285 } 286 287 // See if the lead thread was able to get us a connection. 288 p.Lock() 289 if c := p.pool[addr.String()]; c != nil { 290 c.markForUse() 291 p.Unlock() 292 return c, nil 293 } 294 295 p.Unlock() 296 return nil, fmt.Errorf("rpc error: lead thread didn't get connection") 297 } 298 299 // getNewConn is used to return a new connection 300 func (p *ConnPool) getNewConn(region string, addr net.Addr, version int) (*Conn, error) { 301 // Try to dial the conn 302 conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second) 303 if err != nil { 304 return nil, err 305 } 306 307 // Cast to TCPConn 308 if tcp, ok := conn.(*net.TCPConn); ok { 309 tcp.SetKeepAlive(true) 310 tcp.SetNoDelay(true) 311 } 312 313 // Check if TLS is enabled 314 if p.tlsWrap != nil { 315 // Switch the connection into TLS mode 316 if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil { 317 conn.Close() 318 return nil, err 319 } 320 321 // Wrap the connection in a TLS client 322 tlsConn, err := p.tlsWrap(region, conn) 323 if err != nil { 324 conn.Close() 325 return nil, err 326 } 327 conn = tlsConn 328 } 329 330 // Write the multiplex byte to set the mode 331 if _, err := conn.Write([]byte{byte(RpcMultiplex)}); err != nil { 332 conn.Close() 333 return nil, err 334 } 335 336 // Setup the logger 337 conf := yamux.DefaultConfig() 338 conf.LogOutput = p.logOutput 339 340 // Create a multiplexed session 341 session, err := yamux.Client(conn, conf) 342 if err != nil { 343 conn.Close() 344 return nil, err 345 } 346 347 // Wrap the connection 348 c := &Conn{ 349 refCount: 1, 350 addr: addr, 351 session: session, 352 clients: list.New(), 353 lastUsed: time.Now(), 354 version: version, 355 pool: p, 356 } 357 return c, nil 358 } 359 360 // clearConn is used to clear any cached connection, potentially in response to 361 // an error 362 func (p *ConnPool) clearConn(conn *Conn) { 363 // Ensure returned streams are closed 364 atomic.StoreInt32(&conn.shouldClose, 1) 365 366 // Clear from the cache 367 p.Lock() 368 if c, ok := p.pool[conn.addr.String()]; ok && c == conn { 369 delete(p.pool, conn.addr.String()) 370 } 371 p.Unlock() 372 373 // Close down immediately if idle 374 if refCount := atomic.LoadInt32(&conn.refCount); refCount == 0 { 375 conn.Close() 376 } 377 } 378 379 // releaseConn is invoked when we are done with a conn to reduce the ref count 380 func (p *ConnPool) releaseConn(conn *Conn) { 381 refCount := atomic.AddInt32(&conn.refCount, -1) 382 if refCount == 0 && atomic.LoadInt32(&conn.shouldClose) == 1 { 383 conn.Close() 384 } 385 } 386 387 // getClient is used to get a usable client for an address and protocol version 388 func (p *ConnPool) getClient(region string, addr net.Addr, version int) (*Conn, *StreamClient, error) { 389 retries := 0 390 START: 391 // Try to get a conn first 392 conn, err := p.acquire(region, addr, version) 393 if err != nil { 394 return nil, nil, fmt.Errorf("failed to get conn: %v", err) 395 } 396 397 // Get a client 398 client, err := conn.getClient() 399 if err != nil { 400 p.clearConn(conn) 401 p.releaseConn(conn) 402 403 // Try to redial, possible that the TCP session closed due to timeout 404 if retries == 0 { 405 retries++ 406 goto START 407 } 408 return nil, nil, fmt.Errorf("failed to start stream: %v", err) 409 } 410 return conn, client, nil 411 } 412 413 // RPC is used to make an RPC call to a remote host 414 func (p *ConnPool) RPC(region string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error { 415 // Get a usable client 416 conn, sc, err := p.getClient(region, addr, version) 417 if err != nil { 418 return fmt.Errorf("rpc error: %v", err) 419 } 420 421 // Make the RPC call 422 err = msgpackrpc.CallWithCodec(sc.codec, method, args, reply) 423 if err != nil { 424 sc.Close() 425 p.releaseConn(conn) 426 return fmt.Errorf("rpc error: %v", err) 427 } 428 429 // Done with the connection 430 conn.returnClient(sc) 431 p.releaseConn(conn) 432 return nil 433 } 434 435 // Reap is used to close conns open over maxTime 436 func (p *ConnPool) reap() { 437 for { 438 // Sleep for a while 439 select { 440 case <-p.shutdownCh: 441 return 442 case <-time.After(time.Second): 443 } 444 445 // Reap all old conns 446 p.Lock() 447 var removed []string 448 now := time.Now() 449 for host, conn := range p.pool { 450 // Skip recently used connections 451 if now.Sub(conn.lastUsed) < p.maxTime { 452 continue 453 } 454 455 // Skip connections with active streams 456 if atomic.LoadInt32(&conn.refCount) > 0 { 457 continue 458 } 459 460 // Close the conn 461 conn.Close() 462 463 // Remove from pool 464 removed = append(removed, host) 465 } 466 for _, host := range removed { 467 delete(p.pool, host) 468 } 469 p.Unlock() 470 } 471 }