github.com/jmitchell/nomad@v0.1.3-0.20151007230021-7ab84c2862d8/nomad/pool.go (about) 1 package nomad 2 3 import ( 4 "container/list" 5 "fmt" 6 "io" 7 "net" 8 "net/rpc" 9 "sync" 10 "sync/atomic" 11 "time" 12 13 "github.com/hashicorp/consul/tlsutil" 14 "github.com/hashicorp/net-rpc-msgpackrpc" 15 "github.com/hashicorp/yamux" 16 ) 17 18 // streamClient is used to wrap a stream with an RPC client 19 type StreamClient struct { 20 stream net.Conn 21 codec rpc.ClientCodec 22 } 23 24 func (sc *StreamClient) Close() { 25 sc.stream.Close() 26 } 27 28 // Conn is a pooled connection to a Nomad server 29 type Conn struct { 30 refCount int32 31 shouldClose int32 32 33 addr net.Addr 34 session *yamux.Session 35 lastUsed time.Time 36 version int 37 38 pool *ConnPool 39 40 clients *list.List 41 clientLock sync.Mutex 42 } 43 44 // markForUse does all the bookkeeping required to ready a connection for use. 45 func (c *Conn) markForUse() { 46 c.lastUsed = time.Now() 47 atomic.AddInt32(&c.refCount, 1) 48 } 49 50 func (c *Conn) Close() error { 51 return c.session.Close() 52 } 53 54 // getClient is used to get a cached or new client 55 func (c *Conn) getClient() (*StreamClient, error) { 56 // Check for cached client 57 c.clientLock.Lock() 58 front := c.clients.Front() 59 if front != nil { 60 c.clients.Remove(front) 61 } 62 c.clientLock.Unlock() 63 if front != nil { 64 return front.Value.(*StreamClient), nil 65 } 66 67 // Open a new session 68 stream, err := c.session.Open() 69 if err != nil { 70 return nil, err 71 } 72 73 // Create a client codec 74 codec := msgpackrpc.NewClientCodec(stream) 75 76 // Return a new stream client 77 sc := &StreamClient{ 78 stream: stream, 79 codec: codec, 80 } 81 return sc, nil 82 } 83 84 // returnStream is used when done with a stream 85 // to allow re-use by a future RPC 86 func (c *Conn) returnClient(client *StreamClient) { 87 didSave := false 88 c.clientLock.Lock() 89 if c.clients.Len() < c.pool.maxStreams && atomic.LoadInt32(&c.shouldClose) == 0 { 90 c.clients.PushFront(client) 91 didSave = true 92 } 93 c.clientLock.Unlock() 94 if !didSave { 95 client.Close() 96 } 97 } 98 99 // ConnPool is used to maintain a connection pool to other 100 // Nomad servers. This is used to reduce the latency of 101 // RPC requests between servers. It is only used to pool 102 // connections in the rpcNomad mode. Raft connections 103 // are pooled separately. 104 type ConnPool struct { 105 sync.Mutex 106 107 // LogOutput is used to control logging 108 logOutput io.Writer 109 110 // The maximum time to keep a connection open 111 maxTime time.Duration 112 113 // The maximum number of open streams to keep 114 maxStreams int 115 116 // Pool maps an address to a open connection 117 pool map[string]*Conn 118 119 // limiter is used to throttle the number of connect attempts 120 // to a given address. The first thread will attempt a connection 121 // and put a channel in here, which all other threads will wait 122 // on to close. 123 limiter map[string]chan struct{} 124 125 // TLS wrapper 126 tlsWrap tlsutil.DCWrapper 127 128 // Used to indicate the pool is shutdown 129 shutdown bool 130 shutdownCh chan struct{} 131 } 132 133 // NewPool is used to make a new connection pool 134 // Maintain at most one connection per host, for up to maxTime. 135 // Set maxTime to 0 to disable reaping. maxStreams is used to control 136 // the number of idle streams allowed. 137 // If TLS settings are provided outgoing connections use TLS. 138 func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.DCWrapper) *ConnPool { 139 pool := &ConnPool{ 140 logOutput: logOutput, 141 maxTime: maxTime, 142 maxStreams: maxStreams, 143 pool: make(map[string]*Conn), 144 limiter: make(map[string]chan struct{}), 145 tlsWrap: tlsWrap, 146 shutdownCh: make(chan struct{}), 147 } 148 if maxTime > 0 { 149 go pool.reap() 150 } 151 return pool 152 } 153 154 // Shutdown is used to close the connection pool 155 func (p *ConnPool) Shutdown() error { 156 p.Lock() 157 defer p.Unlock() 158 159 for _, conn := range p.pool { 160 conn.Close() 161 } 162 p.pool = make(map[string]*Conn) 163 164 if p.shutdown { 165 return nil 166 } 167 p.shutdown = true 168 close(p.shutdownCh) 169 return nil 170 } 171 172 // Acquire is used to get a connection that is 173 // pooled or to return a new connection 174 func (p *ConnPool) acquire(region string, addr net.Addr, version int) (*Conn, error) { 175 // Check to see if there's a pooled connection available. This is up 176 // here since it should the the vastly more common case than the rest 177 // of the code here. 178 p.Lock() 179 c := p.pool[addr.String()] 180 if c != nil { 181 c.markForUse() 182 p.Unlock() 183 return c, nil 184 } 185 186 // If not (while we are still locked), set up the throttling structure 187 // for this address, which will make everyone else wait until our 188 // attempt is done. 189 var wait chan struct{} 190 var ok bool 191 if wait, ok = p.limiter[addr.String()]; !ok { 192 wait = make(chan struct{}) 193 p.limiter[addr.String()] = wait 194 } 195 isLeadThread := !ok 196 p.Unlock() 197 198 // If we are the lead thread, make the new connection and then wake 199 // everybody else up to see if we got it. 200 if isLeadThread { 201 c, err := p.getNewConn(region, addr, version) 202 p.Lock() 203 delete(p.limiter, addr.String()) 204 close(wait) 205 if err != nil { 206 p.Unlock() 207 return nil, err 208 } 209 210 p.pool[addr.String()] = c 211 p.Unlock() 212 return c, nil 213 } 214 215 // Otherwise, wait for the lead thread to attempt the connection 216 // and use what's in the pool at that point. 217 select { 218 case <-p.shutdownCh: 219 return nil, fmt.Errorf("rpc error: shutdown") 220 case <-wait: 221 } 222 223 // See if the lead thread was able to get us a connection. 224 p.Lock() 225 if c := p.pool[addr.String()]; c != nil { 226 c.markForUse() 227 p.Unlock() 228 return c, nil 229 } 230 231 p.Unlock() 232 return nil, fmt.Errorf("rpc error: lead thread didn't get connection") 233 } 234 235 // getNewConn is used to return a new connection 236 func (p *ConnPool) getNewConn(region string, addr net.Addr, version int) (*Conn, error) { 237 // Try to dial the conn 238 conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second) 239 if err != nil { 240 return nil, err 241 } 242 243 // Cast to TCPConn 244 if tcp, ok := conn.(*net.TCPConn); ok { 245 tcp.SetKeepAlive(true) 246 tcp.SetNoDelay(true) 247 } 248 249 // Check if TLS is enabled 250 if p.tlsWrap != nil { 251 // Switch the connection into TLS mode 252 if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil { 253 conn.Close() 254 return nil, err 255 } 256 257 // Wrap the connection in a TLS client 258 tlsConn, err := p.tlsWrap(region, conn) 259 if err != nil { 260 conn.Close() 261 return nil, err 262 } 263 conn = tlsConn 264 } 265 266 // Write the multiplex byte to set the mode 267 if _, err := conn.Write([]byte{byte(rpcMultiplex)}); err != nil { 268 conn.Close() 269 return nil, err 270 } 271 272 // Setup the logger 273 conf := yamux.DefaultConfig() 274 conf.LogOutput = p.logOutput 275 276 // Create a multiplexed session 277 session, err := yamux.Client(conn, conf) 278 if err != nil { 279 conn.Close() 280 return nil, err 281 } 282 283 // Wrap the connection 284 c := &Conn{ 285 refCount: 1, 286 addr: addr, 287 session: session, 288 clients: list.New(), 289 lastUsed: time.Now(), 290 version: version, 291 pool: p, 292 } 293 return c, nil 294 } 295 296 // clearConn is used to clear any cached connection, potentially in response to an erro 297 func (p *ConnPool) clearConn(conn *Conn) { 298 // Ensure returned streams are closed 299 atomic.StoreInt32(&conn.shouldClose, 1) 300 301 // Clear from the cache 302 p.Lock() 303 if c, ok := p.pool[conn.addr.String()]; ok && c == conn { 304 delete(p.pool, conn.addr.String()) 305 } 306 p.Unlock() 307 308 // Close down immediately if idle 309 if refCount := atomic.LoadInt32(&conn.refCount); refCount == 0 { 310 conn.Close() 311 } 312 } 313 314 // releaseConn is invoked when we are done with a conn to reduce the ref count 315 func (p *ConnPool) releaseConn(conn *Conn) { 316 refCount := atomic.AddInt32(&conn.refCount, -1) 317 if refCount == 0 && atomic.LoadInt32(&conn.shouldClose) == 1 { 318 conn.Close() 319 } 320 } 321 322 // getClient is used to get a usable client for an address and protocol version 323 func (p *ConnPool) getClient(region string, addr net.Addr, version int) (*Conn, *StreamClient, error) { 324 retries := 0 325 START: 326 // Try to get a conn first 327 conn, err := p.acquire(region, addr, version) 328 if err != nil { 329 return nil, nil, fmt.Errorf("failed to get conn: %v", err) 330 } 331 332 // Get a client 333 client, err := conn.getClient() 334 if err != nil { 335 p.clearConn(conn) 336 p.releaseConn(conn) 337 338 // Try to redial, possible that the TCP session closed due to timeout 339 if retries == 0 { 340 retries++ 341 goto START 342 } 343 return nil, nil, fmt.Errorf("failed to start stream: %v", err) 344 } 345 return conn, client, nil 346 } 347 348 // RPC is used to make an RPC call to a remote host 349 func (p *ConnPool) RPC(region string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error { 350 // Get a usable client 351 conn, sc, err := p.getClient(region, addr, version) 352 if err != nil { 353 return fmt.Errorf("rpc error: %v", err) 354 } 355 356 // Make the RPC call 357 err = msgpackrpc.CallWithCodec(sc.codec, method, args, reply) 358 if err != nil { 359 sc.Close() 360 p.releaseConn(conn) 361 return fmt.Errorf("rpc error: %v", err) 362 } 363 364 // Done with the connection 365 conn.returnClient(sc) 366 p.releaseConn(conn) 367 return nil 368 } 369 370 // Reap is used to close conns open over maxTime 371 func (p *ConnPool) reap() { 372 for { 373 // Sleep for a while 374 select { 375 case <-p.shutdownCh: 376 return 377 case <-time.After(time.Second): 378 } 379 380 // Reap all old conns 381 p.Lock() 382 var removed []string 383 now := time.Now() 384 for host, conn := range p.pool { 385 // Skip recently used connections 386 if now.Sub(conn.lastUsed) < p.maxTime { 387 continue 388 } 389 390 // Skip connections with active streams 391 if atomic.LoadInt32(&conn.refCount) > 0 { 392 continue 393 } 394 395 // Close the conn 396 conn.Close() 397 398 // Remove from pool 399 removed = append(removed, host) 400 } 401 for _, host := range removed { 402 delete(p.pool, host) 403 } 404 p.Unlock() 405 } 406 }