github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/conn/conn.go (about) 1 package conn 2 3 import ( 4 "context" 5 "fmt" 6 "sync" 7 "sync/atomic" 8 "time" 9 10 "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" 11 "google.golang.org/grpc" 12 "google.golang.org/grpc/connectivity" 13 "google.golang.org/grpc/metadata" 14 "google.golang.org/grpc/stats" 15 16 "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" 17 "github.com/ydb-platform/ydb-go-sdk/v3/internal/meta" 18 "github.com/ydb-platform/ydb-go-sdk/v3/internal/operation" 19 "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" 20 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" 21 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 22 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync" 23 "github.com/ydb-platform/ydb-go-sdk/v3/trace" 24 ) 25 26 var ( 27 // errOperationNotReady specified error when operation is not ready 28 errOperationNotReady = xerrors.Wrap(fmt.Errorf("operation is not ready yet")) 29 30 // errClosedConnection specified error when connection are closed early 31 errClosedConnection = xerrors.Wrap(fmt.Errorf("connection closed early")) 32 33 // errUnavailableConnection specified error when connection are closed early 34 errUnavailableConnection = xerrors.Wrap(fmt.Errorf("connection unavailable")) 35 ) 36 37 type Conn interface { 38 grpc.ClientConnInterface 39 40 Endpoint() endpoint.Endpoint 41 42 LastUsage() time.Time 43 44 Ping(ctx context.Context) error 45 IsState(states ...State) bool 46 GetState() State 47 SetState(ctx context.Context, state State) State 48 Unban(ctx context.Context) State 49 } 50 51 type conn struct { 52 mtx sync.RWMutex 53 config Config // ro access 54 grpcConn *grpc.ClientConn 55 done chan struct{} 56 endpoint endpoint.Endpoint // ro access 57 closed bool 58 state atomic.Uint32 59 childStreams *xcontext.CancelsGuard 60 lastUsage xsync.LastUsage 61 onClose []func(*conn) 62 onTransportErrors []func(ctx context.Context, cc Conn, cause error) 63 } 64 65 func (c *conn) Address() string { 66 return c.endpoint.Address() 67 } 68 69 func (c *conn) Ping(ctx context.Context) error { 70 cc, err := c.realConn(ctx) 71 if err != nil { 72 return xerrors.WithStackTrace(err) 73 } 74 if !isAvailable(cc) { 75 return xerrors.WithStackTrace(errUnavailableConnection) 76 } 77 78 return nil 79 } 80 81 func (c *conn) LastUsage() time.Time { 82 c.mtx.RLock() 83 defer c.mtx.RUnlock() 84 85 return c.lastUsage.Get() 86 } 87 88 func (c *conn) IsState(states ...State) bool { 89 state := State(c.state.Load()) 90 for _, s := range states { 91 if s == state { 92 return true 93 } 94 } 95 96 return false 97 } 98 99 func (c *conn) NodeID() uint32 { 100 if c != nil { 101 return c.endpoint.NodeID() 102 } 103 104 return 0 105 } 106 107 func (c *conn) park(ctx context.Context) (err error) { 108 onDone := trace.DriverOnConnPark( 109 c.config.Trace(), &ctx, 110 stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/conn.(*conn).park"), 111 c.Endpoint(), 112 ) 113 defer func() { 114 onDone(err) 115 }() 116 117 c.mtx.Lock() 118 defer c.mtx.Unlock() 119 120 if c.closed { 121 return nil 122 } 123 124 if c.grpcConn == nil { 125 return nil 126 } 127 128 err = c.close(ctx) 129 if err != nil { 130 return xerrors.WithStackTrace(err) 131 } 132 133 return nil 134 } 135 136 func (c *conn) Endpoint() endpoint.Endpoint { 137 if c != nil { 138 return c.endpoint 139 } 140 141 return nil 142 } 143 144 func (c *conn) SetState(ctx context.Context, s State) State { 145 return c.setState(ctx, s) 146 } 147 148 func (c *conn) setState(ctx context.Context, s State) State { 149 if state := State(c.state.Swap(uint32(s))); state != s { 150 trace.DriverOnConnStateChange( 151 c.config.Trace(), &ctx, 152 stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/conn.(*conn).setState"), 153 c.endpoint.Copy(), state, 154 )(s) 155 } 156 157 return s 158 } 159 160 func (c *conn) Unban(ctx context.Context) State { 161 var newState State 162 c.mtx.RLock() 163 cc := c.grpcConn //nolint:ifshort 164 c.mtx.RUnlock() 165 if isAvailable(cc) { 166 newState = Online 167 } else { 168 newState = Offline 169 } 170 171 c.setState(ctx, newState) 172 173 return newState 174 } 175 176 func (c *conn) GetState() (s State) { 177 return State(c.state.Load()) 178 } 179 180 func makeDialOption(overrideHost string) []grpc.DialOption { 181 dialOption := []grpc.DialOption{ 182 grpc.WithStatsHandler(statsHandler{}), 183 } 184 185 if len(overrideHost) != 0 { 186 dialOption = append(dialOption, grpc.WithAuthority(overrideHost)) 187 } 188 189 return dialOption 190 } 191 192 func (c *conn) realConn(ctx context.Context) (cc *grpc.ClientConn, err error) { 193 if c.isClosed() { 194 return nil, xerrors.WithStackTrace(errClosedConnection) 195 } 196 197 c.mtx.Lock() 198 defer c.mtx.Unlock() 199 200 if c.grpcConn != nil { 201 return c.grpcConn, nil 202 } 203 204 if dialTimeout := c.config.DialTimeout(); dialTimeout > 0 { 205 var cancel context.CancelFunc 206 ctx, cancel = xcontext.WithTimeout(ctx, dialTimeout) 207 defer cancel() 208 } 209 210 onDone := trace.DriverOnConnDial( 211 c.config.Trace(), &ctx, 212 stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/conn.(*conn).realConn"), 213 c.endpoint.Copy(), 214 ) 215 defer func() { 216 onDone(err) 217 }() 218 219 // prepend "ydb" scheme for grpc dns-resolver to find the proper scheme 220 // three slashes in "ydb:///" is ok. It needs for good parse scheme in grpc resolver. 221 address := "ydb:///" + c.endpoint.Address() 222 223 dialOption := makeDialOption(c.endpoint.OverrideHost()) 224 225 cc, err = grpc.DialContext(ctx, address, append( //nolint:staticcheck,nolintlint 226 dialOption, 227 c.config.GrpcDialOptions()..., 228 )...) 229 if err != nil { 230 if xerrors.IsContextError(err) { 231 return nil, xerrors.WithStackTrace(err) 232 } 233 234 defer func() { 235 c.onTransportError(ctx, err) 236 }() 237 238 return nil, xerrors.WithStackTrace( 239 xerrors.Retryable( 240 xerrors.Transport(err), 241 xerrors.WithName("realConn"), 242 ), 243 ) 244 } 245 246 c.grpcConn = cc 247 c.setState(ctx, Online) 248 249 return c.grpcConn, nil 250 } 251 252 func (c *conn) onTransportError(ctx context.Context, cause error) { 253 for _, onTransportError := range c.onTransportErrors { 254 onTransportError(ctx, c, cause) 255 } 256 } 257 258 func isAvailable(raw *grpc.ClientConn) bool { 259 return raw != nil && raw.GetState() == connectivity.Ready 260 } 261 262 // conn must be locked 263 func (c *conn) close(ctx context.Context) (err error) { 264 if c.grpcConn == nil { 265 return nil 266 } 267 268 defer func() { 269 c.grpcConn = nil 270 c.setState(ctx, Offline) 271 }() 272 273 err = c.grpcConn.Close() 274 if err == nil || !UseWrapping(ctx) { 275 return err 276 } 277 278 return xerrors.WithStackTrace(err) 279 } 280 281 func (c *conn) isClosed() bool { 282 c.mtx.RLock() 283 defer c.mtx.RUnlock() 284 285 return c.closed 286 } 287 288 func (c *conn) Close(ctx context.Context) (err error) { 289 c.mtx.Lock() 290 defer c.mtx.Unlock() 291 292 if c.closed { 293 return nil 294 } 295 296 onDone := trace.DriverOnConnClose( 297 c.config.Trace(), &ctx, 298 stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/conn.(*conn).Close"), 299 c.Endpoint(), 300 ) 301 defer func() { 302 c.closed = true 303 304 c.setState(ctx, Destroyed) 305 306 for _, onClose := range c.onClose { 307 onClose(c) 308 } 309 310 onDone(err) 311 }() 312 313 err = c.close(ctx) 314 315 if !UseWrapping(ctx) { 316 return err 317 } 318 319 return xerrors.WithStackTrace(xerrors.Transport(err, 320 xerrors.WithAddress(c.Address()), 321 xerrors.WithNodeID(c.NodeID()), 322 )) 323 } 324 325 var onTransportErrorStub = func(ctx context.Context, err error) {} 326 327 func replyWrapper(reply any) (opID string, issues []trace.Issue) { 328 switch t := reply.(type) { 329 case operation.Response: 330 opID = t.GetOperation().GetId() 331 for _, issue := range t.GetOperation().GetIssues() { 332 issues = append(issues, issue) 333 } 334 case operation.Status: 335 for _, issue := range t.GetIssues() { 336 issues = append(issues, issue) 337 } 338 } 339 340 return opID, issues 341 } 342 343 //nolint:funlen 344 func invoke( 345 ctx context.Context, 346 method string, 347 req, reply any, 348 cc grpc.ClientConnInterface, 349 onTransportError func(context.Context, error), 350 address string, 351 nodeID uint32, 352 opts ...grpc.CallOption, 353 ) ( 354 opID string, 355 issues []trace.Issue, 356 _ error, 357 ) { 358 useWrapping := UseWrapping(ctx) 359 360 ctx, traceID, err := meta.TraceID(ctx) 361 if err != nil { 362 return opID, issues, xerrors.WithStackTrace(err) 363 } 364 365 ctx, sentMark := markContext(meta.WithTraceID(ctx, traceID)) 366 367 if onTransportError == nil { 368 onTransportError = onTransportErrorStub 369 } 370 371 err = cc.Invoke(ctx, method, req, reply, opts...) 372 if err != nil { 373 if xerrors.IsContextError(err) { 374 return opID, issues, xerrors.WithStackTrace(err) 375 } 376 377 defer onTransportError(ctx, err) 378 379 if !useWrapping { 380 return opID, issues, err 381 } 382 383 if sentMark.canRetry() { 384 return opID, issues, xerrors.WithStackTrace(xerrors.Retryable( 385 xerrors.Transport(err, 386 xerrors.WithTraceID(traceID), 387 ), 388 xerrors.WithName("Invoke"), 389 )) 390 } 391 392 return opID, issues, xerrors.WithStackTrace(xerrors.Transport(err, 393 xerrors.WithAddress(address), 394 xerrors.WithNodeID(nodeID), 395 xerrors.WithTraceID(traceID), 396 )) 397 } 398 399 opID, issues = replyWrapper(reply) 400 401 if !useWrapping { 402 return opID, issues, nil 403 } 404 405 switch t := reply.(type) { 406 case operation.Response: 407 switch { 408 case !t.GetOperation().GetReady(): 409 return opID, issues, xerrors.WithStackTrace(errOperationNotReady) 410 411 case t.GetOperation().GetStatus() != Ydb.StatusIds_SUCCESS: 412 return opID, issues, xerrors.WithStackTrace( 413 xerrors.Operation( 414 xerrors.FromOperation(t.GetOperation()), 415 xerrors.WithAddress(address), 416 xerrors.WithNodeID(nodeID), 417 xerrors.WithTraceID(traceID), 418 ), 419 ) 420 } 421 case operation.Status: 422 if t.GetStatus() != Ydb.StatusIds_SUCCESS { 423 return opID, issues, xerrors.WithStackTrace( 424 xerrors.Operation( 425 xerrors.FromOperation(t), 426 xerrors.WithAddress(address), 427 xerrors.WithNodeID(nodeID), 428 xerrors.WithTraceID(traceID), 429 ), 430 ) 431 } 432 } 433 434 return opID, issues, nil 435 } 436 437 func (c *conn) Invoke( 438 ctx context.Context, 439 method string, 440 req interface{}, 441 res interface{}, 442 opts ...grpc.CallOption, 443 ) (err error) { 444 var ( 445 opID string 446 issues []trace.Issue 447 onDone = trace.DriverOnConnInvoke( 448 c.config.Trace(), &ctx, 449 stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/conn.(*conn).Invoke"), 450 c.endpoint, trace.Method(method), 451 ) 452 cc *grpc.ClientConn 453 md = metadata.MD{} 454 ) 455 defer func() { 456 meta.CallTrailerCallback(ctx, md) 457 onDone(err, issues, opID, c.GetState(), md) 458 }() 459 460 cc, err = c.realConn(ctx) 461 if err != nil { 462 return xerrors.WithStackTrace(err) 463 } 464 465 stop := c.lastUsage.Start() 466 defer stop() 467 468 opID, issues, err = invoke( 469 ctx, 470 method, 471 req, 472 res, 473 cc, 474 c.onTransportError, 475 c.Address(), 476 c.NodeID(), 477 append(opts, grpc.Trailer(&md))..., 478 ) 479 480 return err 481 } 482 483 //nolint:funlen 484 func (c *conn) NewStream( 485 ctx context.Context, 486 desc *grpc.StreamDesc, 487 method string, 488 opts ...grpc.CallOption, 489 ) (_ grpc.ClientStream, finalErr error) { 490 var ( 491 onDone = trace.DriverOnConnNewStream( 492 c.config.Trace(), &ctx, 493 stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/conn.(*conn).NewStream"), 494 c.endpoint.Copy(), trace.Method(method), 495 ) 496 useWrapping = UseWrapping(ctx) 497 ) 498 499 defer func() { 500 onDone(finalErr, c.GetState()) 501 }() 502 503 cc, err := c.realConn(ctx) 504 if err != nil { 505 return nil, xerrors.WithStackTrace(err) 506 } 507 508 stop := c.lastUsage.Start() 509 defer stop() 510 511 ctx, traceID, err := meta.TraceID(ctx) 512 if err != nil { 513 return nil, xerrors.WithStackTrace(err) 514 } 515 516 ctx, sentMark := markContext(meta.WithTraceID(ctx, traceID)) 517 518 ctx, cancel := c.childStreams.WithCancel(ctx) 519 defer func() { 520 if finalErr != nil { 521 cancel() 522 } 523 }() 524 525 s := &grpcClientStream{ 526 parentConn: c, 527 streamCtx: ctx, 528 streamCancel: cancel, 529 wrapping: useWrapping, 530 traceID: traceID, 531 sentMark: sentMark, 532 } 533 534 s.stream, err = cc.NewStream(ctx, desc, method, append(opts, grpc.OnFinish(s.finish))...) 535 if err != nil { 536 if xerrors.IsContextError(err) { 537 return nil, xerrors.WithStackTrace(err) 538 } 539 540 defer func() { 541 c.onTransportError(ctx, err) 542 }() 543 544 if !useWrapping { 545 return nil, err 546 } 547 548 if sentMark.canRetry() { 549 return nil, xerrors.WithStackTrace(xerrors.Retryable( 550 xerrors.Transport(err, 551 xerrors.WithTraceID(traceID), 552 ), 553 xerrors.WithName("NewStream"), 554 )) 555 } 556 557 return nil, xerrors.WithStackTrace(xerrors.Transport(err, 558 xerrors.WithAddress(c.Address()), 559 xerrors.WithTraceID(traceID), 560 )) 561 } 562 563 return s, nil 564 } 565 566 type option func(c *conn) 567 568 func withOnClose(onClose func(*conn)) option { 569 return func(c *conn) { 570 if onClose != nil { 571 c.onClose = append(c.onClose, onClose) 572 } 573 } 574 } 575 576 func withOnTransportError(onTransportError func(ctx context.Context, cc Conn, cause error)) option { 577 return func(c *conn) { 578 if onTransportError != nil { 579 c.onTransportErrors = append(c.onTransportErrors, onTransportError) 580 } 581 } 582 } 583 584 func newConn(e endpoint.Endpoint, config Config, opts ...option) *conn { 585 c := &conn{ 586 endpoint: e, 587 config: config, 588 done: make(chan struct{}), 589 lastUsage: xsync.NewLastUsage(), 590 childStreams: xcontext.NewCancelsGuard(), 591 onClose: []func(*conn){ 592 func(c *conn) { 593 c.childStreams.Cancel() 594 }, 595 }, 596 } 597 c.state.Store(uint32(Created)) 598 for _, opt := range opts { 599 if opt != nil { 600 opt(c) 601 } 602 } 603 604 return c 605 } 606 607 func New(e endpoint.Endpoint, config Config, opts ...option) Conn { 608 return newConn(e, config, opts...) 609 } 610 611 var _ stats.Handler = statsHandler{} 612 613 type statsHandler struct{} 614 615 func (statsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { 616 return ctx 617 } 618 619 func (statsHandler) HandleRPC(ctx context.Context, rpcStats stats.RPCStats) { 620 switch rpcStats.(type) { 621 case *stats.Begin, *stats.End: 622 default: 623 getContextMark(ctx).markDirty() 624 } 625 } 626 627 func (statsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { 628 return ctx 629 } 630 631 func (statsHandler) HandleConn(context.Context, stats.ConnStats) {} 632 633 type ctxHandleRPCKey struct{} 634 635 var rpcKey = ctxHandleRPCKey{} 636 637 func markContext(ctx context.Context) (context.Context, *modificationMark) { 638 mark := &modificationMark{} 639 640 return context.WithValue(ctx, rpcKey, mark), mark 641 } 642 643 func getContextMark(ctx context.Context) *modificationMark { 644 v := ctx.Value(rpcKey) 645 if v == nil { 646 return &modificationMark{} 647 } 648 649 val, ok := v.(*modificationMark) 650 if !ok { 651 panic(fmt.Sprintf("unsupported type conversion from %T to *modificationMark", val)) 652 } 653 654 return val 655 } 656 657 type modificationMark struct { 658 dirty atomic.Bool 659 } 660 661 func (m *modificationMark) canRetry() bool { 662 return !m.dirty.Load() 663 } 664 665 func (m *modificationMark) markDirty() { 666 m.dirty.Store(true) 667 }