github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/conn/pool.go (about)

     1  package conn
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"sync/atomic"
     7  	"time"
     8  
     9  	"google.golang.org/grpc"
    10  
    11  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/closer"
    12  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
    15  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    16  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync"
    17  	"github.com/ydb-platform/ydb-go-sdk/v3/trace"
    18  )
    19  
    20  type connsKey struct {
    21  	address string
    22  	nodeID  uint32
    23  }
    24  
    25  type Pool struct {
    26  	usages int64
    27  	config Config
    28  	mtx    xsync.RWMutex
    29  	opts   []grpc.DialOption
    30  	conns  map[connsKey]*conn
    31  	done   chan struct{}
    32  }
    33  
    34  func (p *Pool) Get(endpoint endpoint.Endpoint) Conn {
    35  	p.mtx.Lock()
    36  	defer p.mtx.Unlock()
    37  
    38  	var (
    39  		address = endpoint.Address()
    40  		cc      *conn
    41  		has     bool
    42  	)
    43  
    44  	key := connsKey{address, endpoint.NodeID()}
    45  
    46  	if cc, has = p.conns[key]; has {
    47  		return cc
    48  	}
    49  
    50  	cc = newConn(
    51  		endpoint,
    52  		p.config,
    53  		withOnClose(p.remove),
    54  		withOnTransportError(p.Ban),
    55  	)
    56  
    57  	p.conns[key] = cc
    58  
    59  	return cc
    60  }
    61  
    62  func (p *Pool) remove(c *conn) {
    63  	p.mtx.Lock()
    64  	defer p.mtx.Unlock()
    65  	delete(p.conns, connsKey{c.Endpoint().Address(), c.Endpoint().NodeID()})
    66  }
    67  
    68  func (p *Pool) isClosed() bool {
    69  	select {
    70  	case <-p.done:
    71  		return true
    72  	default:
    73  		return false
    74  	}
    75  }
    76  
    77  func (p *Pool) Ban(ctx context.Context, cc Conn, cause error) {
    78  	if p.isClosed() {
    79  		return
    80  	}
    81  
    82  	e := cc.Endpoint().Copy()
    83  
    84  	p.mtx.RLock()
    85  	defer p.mtx.RUnlock()
    86  
    87  	cc, ok := p.conns[connsKey{e.Address(), e.NodeID()}]
    88  	if !ok {
    89  		return
    90  	}
    91  
    92  	trace.DriverOnConnBan(
    93  		p.config.Trace(), &ctx,
    94  		stack.FunctionID(""),
    95  		e, cc.GetState(), cause,
    96  	)(cc.SetState(ctx, Banned))
    97  }
    98  
    99  func (p *Pool) Allow(ctx context.Context, cc Conn) {
   100  	if p.isClosed() {
   101  		return
   102  	}
   103  
   104  	e := cc.Endpoint().Copy()
   105  
   106  	p.mtx.RLock()
   107  	defer p.mtx.RUnlock()
   108  
   109  	cc, ok := p.conns[connsKey{e.Address(), e.NodeID()}]
   110  	if !ok {
   111  		return
   112  	}
   113  
   114  	trace.DriverOnConnAllow(
   115  		p.config.Trace(), &ctx,
   116  		stack.FunctionID(""),
   117  		e, cc.GetState(),
   118  	)(cc.Unban(ctx))
   119  }
   120  
   121  func (p *Pool) Take(context.Context) error {
   122  	atomic.AddInt64(&p.usages, 1)
   123  
   124  	return nil
   125  }
   126  
   127  func (p *Pool) Release(ctx context.Context) (finalErr error) {
   128  	onDone := trace.DriverOnPoolRelease(p.config.Trace(), &ctx, stack.FunctionID(""))
   129  	defer func() {
   130  		onDone(finalErr)
   131  	}()
   132  
   133  	if atomic.AddInt64(&p.usages, -1) > 0 {
   134  		return nil
   135  	}
   136  
   137  	close(p.done)
   138  
   139  	var conns []closer.Closer
   140  	p.mtx.WithRLock(func() {
   141  		conns = make([]closer.Closer, 0, len(p.conns))
   142  		for _, c := range p.conns {
   143  			conns = append(conns, c)
   144  		}
   145  	})
   146  
   147  	var (
   148  		errCh = make(chan error, len(conns))
   149  		wg    sync.WaitGroup
   150  	)
   151  
   152  	wg.Add(len(conns))
   153  	for _, c := range conns {
   154  		go func(c closer.Closer) {
   155  			defer wg.Done()
   156  			if err := c.Close(ctx); err != nil {
   157  				errCh <- err
   158  			}
   159  		}(c)
   160  	}
   161  	wg.Wait()
   162  	close(errCh)
   163  
   164  	issues := make([]error, 0, len(conns))
   165  	for err := range errCh {
   166  		issues = append(issues, err)
   167  	}
   168  
   169  	if len(issues) > 0 {
   170  		return xerrors.WithStackTrace(xerrors.NewWithIssues("connection pool close failed", issues...))
   171  	}
   172  
   173  	return nil
   174  }
   175  
   176  func (p *Pool) connParker(ctx context.Context, ttl, interval time.Duration) {
   177  	ticker := time.NewTicker(interval)
   178  	defer ticker.Stop()
   179  	for {
   180  		select {
   181  		case <-p.done:
   182  			return
   183  		case <-ticker.C:
   184  			for _, c := range p.collectConns() {
   185  				if time.Since(c.LastUsage()) > ttl {
   186  					switch c.GetState() {
   187  					case Online, Banned:
   188  						_ = c.park(ctx)
   189  					default:
   190  						// nop
   191  					}
   192  				}
   193  			}
   194  		}
   195  	}
   196  }
   197  
   198  func (p *Pool) collectConns() []*conn {
   199  	p.mtx.RLock()
   200  	defer p.mtx.RUnlock()
   201  	conns := make([]*conn, 0, len(p.conns))
   202  	for _, c := range p.conns {
   203  		conns = append(conns, c)
   204  	}
   205  
   206  	return conns
   207  }
   208  
   209  func NewPool(ctx context.Context, config Config) *Pool {
   210  	onDone := trace.DriverOnPoolNew(config.Trace(), &ctx, stack.FunctionID(""))
   211  	defer onDone()
   212  
   213  	p := &Pool{
   214  		usages: 1,
   215  		config: config,
   216  		opts:   config.GrpcDialOptions(),
   217  		conns:  make(map[connsKey]*conn),
   218  		done:   make(chan struct{}),
   219  	}
   220  	if ttl := config.ConnectionTTL(); ttl > 0 {
   221  		go p.connParker(xcontext.WithoutDeadline(ctx), ttl, ttl/2)
   222  	}
   223  
   224  	return p
   225  }