github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/conn/pool.go (about) 1 package conn 2 3 import ( 4 "context" 5 "sync" 6 "sync/atomic" 7 "time" 8 9 "google.golang.org/grpc" 10 11 "github.com/ydb-platform/ydb-go-sdk/v3/internal/closer" 12 "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" 13 "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" 14 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" 15 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 16 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync" 17 "github.com/ydb-platform/ydb-go-sdk/v3/trace" 18 ) 19 20 type connsKey struct { 21 address string 22 nodeID uint32 23 } 24 25 type Pool struct { 26 usages int64 27 config Config 28 mtx xsync.RWMutex 29 opts []grpc.DialOption 30 conns map[connsKey]*conn 31 done chan struct{} 32 } 33 34 func (p *Pool) Get(endpoint endpoint.Endpoint) Conn { 35 p.mtx.Lock() 36 defer p.mtx.Unlock() 37 38 var ( 39 address = endpoint.Address() 40 cc *conn 41 has bool 42 ) 43 44 key := connsKey{address, endpoint.NodeID()} 45 46 if cc, has = p.conns[key]; has { 47 return cc 48 } 49 50 cc = newConn( 51 endpoint, 52 p.config, 53 withOnClose(p.remove), 54 withOnTransportError(p.Ban), 55 ) 56 57 p.conns[key] = cc 58 59 return cc 60 } 61 62 func (p *Pool) remove(c *conn) { 63 p.mtx.Lock() 64 defer p.mtx.Unlock() 65 delete(p.conns, connsKey{c.Endpoint().Address(), c.Endpoint().NodeID()}) 66 } 67 68 func (p *Pool) isClosed() bool { 69 select { 70 case <-p.done: 71 return true 72 default: 73 return false 74 } 75 } 76 77 func (p *Pool) Ban(ctx context.Context, cc Conn, cause error) { 78 if p.isClosed() { 79 return 80 } 81 82 e := cc.Endpoint().Copy() 83 84 p.mtx.RLock() 85 defer p.mtx.RUnlock() 86 87 cc, ok := p.conns[connsKey{e.Address(), e.NodeID()}] 88 if !ok { 89 return 90 } 91 92 trace.DriverOnConnBan( 93 p.config.Trace(), &ctx, 94 stack.FunctionID(""), 95 e, cc.GetState(), cause, 96 )(cc.SetState(ctx, Banned)) 97 } 98 99 func (p *Pool) Allow(ctx context.Context, cc Conn) { 100 if p.isClosed() { 101 return 102 } 103 104 e := cc.Endpoint().Copy() 105 106 p.mtx.RLock() 107 defer p.mtx.RUnlock() 108 109 cc, ok := p.conns[connsKey{e.Address(), e.NodeID()}] 110 if !ok { 111 return 112 } 113 114 trace.DriverOnConnAllow( 115 p.config.Trace(), &ctx, 116 stack.FunctionID(""), 117 e, cc.GetState(), 118 )(cc.Unban(ctx)) 119 } 120 121 func (p *Pool) Take(context.Context) error { 122 atomic.AddInt64(&p.usages, 1) 123 124 return nil 125 } 126 127 func (p *Pool) Release(ctx context.Context) (finalErr error) { 128 onDone := trace.DriverOnPoolRelease(p.config.Trace(), &ctx, stack.FunctionID("")) 129 defer func() { 130 onDone(finalErr) 131 }() 132 133 if atomic.AddInt64(&p.usages, -1) > 0 { 134 return nil 135 } 136 137 close(p.done) 138 139 var conns []closer.Closer 140 p.mtx.WithRLock(func() { 141 conns = make([]closer.Closer, 0, len(p.conns)) 142 for _, c := range p.conns { 143 conns = append(conns, c) 144 } 145 }) 146 147 var ( 148 errCh = make(chan error, len(conns)) 149 wg sync.WaitGroup 150 ) 151 152 wg.Add(len(conns)) 153 for _, c := range conns { 154 go func(c closer.Closer) { 155 defer wg.Done() 156 if err := c.Close(ctx); err != nil { 157 errCh <- err 158 } 159 }(c) 160 } 161 wg.Wait() 162 close(errCh) 163 164 issues := make([]error, 0, len(conns)) 165 for err := range errCh { 166 issues = append(issues, err) 167 } 168 169 if len(issues) > 0 { 170 return xerrors.WithStackTrace(xerrors.NewWithIssues("connection pool close failed", issues...)) 171 } 172 173 return nil 174 } 175 176 func (p *Pool) connParker(ctx context.Context, ttl, interval time.Duration) { 177 ticker := time.NewTicker(interval) 178 defer ticker.Stop() 179 for { 180 select { 181 case <-p.done: 182 return 183 case <-ticker.C: 184 for _, c := range p.collectConns() { 185 if time.Since(c.LastUsage()) > ttl { 186 switch c.GetState() { 187 case Online, Banned: 188 _ = c.park(ctx) 189 default: 190 // nop 191 } 192 } 193 } 194 } 195 } 196 } 197 198 func (p *Pool) collectConns() []*conn { 199 p.mtx.RLock() 200 defer p.mtx.RUnlock() 201 conns := make([]*conn, 0, len(p.conns)) 202 for _, c := range p.conns { 203 conns = append(conns, c) 204 } 205 206 return conns 207 } 208 209 func NewPool(ctx context.Context, config Config) *Pool { 210 onDone := trace.DriverOnPoolNew(config.Trace(), &ctx, stack.FunctionID("")) 211 defer onDone() 212 213 p := &Pool{ 214 usages: 1, 215 config: config, 216 opts: config.GrpcDialOptions(), 217 conns: make(map[connsKey]*conn), 218 done: make(chan struct{}), 219 } 220 if ttl := config.ConnectionTTL(); ttl > 0 { 221 go p.connParker(xcontext.WithoutDeadline(ctx), ttl, ttl/2) 222 } 223 224 return p 225 }