github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/conn/conn.go (about) 1 package conn 2 3 import ( 4 "context" 5 "fmt" 6 "sync" 7 "sync/atomic" 8 "time" 9 10 "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" 11 "google.golang.org/grpc" 12 "google.golang.org/grpc/connectivity" 13 "google.golang.org/grpc/metadata" 14 "google.golang.org/grpc/stats" 15 16 "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" 17 "github.com/ydb-platform/ydb-go-sdk/v3/internal/meta" 18 "github.com/ydb-platform/ydb-go-sdk/v3/internal/response" 19 "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" 20 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" 21 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 22 "github.com/ydb-platform/ydb-go-sdk/v3/trace" 23 ) 24 25 var ( 26 // errOperationNotReady specified error when operation is not ready 27 errOperationNotReady = xerrors.Wrap(fmt.Errorf("operation is not ready yet")) 28 29 // errClosedConnection specified error when connection are closed early 30 errClosedConnection = xerrors.Wrap(fmt.Errorf("connection closed early")) 31 32 // errUnavailableConnection specified error when connection are closed early 33 errUnavailableConnection = xerrors.Wrap(fmt.Errorf("connection unavailable")) 34 ) 35 36 type Conn interface { 37 grpc.ClientConnInterface 38 39 Endpoint() endpoint.Endpoint 40 41 LastUsage() time.Time 42 43 Ping(ctx context.Context) error 44 IsState(states ...State) bool 45 GetState() State 46 SetState(ctx context.Context, state State) State 47 Unban(ctx context.Context) State 48 } 49 50 type conn struct { 51 mtx sync.RWMutex 52 config Config // ro access 53 cc *grpc.ClientConn 54 done chan struct{} 55 endpoint endpoint.Endpoint // ro access 56 closed bool 57 state atomic.Uint32 58 lastUsage time.Time 59 onClose []func(*conn) 60 onTransportErrors []func(ctx context.Context, cc Conn, cause error) 61 } 62 63 func (c *conn) Address() string { 64 return c.endpoint.Address() 65 } 66 67 func (c *conn) Ping(ctx context.Context) error { 68 cc, err := c.realConn(ctx) 69 if err != nil { 70 return c.wrapError(err) 71 } 72 if !isAvailable(cc) { 73 return c.wrapError(errUnavailableConnection) 74 } 75 76 return nil 77 } 78 79 func (c *conn) LastUsage() time.Time { 80 c.mtx.RLock() 81 defer c.mtx.RUnlock() 82 83 return c.lastUsage 84 } 85 86 func (c *conn) IsState(states ...State) bool { 87 state := State(c.state.Load()) 88 for _, s := range states { 89 if s == state { 90 return true 91 } 92 } 93 94 return false 95 } 96 97 func (c *conn) park(ctx context.Context) (err error) { 98 onDone := trace.DriverOnConnPark( 99 c.config.Trace(), &ctx, 100 stack.FunctionID(""), 101 c.Endpoint(), 102 ) 103 defer func() { 104 onDone(err) 105 }() 106 107 c.mtx.Lock() 108 defer c.mtx.Unlock() 109 110 if c.closed { 111 return nil 112 } 113 114 if c.cc == nil { 115 return nil 116 } 117 118 err = c.close(ctx) 119 120 if err != nil { 121 return c.wrapError(err) 122 } 123 124 return nil 125 } 126 127 func (c *conn) NodeID() uint32 { 128 if c != nil { 129 return c.endpoint.NodeID() 130 } 131 132 return 0 133 } 134 135 func (c *conn) Endpoint() endpoint.Endpoint { 136 if c != nil { 137 return c.endpoint 138 } 139 140 return nil 141 } 142 143 func (c *conn) SetState(ctx context.Context, s State) State { 144 return c.setState(ctx, s) 145 } 146 147 func (c *conn) setState(ctx context.Context, s State) State { 148 if state := State(c.state.Swap(uint32(s))); state != s { 149 trace.DriverOnConnStateChange( 150 c.config.Trace(), &ctx, 151 stack.FunctionID(""), 152 c.endpoint.Copy(), state, 153 )(s) 154 } 155 156 return s 157 } 158 159 func (c *conn) Unban(ctx context.Context) State { 160 var newState State 161 c.mtx.RLock() 162 cc := c.cc 163 c.mtx.RUnlock() 164 if isAvailable(cc) { 165 newState = Online 166 } else { 167 newState = Offline 168 } 169 170 c.setState(ctx, newState) 171 172 return newState 173 } 174 175 func (c *conn) GetState() (s State) { 176 return State(c.state.Load()) 177 } 178 179 func (c *conn) realConn(ctx context.Context) (cc *grpc.ClientConn, err error) { 180 if c.isClosed() { 181 return nil, c.wrapError(errClosedConnection) 182 } 183 184 c.mtx.Lock() 185 defer c.mtx.Unlock() 186 187 if c.cc != nil { 188 return c.cc, nil 189 } 190 191 if dialTimeout := c.config.DialTimeout(); dialTimeout > 0 { 192 var cancel context.CancelFunc 193 ctx, cancel = xcontext.WithTimeout(ctx, dialTimeout) 194 defer cancel() 195 } 196 197 onDone := trace.DriverOnConnDial( 198 c.config.Trace(), &ctx, 199 stack.FunctionID(""), 200 c.endpoint.Copy(), 201 ) 202 defer func() { 203 onDone(err) 204 }() 205 206 // prepend "ydb" scheme for grpc dns-resolver to find the proper scheme 207 // three slashes in "ydb:///" is ok. It needs for good parse scheme in grpc resolver. 208 address := "ydb:///" + c.endpoint.Address() 209 210 cc, err = grpc.DialContext(ctx, address, append( 211 []grpc.DialOption{ 212 grpc.WithStatsHandler(statsHandler{}), 213 }, c.config.GrpcDialOptions()..., 214 )...) 215 if err != nil { 216 defer func() { 217 c.onTransportError(ctx, err) 218 }() 219 220 err = xerrors.Transport(err, 221 xerrors.WithAddress(address), 222 ) 223 224 return nil, c.wrapError( 225 xerrors.Retryable(err, 226 xerrors.WithName("realConn"), 227 ), 228 ) 229 } 230 231 c.cc = cc 232 c.setState(ctx, Online) 233 234 return c.cc, nil 235 } 236 237 func (c *conn) onTransportError(ctx context.Context, cause error) { 238 for _, onTransportError := range c.onTransportErrors { 239 onTransportError(ctx, c, cause) 240 } 241 } 242 243 func (c *conn) touchLastUsage() { 244 c.mtx.Lock() 245 defer c.mtx.Unlock() 246 c.lastUsage = time.Now() 247 } 248 249 func isAvailable(raw *grpc.ClientConn) bool { 250 return raw != nil && raw.GetState() == connectivity.Ready 251 } 252 253 // conn must be locked 254 func (c *conn) close(ctx context.Context) (err error) { 255 if c.cc == nil { 256 return nil 257 } 258 err = c.cc.Close() 259 c.cc = nil 260 c.setState(ctx, Offline) 261 262 return c.wrapError(err) 263 } 264 265 func (c *conn) isClosed() bool { 266 c.mtx.RLock() 267 defer c.mtx.RUnlock() 268 269 return c.closed 270 } 271 272 func (c *conn) Close(ctx context.Context) (err error) { 273 c.mtx.Lock() 274 defer c.mtx.Unlock() 275 276 if c.closed { 277 return nil 278 } 279 280 onDone := trace.DriverOnConnClose( 281 c.config.Trace(), &ctx, 282 stack.FunctionID(""), 283 c.Endpoint(), 284 ) 285 defer func() { 286 onDone(err) 287 }() 288 289 c.closed = true 290 291 err = c.close(ctx) 292 293 c.setState(ctx, Destroyed) 294 295 for _, onClose := range c.onClose { 296 onClose(c) 297 } 298 299 return c.wrapError(err) 300 } 301 302 func (c *conn) Invoke( 303 ctx context.Context, 304 method string, 305 req interface{}, 306 res interface{}, 307 opts ...grpc.CallOption, 308 ) (err error) { 309 var ( 310 opID string 311 issues []trace.Issue 312 useWrapping = UseWrapping(ctx) 313 onDone = trace.DriverOnConnInvoke( 314 c.config.Trace(), &ctx, 315 stack.FunctionID(""), 316 c.endpoint, trace.Method(method), 317 ) 318 cc *grpc.ClientConn 319 md = metadata.MD{} 320 ) 321 defer func() { 322 meta.CallTrailerCallback(ctx, md) 323 onDone(err, issues, opID, c.GetState(), md) 324 }() 325 326 cc, err = c.realConn(ctx) 327 if err != nil { 328 return c.wrapError(err) 329 } 330 331 c.touchLastUsage() 332 defer c.touchLastUsage() 333 334 ctx, traceID, err := meta.TraceID(ctx) 335 if err != nil { 336 return xerrors.WithStackTrace(err) 337 } 338 339 ctx, sentMark := markContext(meta.WithTraceID(ctx, traceID)) 340 341 err = cc.Invoke(ctx, method, req, res, append(opts, grpc.Trailer(&md))...) 342 if err != nil { 343 defer func() { 344 c.onTransportError(ctx, err) 345 }() 346 347 if useWrapping { 348 err = xerrors.Transport(err, 349 xerrors.WithAddress(c.Address()), 350 xerrors.WithTraceID(traceID), 351 ) 352 if sentMark.canRetry() { 353 return c.wrapError(xerrors.Retryable(err, xerrors.WithName("Invoke"))) 354 } 355 356 return c.wrapError(err) 357 } 358 359 return err 360 } 361 362 if o, ok := res.(response.Response); ok { 363 opID = o.GetOperation().GetId() 364 for _, issue := range o.GetOperation().GetIssues() { 365 issues = append(issues, issue) 366 } 367 if useWrapping { 368 switch { 369 case !o.GetOperation().GetReady(): 370 return c.wrapError(errOperationNotReady) 371 372 case o.GetOperation().GetStatus() != Ydb.StatusIds_SUCCESS: 373 return c.wrapError( 374 xerrors.Operation( 375 xerrors.FromOperation(o.GetOperation()), 376 xerrors.WithAddress(c.Address()), 377 xerrors.WithTraceID(traceID), 378 ), 379 ) 380 } 381 } 382 } 383 384 return err 385 } 386 387 func (c *conn) NewStream( 388 ctx context.Context, 389 desc *grpc.StreamDesc, 390 method string, 391 opts ...grpc.CallOption, 392 ) (_ grpc.ClientStream, err error) { 393 var ( 394 streamRecv = trace.DriverOnConnNewStream( 395 c.config.Trace(), &ctx, 396 stack.FunctionID(""), 397 c.endpoint.Copy(), trace.Method(method), 398 ) 399 useWrapping = UseWrapping(ctx) 400 cc *grpc.ClientConn 401 s grpc.ClientStream 402 ) 403 404 defer func() { 405 if err != nil { 406 streamRecv(err)(err, c.GetState(), metadata.MD{}) 407 } 408 }() 409 410 var cancel context.CancelFunc 411 ctx, cancel = xcontext.WithCancel(ctx) 412 413 defer func() { 414 if err != nil { 415 cancel() 416 } 417 }() 418 419 cc, err = c.realConn(ctx) 420 if err != nil { 421 return nil, c.wrapError(err) 422 } 423 424 c.touchLastUsage() 425 defer c.touchLastUsage() 426 427 ctx, traceID, err := meta.TraceID(ctx) 428 if err != nil { 429 return nil, xerrors.WithStackTrace(err) 430 } 431 432 ctx, sentMark := markContext(meta.WithTraceID(ctx, traceID)) 433 434 s, err = cc.NewStream(ctx, desc, method, opts...) 435 if err != nil { 436 defer func() { 437 c.onTransportError(ctx, err) 438 }() 439 440 if useWrapping { 441 err = xerrors.Transport(err, 442 xerrors.WithAddress(c.Address()), 443 xerrors.WithTraceID(traceID), 444 ) 445 if sentMark.canRetry() { 446 return s, c.wrapError(xerrors.Retryable(err, xerrors.WithName("NewStream"))) 447 } 448 449 return s, c.wrapError(err) 450 } 451 452 return s, err 453 } 454 455 return &grpcClientStream{ 456 ClientStream: s, 457 c: c, 458 wrapping: useWrapping, 459 traceID: traceID, 460 sentMark: sentMark, 461 onDone: func(ctx context.Context, md metadata.MD) { 462 cancel() 463 meta.CallTrailerCallback(ctx, md) 464 }, 465 recv: streamRecv, 466 }, nil 467 } 468 469 func (c *conn) wrapError(err error) error { 470 if err == nil { 471 return nil 472 } 473 nodeErr := newConnError(c.endpoint.NodeID(), c.endpoint.Address(), err) 474 475 return xerrors.WithStackTrace(nodeErr, xerrors.WithSkipDepth(1)) 476 } 477 478 type option func(c *conn) 479 480 func withOnClose(onClose func(*conn)) option { 481 return func(c *conn) { 482 if onClose != nil { 483 c.onClose = append(c.onClose, onClose) 484 } 485 } 486 } 487 488 func withOnTransportError(onTransportError func(ctx context.Context, cc Conn, cause error)) option { 489 return func(c *conn) { 490 if onTransportError != nil { 491 c.onTransportErrors = append(c.onTransportErrors, onTransportError) 492 } 493 } 494 } 495 496 func newConn(e endpoint.Endpoint, config Config, opts ...option) *conn { 497 c := &conn{ 498 endpoint: e, 499 config: config, 500 done: make(chan struct{}), 501 } 502 c.state.Store(uint32(Created)) 503 for _, o := range opts { 504 if o != nil { 505 o(c) 506 } 507 } 508 509 return c 510 } 511 512 func New(e endpoint.Endpoint, config Config, opts ...option) Conn { 513 return newConn(e, config, opts...) 514 } 515 516 var _ stats.Handler = statsHandler{} 517 518 type statsHandler struct{} 519 520 func (statsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { 521 return ctx 522 } 523 524 func (statsHandler) HandleRPC(ctx context.Context, rpcStats stats.RPCStats) { 525 switch rpcStats.(type) { 526 case *stats.Begin, *stats.End: 527 default: 528 getContextMark(ctx).markDirty() 529 } 530 } 531 532 func (statsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { 533 return ctx 534 } 535 536 func (statsHandler) HandleConn(context.Context, stats.ConnStats) {} 537 538 type ctxHandleRPCKey struct{} 539 540 var rpcKey = ctxHandleRPCKey{} 541 542 func markContext(ctx context.Context) (context.Context, *modificationMark) { 543 mark := &modificationMark{} 544 545 return context.WithValue(ctx, rpcKey, mark), mark 546 } 547 548 func getContextMark(ctx context.Context) *modificationMark { 549 v := ctx.Value(rpcKey) 550 if v == nil { 551 return &modificationMark{} 552 } 553 554 return v.(*modificationMark) 555 } 556 557 type modificationMark struct { 558 dirty atomic.Bool 559 } 560 561 func (m *modificationMark) canRetry() bool { 562 return !m.dirty.Load() 563 } 564 565 func (m *modificationMark) markDirty() { 566 m.dirty.Store(true) 567 }