github.com/anycable/anycable-go@v1.5.1/pool/pool.go (about)

     1  package pool
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"sync"
     7  
     8  	"google.golang.org/grpc"
     9  )
    10  
    11  var (
    12  	errClosed = errors.New("pool is closed")
    13  )
    14  
    15  // Pool represents connection pool
    16  type Pool interface {
    17  	Get() (Conn, error)
    18  	Close()
    19  	Available() int
    20  	Busy() int
    21  }
    22  
    23  type channelPool struct {
    24  	mu          sync.Mutex
    25  	conns       chan *grpc.ClientConn
    26  	activeCount int
    27  	factory     Factory
    28  }
    29  
    30  // Factory is a connection pool factory fun interface
    31  type Factory func() (*grpc.ClientConn, error)
    32  
    33  // Conn is a single connection to pool wrapper
    34  type Conn struct {
    35  	Conn *grpc.ClientConn
    36  	c    *channelPool
    37  }
    38  
    39  // Close connection
    40  func (p Conn) Close() error {
    41  	return p.c.put(p.Conn)
    42  }
    43  
    44  func (c *channelPool) wrapConn(conn *grpc.ClientConn) Conn {
    45  	c.mu.Lock()
    46  	p := Conn{Conn: conn, c: c}
    47  	c.activeCount++
    48  	c.mu.Unlock()
    49  	return p
    50  }
    51  
    52  // NewChannelPool builds a new pool with provided configuration
    53  func NewChannelPool(initialCap, maxCap int, factory Factory) (Pool, error) {
    54  	if initialCap < 0 || maxCap <= 0 || initialCap > maxCap {
    55  		return nil, errors.New("invalid capacity settings")
    56  	}
    57  
    58  	c := &channelPool{
    59  		conns:       make(chan *grpc.ClientConn, maxCap),
    60  		factory:     factory,
    61  		activeCount: 0,
    62  	}
    63  
    64  	for i := 0; i < initialCap; i++ {
    65  		conn, err := factory()
    66  		if err != nil {
    67  			c.Close()
    68  			return nil, fmt.Errorf("factory is not able to fill the pool: %s", err)
    69  		}
    70  		c.conns <- conn
    71  	}
    72  
    73  	return c, nil
    74  }
    75  
    76  func (c *channelPool) getConns() chan *grpc.ClientConn {
    77  	c.mu.Lock()
    78  	conns := c.conns
    79  	c.mu.Unlock()
    80  	return conns
    81  }
    82  
    83  func (c *channelPool) Get() (Conn, error) {
    84  	conns := c.getConns()
    85  	if conns == nil {
    86  		return Conn{}, errClosed
    87  	}
    88  
    89  	// wrap our connections with out custom grpc.ClientConn implementation (wrapConn
    90  	// method) that puts the connection back to the pool if it's closed.
    91  	select {
    92  	case conn := <-conns:
    93  		if conn == nil {
    94  			return Conn{}, errClosed
    95  		}
    96  
    97  		return c.wrapConn(conn), nil
    98  	default:
    99  		conn, err := c.factory()
   100  		if err != nil {
   101  			return Conn{}, err
   102  		}
   103  
   104  		return c.wrapConn(conn), nil
   105  	}
   106  }
   107  
   108  func (c *channelPool) put(conn *grpc.ClientConn) error {
   109  	if conn == nil {
   110  		return errors.New("connection is nil. rejecting")
   111  	}
   112  
   113  	c.mu.Lock()
   114  	defer c.mu.Unlock()
   115  
   116  	c.activeCount--
   117  
   118  	if c.conns == nil {
   119  		// pool is closed, close passed connection
   120  		return conn.Close()
   121  	}
   122  
   123  	// put the resource back into the pool. If the pool is full, this will
   124  	// block and the default case will be executed.
   125  	select {
   126  	case c.conns <- conn:
   127  		return nil
   128  	default:
   129  		// pool is full, close passed connection
   130  		return conn.Close()
   131  	}
   132  }
   133  
   134  // Close all connections and pool's channel
   135  func (c *channelPool) Close() {
   136  	c.mu.Lock()
   137  	conns := c.conns
   138  	c.conns = nil
   139  	c.factory = nil
   140  	c.mu.Unlock()
   141  
   142  	if conns == nil {
   143  		return
   144  	}
   145  
   146  	close(conns)
   147  	for conn := range conns {
   148  		conn.Close()
   149  	}
   150  }
   151  
   152  func (c *channelPool) Busy() int { return c.activeCount }
   153  
   154  func (c *channelPool) Available() int { return len(c.getConns()) }