github.com/diptanu/nomad@v0.5.7-0.20170516172507-d72e86cbe3d9/nomad/pool.go (about) 1 package nomad 2 3 import ( 4 "container/list" 5 "fmt" 6 "io" 7 "net" 8 "net/rpc" 9 "sync" 10 "sync/atomic" 11 "time" 12 13 "github.com/hashicorp/net-rpc-msgpackrpc" 14 "github.com/hashicorp/nomad/helper/tlsutil" 15 "github.com/hashicorp/yamux" 16 ) 17 18 // streamClient is used to wrap a stream with an RPC client 19 type StreamClient struct { 20 stream net.Conn 21 codec rpc.ClientCodec 22 } 23 24 func (sc *StreamClient) Close() { 25 sc.stream.Close() 26 } 27 28 // Conn is a pooled connection to a Nomad server 29 type Conn struct { 30 refCount int32 31 shouldClose int32 32 33 addr net.Addr 34 session *yamux.Session 35 lastUsed time.Time 36 version int 37 38 pool *ConnPool 39 40 clients *list.List 41 clientLock sync.Mutex 42 } 43 44 // markForUse does all the bookkeeping required to ready a connection for use. 45 func (c *Conn) markForUse() { 46 c.lastUsed = time.Now() 47 atomic.AddInt32(&c.refCount, 1) 48 } 49 50 func (c *Conn) Close() error { 51 return c.session.Close() 52 } 53 54 // getClient is used to get a cached or new client 55 func (c *Conn) getClient() (*StreamClient, error) { 56 // Check for cached client 57 c.clientLock.Lock() 58 front := c.clients.Front() 59 if front != nil { 60 c.clients.Remove(front) 61 } 62 c.clientLock.Unlock() 63 if front != nil { 64 return front.Value.(*StreamClient), nil 65 } 66 67 // Open a new session 68 stream, err := c.session.Open() 69 if err != nil { 70 return nil, err 71 } 72 73 // Create a client codec 74 codec := NewClientCodec(stream) 75 76 // Return a new stream client 77 sc := &StreamClient{ 78 stream: stream, 79 codec: codec, 80 } 81 return sc, nil 82 } 83 84 // returnStream is used when done with a stream 85 // to allow re-use by a future RPC 86 func (c *Conn) returnClient(client *StreamClient) { 87 didSave := false 88 c.clientLock.Lock() 89 if c.clients.Len() < c.pool.maxStreams && atomic.LoadInt32(&c.shouldClose) == 0 { 90 c.clients.PushFront(client) 91 didSave = true 92 93 // If this is a Yamux stream, shrink the internal buffers so that 94 // we can GC the idle memory 95 if ys, ok := client.stream.(*yamux.Stream); ok { 96 ys.Shrink() 97 } 98 } 99 c.clientLock.Unlock() 100 if !didSave { 101 client.Close() 102 } 103 } 104 105 // ConnPool is used to maintain a connection pool to other 106 // Nomad servers. This is used to reduce the latency of 107 // RPC requests between servers. It is only used to pool 108 // connections in the rpcNomad mode. Raft connections 109 // are pooled separately. 110 type ConnPool struct { 111 sync.Mutex 112 113 // LogOutput is used to control logging 114 logOutput io.Writer 115 116 // The maximum time to keep a connection open 117 maxTime time.Duration 118 119 // The maximum number of open streams to keep 120 maxStreams int 121 122 // Pool maps an address to a open connection 123 pool map[string]*Conn 124 125 // limiter is used to throttle the number of connect attempts 126 // to a given address. The first thread will attempt a connection 127 // and put a channel in here, which all other threads will wait 128 // on to close. 129 limiter map[string]chan struct{} 130 131 // TLS wrapper 132 tlsWrap tlsutil.RegionWrapper 133 134 // Used to indicate the pool is shutdown 135 shutdown bool 136 shutdownCh chan struct{} 137 } 138 139 // NewPool is used to make a new connection pool 140 // Maintain at most one connection per host, for up to maxTime. 141 // Set maxTime to 0 to disable reaping. maxStreams is used to control 142 // the number of idle streams allowed. 143 // If TLS settings are provided outgoing connections use TLS. 144 func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.RegionWrapper) *ConnPool { 145 pool := &ConnPool{ 146 logOutput: logOutput, 147 maxTime: maxTime, 148 maxStreams: maxStreams, 149 pool: make(map[string]*Conn), 150 limiter: make(map[string]chan struct{}), 151 tlsWrap: tlsWrap, 152 shutdownCh: make(chan struct{}), 153 } 154 if maxTime > 0 { 155 go pool.reap() 156 } 157 return pool 158 } 159 160 // Shutdown is used to close the connection pool 161 func (p *ConnPool) Shutdown() error { 162 p.Lock() 163 defer p.Unlock() 164 165 for _, conn := range p.pool { 166 conn.Close() 167 } 168 p.pool = make(map[string]*Conn) 169 170 if p.shutdown { 171 return nil 172 } 173 p.shutdown = true 174 close(p.shutdownCh) 175 return nil 176 } 177 178 // Acquire is used to get a connection that is 179 // pooled or to return a new connection 180 func (p *ConnPool) acquire(region string, addr net.Addr, version int) (*Conn, error) { 181 // Check to see if there's a pooled connection available. This is up 182 // here since it should the vastly more common case than the rest 183 // of the code here. 184 p.Lock() 185 c := p.pool[addr.String()] 186 if c != nil { 187 c.markForUse() 188 p.Unlock() 189 return c, nil 190 } 191 192 // If not (while we are still locked), set up the throttling structure 193 // for this address, which will make everyone else wait until our 194 // attempt is done. 195 var wait chan struct{} 196 var ok bool 197 if wait, ok = p.limiter[addr.String()]; !ok { 198 wait = make(chan struct{}) 199 p.limiter[addr.String()] = wait 200 } 201 isLeadThread := !ok 202 p.Unlock() 203 204 // If we are the lead thread, make the new connection and then wake 205 // everybody else up to see if we got it. 206 if isLeadThread { 207 c, err := p.getNewConn(region, addr, version) 208 p.Lock() 209 delete(p.limiter, addr.String()) 210 close(wait) 211 if err != nil { 212 p.Unlock() 213 return nil, err 214 } 215 216 p.pool[addr.String()] = c 217 p.Unlock() 218 return c, nil 219 } 220 221 // Otherwise, wait for the lead thread to attempt the connection 222 // and use what's in the pool at that point. 223 select { 224 case <-p.shutdownCh: 225 return nil, fmt.Errorf("rpc error: shutdown") 226 case <-wait: 227 } 228 229 // See if the lead thread was able to get us a connection. 230 p.Lock() 231 if c := p.pool[addr.String()]; c != nil { 232 c.markForUse() 233 p.Unlock() 234 return c, nil 235 } 236 237 p.Unlock() 238 return nil, fmt.Errorf("rpc error: lead thread didn't get connection") 239 } 240 241 // getNewConn is used to return a new connection 242 func (p *ConnPool) getNewConn(region string, addr net.Addr, version int) (*Conn, error) { 243 // Try to dial the conn 244 conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second) 245 if err != nil { 246 return nil, err 247 } 248 249 // Cast to TCPConn 250 if tcp, ok := conn.(*net.TCPConn); ok { 251 tcp.SetKeepAlive(true) 252 tcp.SetNoDelay(true) 253 } 254 255 // Check if TLS is enabled 256 if p.tlsWrap != nil { 257 // Switch the connection into TLS mode 258 if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil { 259 conn.Close() 260 return nil, err 261 } 262 263 // Wrap the connection in a TLS client 264 tlsConn, err := p.tlsWrap(region, conn) 265 if err != nil { 266 conn.Close() 267 return nil, err 268 } 269 conn = tlsConn 270 } 271 272 // Write the multiplex byte to set the mode 273 if _, err := conn.Write([]byte{byte(rpcMultiplex)}); err != nil { 274 conn.Close() 275 return nil, err 276 } 277 278 // Setup the logger 279 conf := yamux.DefaultConfig() 280 conf.LogOutput = p.logOutput 281 282 // Create a multiplexed session 283 session, err := yamux.Client(conn, conf) 284 if err != nil { 285 conn.Close() 286 return nil, err 287 } 288 289 // Wrap the connection 290 c := &Conn{ 291 refCount: 1, 292 addr: addr, 293 session: session, 294 clients: list.New(), 295 lastUsed: time.Now(), 296 version: version, 297 pool: p, 298 } 299 return c, nil 300 } 301 302 // clearConn is used to clear any cached connection, potentially in response to an erro 303 func (p *ConnPool) clearConn(conn *Conn) { 304 // Ensure returned streams are closed 305 atomic.StoreInt32(&conn.shouldClose, 1) 306 307 // Clear from the cache 308 p.Lock() 309 if c, ok := p.pool[conn.addr.String()]; ok && c == conn { 310 delete(p.pool, conn.addr.String()) 311 } 312 p.Unlock() 313 314 // Close down immediately if idle 315 if refCount := atomic.LoadInt32(&conn.refCount); refCount == 0 { 316 conn.Close() 317 } 318 } 319 320 // releaseConn is invoked when we are done with a conn to reduce the ref count 321 func (p *ConnPool) releaseConn(conn *Conn) { 322 refCount := atomic.AddInt32(&conn.refCount, -1) 323 if refCount == 0 && atomic.LoadInt32(&conn.shouldClose) == 1 { 324 conn.Close() 325 } 326 } 327 328 // getClient is used to get a usable client for an address and protocol version 329 func (p *ConnPool) getClient(region string, addr net.Addr, version int) (*Conn, *StreamClient, error) { 330 retries := 0 331 START: 332 // Try to get a conn first 333 conn, err := p.acquire(region, addr, version) 334 if err != nil { 335 return nil, nil, fmt.Errorf("failed to get conn: %v", err) 336 } 337 338 // Get a client 339 client, err := conn.getClient() 340 if err != nil { 341 p.clearConn(conn) 342 p.releaseConn(conn) 343 344 // Try to redial, possible that the TCP session closed due to timeout 345 if retries == 0 { 346 retries++ 347 goto START 348 } 349 return nil, nil, fmt.Errorf("failed to start stream: %v", err) 350 } 351 return conn, client, nil 352 } 353 354 // RPC is used to make an RPC call to a remote host 355 func (p *ConnPool) RPC(region string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error { 356 // Get a usable client 357 conn, sc, err := p.getClient(region, addr, version) 358 if err != nil { 359 return fmt.Errorf("rpc error: %v", err) 360 } 361 362 // Make the RPC call 363 err = msgpackrpc.CallWithCodec(sc.codec, method, args, reply) 364 if err != nil { 365 sc.Close() 366 p.releaseConn(conn) 367 return fmt.Errorf("rpc error: %v", err) 368 } 369 370 // Done with the connection 371 conn.returnClient(sc) 372 p.releaseConn(conn) 373 return nil 374 } 375 376 // Reap is used to close conns open over maxTime 377 func (p *ConnPool) reap() { 378 for { 379 // Sleep for a while 380 select { 381 case <-p.shutdownCh: 382 return 383 case <-time.After(time.Second): 384 } 385 386 // Reap all old conns 387 p.Lock() 388 var removed []string 389 now := time.Now() 390 for host, conn := range p.pool { 391 // Skip recently used connections 392 if now.Sub(conn.lastUsed) < p.maxTime { 393 continue 394 } 395 396 // Skip connections with active streams 397 if atomic.LoadInt32(&conn.refCount) > 0 { 398 continue 399 } 400 401 // Close the conn 402 conn.Close() 403 404 // Remove from pool 405 removed = append(removed, host) 406 } 407 for _, host := range removed { 408 delete(p.pool, host) 409 } 410 p.Unlock() 411 } 412 }