github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/conn/conn.go (about)

     1  package conn
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"sync/atomic"
     8  	"time"
     9  
    10  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
    11  	"google.golang.org/grpc"
    12  	"google.golang.org/grpc/connectivity"
    13  	"google.golang.org/grpc/metadata"
    14  	"google.golang.org/grpc/stats"
    15  
    16  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
    17  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/meta"
    18  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/operation"
    19  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
    20  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
    21  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    22  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync"
    23  	"github.com/ydb-platform/ydb-go-sdk/v3/trace"
    24  )
    25  
    26  var (
    27  	// errOperationNotReady specified error when operation is not ready
    28  	errOperationNotReady = xerrors.Wrap(fmt.Errorf("operation is not ready yet"))
    29  
    30  	// errClosedConnection specified error when connection are closed early
    31  	errClosedConnection = xerrors.Wrap(fmt.Errorf("connection closed early"))
    32  
    33  	// errUnavailableConnection specified error when connection are closed early
    34  	errUnavailableConnection = xerrors.Wrap(fmt.Errorf("connection unavailable"))
    35  )
    36  
    37  type Conn interface {
    38  	grpc.ClientConnInterface
    39  
    40  	Endpoint() endpoint.Endpoint
    41  
    42  	LastUsage() time.Time
    43  
    44  	Ping(ctx context.Context) error
    45  	IsState(states ...State) bool
    46  	GetState() State
    47  	SetState(ctx context.Context, state State) State
    48  	Unban(ctx context.Context) State
    49  }
    50  
    51  type conn struct {
    52  	mtx               sync.RWMutex
    53  	config            Config // ro access
    54  	grpcConn          *grpc.ClientConn
    55  	done              chan struct{}
    56  	endpoint          endpoint.Endpoint // ro access
    57  	closed            bool
    58  	state             atomic.Uint32
    59  	childStreams      *xcontext.CancelsGuard
    60  	lastUsage         xsync.LastUsage
    61  	onClose           []func(*conn)
    62  	onTransportErrors []func(ctx context.Context, cc Conn, cause error)
    63  }
    64  
    65  func (c *conn) Address() string {
    66  	return c.endpoint.Address()
    67  }
    68  
    69  func (c *conn) Ping(ctx context.Context) error {
    70  	cc, err := c.realConn(ctx)
    71  	if err != nil {
    72  		return xerrors.WithStackTrace(err)
    73  	}
    74  	if !isAvailable(cc) {
    75  		return xerrors.WithStackTrace(errUnavailableConnection)
    76  	}
    77  
    78  	return nil
    79  }
    80  
    81  func (c *conn) LastUsage() time.Time {
    82  	c.mtx.RLock()
    83  	defer c.mtx.RUnlock()
    84  
    85  	return c.lastUsage.Get()
    86  }
    87  
    88  func (c *conn) IsState(states ...State) bool {
    89  	state := State(c.state.Load())
    90  	for _, s := range states {
    91  		if s == state {
    92  			return true
    93  		}
    94  	}
    95  
    96  	return false
    97  }
    98  
    99  func (c *conn) NodeID() uint32 {
   100  	if c != nil {
   101  		return c.endpoint.NodeID()
   102  	}
   103  
   104  	return 0
   105  }
   106  
   107  func (c *conn) park(ctx context.Context) (err error) {
   108  	onDone := trace.DriverOnConnPark(
   109  		c.config.Trace(), &ctx,
   110  		stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/conn.(*conn).park"),
   111  		c.Endpoint(),
   112  	)
   113  	defer func() {
   114  		onDone(err)
   115  	}()
   116  
   117  	c.mtx.Lock()
   118  	defer c.mtx.Unlock()
   119  
   120  	if c.closed {
   121  		return nil
   122  	}
   123  
   124  	if c.grpcConn == nil {
   125  		return nil
   126  	}
   127  
   128  	err = c.close(ctx)
   129  	if err != nil {
   130  		return xerrors.WithStackTrace(err)
   131  	}
   132  
   133  	return nil
   134  }
   135  
   136  func (c *conn) Endpoint() endpoint.Endpoint {
   137  	if c != nil {
   138  		return c.endpoint
   139  	}
   140  
   141  	return nil
   142  }
   143  
   144  func (c *conn) SetState(ctx context.Context, s State) State {
   145  	return c.setState(ctx, s)
   146  }
   147  
   148  func (c *conn) setState(ctx context.Context, s State) State {
   149  	if state := State(c.state.Swap(uint32(s))); state != s {
   150  		trace.DriverOnConnStateChange(
   151  			c.config.Trace(), &ctx,
   152  			stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/conn.(*conn).setState"),
   153  			c.endpoint.Copy(), state,
   154  		)(s)
   155  	}
   156  
   157  	return s
   158  }
   159  
   160  func (c *conn) Unban(ctx context.Context) State {
   161  	var newState State
   162  	c.mtx.RLock()
   163  	cc := c.grpcConn //nolint:ifshort
   164  	c.mtx.RUnlock()
   165  	if isAvailable(cc) {
   166  		newState = Online
   167  	} else {
   168  		newState = Offline
   169  	}
   170  
   171  	c.setState(ctx, newState)
   172  
   173  	return newState
   174  }
   175  
   176  func (c *conn) GetState() (s State) {
   177  	return State(c.state.Load())
   178  }
   179  
   180  func makeDialOption(overrideHost string) []grpc.DialOption {
   181  	dialOption := []grpc.DialOption{
   182  		grpc.WithStatsHandler(statsHandler{}),
   183  	}
   184  
   185  	if len(overrideHost) != 0 {
   186  		dialOption = append(dialOption, grpc.WithAuthority(overrideHost))
   187  	}
   188  
   189  	return dialOption
   190  }
   191  
   192  func (c *conn) realConn(ctx context.Context) (cc *grpc.ClientConn, err error) {
   193  	if c.isClosed() {
   194  		return nil, xerrors.WithStackTrace(errClosedConnection)
   195  	}
   196  
   197  	c.mtx.Lock()
   198  	defer c.mtx.Unlock()
   199  
   200  	if c.grpcConn != nil {
   201  		return c.grpcConn, nil
   202  	}
   203  
   204  	if dialTimeout := c.config.DialTimeout(); dialTimeout > 0 {
   205  		var cancel context.CancelFunc
   206  		ctx, cancel = xcontext.WithTimeout(ctx, dialTimeout)
   207  		defer cancel()
   208  	}
   209  
   210  	onDone := trace.DriverOnConnDial(
   211  		c.config.Trace(), &ctx,
   212  		stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/conn.(*conn).realConn"),
   213  		c.endpoint.Copy(),
   214  	)
   215  	defer func() {
   216  		onDone(err)
   217  	}()
   218  
   219  	// prepend "ydb" scheme for grpc dns-resolver to find the proper scheme
   220  	// three slashes in "ydb:///" is ok. It needs for good parse scheme in grpc resolver.
   221  	address := "ydb:///" + c.endpoint.Address()
   222  
   223  	dialOption := makeDialOption(c.endpoint.OverrideHost())
   224  
   225  	cc, err = grpc.DialContext(ctx, address, append( //nolint:staticcheck,nolintlint
   226  		dialOption,
   227  		c.config.GrpcDialOptions()...,
   228  	)...)
   229  	if err != nil {
   230  		if xerrors.IsContextError(err) {
   231  			return nil, xerrors.WithStackTrace(err)
   232  		}
   233  
   234  		defer func() {
   235  			c.onTransportError(ctx, err)
   236  		}()
   237  
   238  		return nil, xerrors.WithStackTrace(
   239  			xerrors.Retryable(
   240  				xerrors.Transport(err),
   241  				xerrors.WithName("realConn"),
   242  			),
   243  		)
   244  	}
   245  
   246  	c.grpcConn = cc
   247  	c.setState(ctx, Online)
   248  
   249  	return c.grpcConn, nil
   250  }
   251  
   252  func (c *conn) onTransportError(ctx context.Context, cause error) {
   253  	for _, onTransportError := range c.onTransportErrors {
   254  		onTransportError(ctx, c, cause)
   255  	}
   256  }
   257  
   258  func isAvailable(raw *grpc.ClientConn) bool {
   259  	return raw != nil && raw.GetState() == connectivity.Ready
   260  }
   261  
   262  // conn must be locked
   263  func (c *conn) close(ctx context.Context) (err error) {
   264  	if c.grpcConn == nil {
   265  		return nil
   266  	}
   267  
   268  	defer func() {
   269  		c.grpcConn = nil
   270  		c.setState(ctx, Offline)
   271  	}()
   272  
   273  	err = c.grpcConn.Close()
   274  	if err == nil || !UseWrapping(ctx) {
   275  		return err
   276  	}
   277  
   278  	return xerrors.WithStackTrace(err)
   279  }
   280  
   281  func (c *conn) isClosed() bool {
   282  	c.mtx.RLock()
   283  	defer c.mtx.RUnlock()
   284  
   285  	return c.closed
   286  }
   287  
   288  func (c *conn) Close(ctx context.Context) (err error) {
   289  	c.mtx.Lock()
   290  	defer c.mtx.Unlock()
   291  
   292  	if c.closed {
   293  		return nil
   294  	}
   295  
   296  	onDone := trace.DriverOnConnClose(
   297  		c.config.Trace(), &ctx,
   298  		stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/conn.(*conn).Close"),
   299  		c.Endpoint(),
   300  	)
   301  	defer func() {
   302  		c.closed = true
   303  
   304  		c.setState(ctx, Destroyed)
   305  
   306  		for _, onClose := range c.onClose {
   307  			onClose(c)
   308  		}
   309  
   310  		onDone(err)
   311  	}()
   312  
   313  	err = c.close(ctx)
   314  
   315  	if !UseWrapping(ctx) {
   316  		return err
   317  	}
   318  
   319  	return xerrors.WithStackTrace(xerrors.Transport(err,
   320  		xerrors.WithAddress(c.Address()),
   321  		xerrors.WithNodeID(c.NodeID()),
   322  	))
   323  }
   324  
   325  var onTransportErrorStub = func(ctx context.Context, err error) {}
   326  
   327  func replyWrapper(reply any) (opID string, issues []trace.Issue) {
   328  	switch t := reply.(type) {
   329  	case operation.Response:
   330  		opID = t.GetOperation().GetId()
   331  		for _, issue := range t.GetOperation().GetIssues() {
   332  			issues = append(issues, issue)
   333  		}
   334  	case operation.Status:
   335  		for _, issue := range t.GetIssues() {
   336  			issues = append(issues, issue)
   337  		}
   338  	}
   339  
   340  	return opID, issues
   341  }
   342  
   343  //nolint:funlen
   344  func invoke(
   345  	ctx context.Context,
   346  	method string,
   347  	req, reply any,
   348  	cc grpc.ClientConnInterface,
   349  	onTransportError func(context.Context, error),
   350  	address string,
   351  	nodeID uint32,
   352  	opts ...grpc.CallOption,
   353  ) (
   354  	opID string,
   355  	issues []trace.Issue,
   356  	_ error,
   357  ) {
   358  	useWrapping := UseWrapping(ctx)
   359  
   360  	ctx, traceID, err := meta.TraceID(ctx)
   361  	if err != nil {
   362  		return opID, issues, xerrors.WithStackTrace(err)
   363  	}
   364  
   365  	ctx, sentMark := markContext(meta.WithTraceID(ctx, traceID))
   366  
   367  	if onTransportError == nil {
   368  		onTransportError = onTransportErrorStub
   369  	}
   370  
   371  	err = cc.Invoke(ctx, method, req, reply, opts...)
   372  	if err != nil {
   373  		if xerrors.IsContextError(err) {
   374  			return opID, issues, xerrors.WithStackTrace(err)
   375  		}
   376  
   377  		defer onTransportError(ctx, err)
   378  
   379  		if !useWrapping {
   380  			return opID, issues, err
   381  		}
   382  
   383  		if sentMark.canRetry() {
   384  			return opID, issues, xerrors.WithStackTrace(xerrors.Retryable(
   385  				xerrors.Transport(err,
   386  					xerrors.WithTraceID(traceID),
   387  				),
   388  				xerrors.WithName("Invoke"),
   389  			))
   390  		}
   391  
   392  		return opID, issues, xerrors.WithStackTrace(xerrors.Transport(err,
   393  			xerrors.WithAddress(address),
   394  			xerrors.WithNodeID(nodeID),
   395  			xerrors.WithTraceID(traceID),
   396  		))
   397  	}
   398  
   399  	opID, issues = replyWrapper(reply)
   400  
   401  	if !useWrapping {
   402  		return opID, issues, nil
   403  	}
   404  
   405  	switch t := reply.(type) {
   406  	case operation.Response:
   407  		switch {
   408  		case !t.GetOperation().GetReady():
   409  			return opID, issues, xerrors.WithStackTrace(errOperationNotReady)
   410  
   411  		case t.GetOperation().GetStatus() != Ydb.StatusIds_SUCCESS:
   412  			return opID, issues, xerrors.WithStackTrace(
   413  				xerrors.Operation(
   414  					xerrors.FromOperation(t.GetOperation()),
   415  					xerrors.WithAddress(address),
   416  					xerrors.WithNodeID(nodeID),
   417  					xerrors.WithTraceID(traceID),
   418  				),
   419  			)
   420  		}
   421  	case operation.Status:
   422  		if t.GetStatus() != Ydb.StatusIds_SUCCESS {
   423  			return opID, issues, xerrors.WithStackTrace(
   424  				xerrors.Operation(
   425  					xerrors.FromOperation(t),
   426  					xerrors.WithAddress(address),
   427  					xerrors.WithNodeID(nodeID),
   428  					xerrors.WithTraceID(traceID),
   429  				),
   430  			)
   431  		}
   432  	}
   433  
   434  	return opID, issues, nil
   435  }
   436  
   437  func (c *conn) Invoke(
   438  	ctx context.Context,
   439  	method string,
   440  	req interface{},
   441  	res interface{},
   442  	opts ...grpc.CallOption,
   443  ) (err error) {
   444  	var (
   445  		opID   string
   446  		issues []trace.Issue
   447  		onDone = trace.DriverOnConnInvoke(
   448  			c.config.Trace(), &ctx,
   449  			stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/conn.(*conn).Invoke"),
   450  			c.endpoint, trace.Method(method),
   451  		)
   452  		cc *grpc.ClientConn
   453  		md = metadata.MD{}
   454  	)
   455  	defer func() {
   456  		meta.CallTrailerCallback(ctx, md)
   457  		onDone(err, issues, opID, c.GetState(), md)
   458  	}()
   459  
   460  	cc, err = c.realConn(ctx)
   461  	if err != nil {
   462  		return xerrors.WithStackTrace(err)
   463  	}
   464  
   465  	stop := c.lastUsage.Start()
   466  	defer stop()
   467  
   468  	opID, issues, err = invoke(
   469  		ctx,
   470  		method,
   471  		req,
   472  		res,
   473  		cc,
   474  		c.onTransportError,
   475  		c.Address(),
   476  		c.NodeID(),
   477  		append(opts, grpc.Trailer(&md))...,
   478  	)
   479  
   480  	return err
   481  }
   482  
   483  //nolint:funlen
   484  func (c *conn) NewStream(
   485  	ctx context.Context,
   486  	desc *grpc.StreamDesc,
   487  	method string,
   488  	opts ...grpc.CallOption,
   489  ) (_ grpc.ClientStream, finalErr error) {
   490  	var (
   491  		onDone = trace.DriverOnConnNewStream(
   492  			c.config.Trace(), &ctx,
   493  			stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/conn.(*conn).NewStream"),
   494  			c.endpoint.Copy(), trace.Method(method),
   495  		)
   496  		useWrapping = UseWrapping(ctx)
   497  	)
   498  
   499  	defer func() {
   500  		onDone(finalErr, c.GetState())
   501  	}()
   502  
   503  	cc, err := c.realConn(ctx)
   504  	if err != nil {
   505  		return nil, xerrors.WithStackTrace(err)
   506  	}
   507  
   508  	stop := c.lastUsage.Start()
   509  	defer stop()
   510  
   511  	ctx, traceID, err := meta.TraceID(ctx)
   512  	if err != nil {
   513  		return nil, xerrors.WithStackTrace(err)
   514  	}
   515  
   516  	ctx, sentMark := markContext(meta.WithTraceID(ctx, traceID))
   517  
   518  	ctx, cancel := c.childStreams.WithCancel(ctx)
   519  	defer func() {
   520  		if finalErr != nil {
   521  			cancel()
   522  		}
   523  	}()
   524  
   525  	s := &grpcClientStream{
   526  		parentConn:   c,
   527  		streamCtx:    ctx,
   528  		streamCancel: cancel,
   529  		wrapping:     useWrapping,
   530  		traceID:      traceID,
   531  		sentMark:     sentMark,
   532  	}
   533  
   534  	s.stream, err = cc.NewStream(ctx, desc, method, append(opts, grpc.OnFinish(s.finish))...)
   535  	if err != nil {
   536  		if xerrors.IsContextError(err) {
   537  			return nil, xerrors.WithStackTrace(err)
   538  		}
   539  
   540  		defer func() {
   541  			c.onTransportError(ctx, err)
   542  		}()
   543  
   544  		if !useWrapping {
   545  			return nil, err
   546  		}
   547  
   548  		if sentMark.canRetry() {
   549  			return nil, xerrors.WithStackTrace(xerrors.Retryable(
   550  				xerrors.Transport(err,
   551  					xerrors.WithTraceID(traceID),
   552  				),
   553  				xerrors.WithName("NewStream"),
   554  			))
   555  		}
   556  
   557  		return nil, xerrors.WithStackTrace(xerrors.Transport(err,
   558  			xerrors.WithAddress(c.Address()),
   559  			xerrors.WithTraceID(traceID),
   560  		))
   561  	}
   562  
   563  	return s, nil
   564  }
   565  
   566  type option func(c *conn)
   567  
   568  func withOnClose(onClose func(*conn)) option {
   569  	return func(c *conn) {
   570  		if onClose != nil {
   571  			c.onClose = append(c.onClose, onClose)
   572  		}
   573  	}
   574  }
   575  
   576  func withOnTransportError(onTransportError func(ctx context.Context, cc Conn, cause error)) option {
   577  	return func(c *conn) {
   578  		if onTransportError != nil {
   579  			c.onTransportErrors = append(c.onTransportErrors, onTransportError)
   580  		}
   581  	}
   582  }
   583  
   584  func newConn(e endpoint.Endpoint, config Config, opts ...option) *conn {
   585  	c := &conn{
   586  		endpoint:     e,
   587  		config:       config,
   588  		done:         make(chan struct{}),
   589  		lastUsage:    xsync.NewLastUsage(),
   590  		childStreams: xcontext.NewCancelsGuard(),
   591  		onClose: []func(*conn){
   592  			func(c *conn) {
   593  				c.childStreams.Cancel()
   594  			},
   595  		},
   596  	}
   597  	c.state.Store(uint32(Created))
   598  	for _, opt := range opts {
   599  		if opt != nil {
   600  			opt(c)
   601  		}
   602  	}
   603  
   604  	return c
   605  }
   606  
   607  func New(e endpoint.Endpoint, config Config, opts ...option) Conn {
   608  	return newConn(e, config, opts...)
   609  }
   610  
   611  var _ stats.Handler = statsHandler{}
   612  
   613  type statsHandler struct{}
   614  
   615  func (statsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
   616  	return ctx
   617  }
   618  
   619  func (statsHandler) HandleRPC(ctx context.Context, rpcStats stats.RPCStats) {
   620  	switch rpcStats.(type) {
   621  	case *stats.Begin, *stats.End:
   622  	default:
   623  		getContextMark(ctx).markDirty()
   624  	}
   625  }
   626  
   627  func (statsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context {
   628  	return ctx
   629  }
   630  
   631  func (statsHandler) HandleConn(context.Context, stats.ConnStats) {}
   632  
   633  type ctxHandleRPCKey struct{}
   634  
   635  var rpcKey = ctxHandleRPCKey{}
   636  
   637  func markContext(ctx context.Context) (context.Context, *modificationMark) {
   638  	mark := &modificationMark{}
   639  
   640  	return context.WithValue(ctx, rpcKey, mark), mark
   641  }
   642  
   643  func getContextMark(ctx context.Context) *modificationMark {
   644  	v := ctx.Value(rpcKey)
   645  	if v == nil {
   646  		return &modificationMark{}
   647  	}
   648  
   649  	val, ok := v.(*modificationMark)
   650  	if !ok {
   651  		panic(fmt.Sprintf("unsupported type conversion from %T to *modificationMark", val))
   652  	}
   653  
   654  	return val
   655  }
   656  
   657  type modificationMark struct {
   658  	dirty atomic.Bool
   659  }
   660  
   661  func (m *modificationMark) canRetry() bool {
   662  	return !m.dirty.Load()
   663  }
   664  
   665  func (m *modificationMark) markDirty() {
   666  	m.dirty.Store(true)
   667  }