github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/conn/conn.go (about)

     1  package conn
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"sync/atomic"
     8  	"time"
     9  
    10  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
    11  	"google.golang.org/grpc"
    12  	"google.golang.org/grpc/connectivity"
    13  	"google.golang.org/grpc/metadata"
    14  	"google.golang.org/grpc/stats"
    15  
    16  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
    17  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/meta"
    18  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/response"
    19  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
    20  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
    21  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    22  	"github.com/ydb-platform/ydb-go-sdk/v3/trace"
    23  )
    24  
    25  var (
    26  	// errOperationNotReady specified error when operation is not ready
    27  	errOperationNotReady = xerrors.Wrap(fmt.Errorf("operation is not ready yet"))
    28  
    29  	// errClosedConnection specified error when connection are closed early
    30  	errClosedConnection = xerrors.Wrap(fmt.Errorf("connection closed early"))
    31  
    32  	// errUnavailableConnection specified error when connection are closed early
    33  	errUnavailableConnection = xerrors.Wrap(fmt.Errorf("connection unavailable"))
    34  )
    35  
    36  type Conn interface {
    37  	grpc.ClientConnInterface
    38  
    39  	Endpoint() endpoint.Endpoint
    40  
    41  	LastUsage() time.Time
    42  
    43  	Ping(ctx context.Context) error
    44  	IsState(states ...State) bool
    45  	GetState() State
    46  	SetState(ctx context.Context, state State) State
    47  	Unban(ctx context.Context) State
    48  }
    49  
    50  type conn struct {
    51  	mtx               sync.RWMutex
    52  	config            Config // ro access
    53  	cc                *grpc.ClientConn
    54  	done              chan struct{}
    55  	endpoint          endpoint.Endpoint // ro access
    56  	closed            bool
    57  	state             atomic.Uint32
    58  	lastUsage         time.Time
    59  	onClose           []func(*conn)
    60  	onTransportErrors []func(ctx context.Context, cc Conn, cause error)
    61  }
    62  
    63  func (c *conn) Address() string {
    64  	return c.endpoint.Address()
    65  }
    66  
    67  func (c *conn) Ping(ctx context.Context) error {
    68  	cc, err := c.realConn(ctx)
    69  	if err != nil {
    70  		return c.wrapError(err)
    71  	}
    72  	if !isAvailable(cc) {
    73  		return c.wrapError(errUnavailableConnection)
    74  	}
    75  
    76  	return nil
    77  }
    78  
    79  func (c *conn) LastUsage() time.Time {
    80  	c.mtx.RLock()
    81  	defer c.mtx.RUnlock()
    82  
    83  	return c.lastUsage
    84  }
    85  
    86  func (c *conn) IsState(states ...State) bool {
    87  	state := State(c.state.Load())
    88  	for _, s := range states {
    89  		if s == state {
    90  			return true
    91  		}
    92  	}
    93  
    94  	return false
    95  }
    96  
    97  func (c *conn) park(ctx context.Context) (err error) {
    98  	onDone := trace.DriverOnConnPark(
    99  		c.config.Trace(), &ctx,
   100  		stack.FunctionID(""),
   101  		c.Endpoint(),
   102  	)
   103  	defer func() {
   104  		onDone(err)
   105  	}()
   106  
   107  	c.mtx.Lock()
   108  	defer c.mtx.Unlock()
   109  
   110  	if c.closed {
   111  		return nil
   112  	}
   113  
   114  	if c.cc == nil {
   115  		return nil
   116  	}
   117  
   118  	err = c.close(ctx)
   119  
   120  	if err != nil {
   121  		return c.wrapError(err)
   122  	}
   123  
   124  	return nil
   125  }
   126  
   127  func (c *conn) NodeID() uint32 {
   128  	if c != nil {
   129  		return c.endpoint.NodeID()
   130  	}
   131  
   132  	return 0
   133  }
   134  
   135  func (c *conn) Endpoint() endpoint.Endpoint {
   136  	if c != nil {
   137  		return c.endpoint
   138  	}
   139  
   140  	return nil
   141  }
   142  
   143  func (c *conn) SetState(ctx context.Context, s State) State {
   144  	return c.setState(ctx, s)
   145  }
   146  
   147  func (c *conn) setState(ctx context.Context, s State) State {
   148  	if state := State(c.state.Swap(uint32(s))); state != s {
   149  		trace.DriverOnConnStateChange(
   150  			c.config.Trace(), &ctx,
   151  			stack.FunctionID(""),
   152  			c.endpoint.Copy(), state,
   153  		)(s)
   154  	}
   155  
   156  	return s
   157  }
   158  
   159  func (c *conn) Unban(ctx context.Context) State {
   160  	var newState State
   161  	c.mtx.RLock()
   162  	cc := c.cc
   163  	c.mtx.RUnlock()
   164  	if isAvailable(cc) {
   165  		newState = Online
   166  	} else {
   167  		newState = Offline
   168  	}
   169  
   170  	c.setState(ctx, newState)
   171  
   172  	return newState
   173  }
   174  
   175  func (c *conn) GetState() (s State) {
   176  	return State(c.state.Load())
   177  }
   178  
   179  func (c *conn) realConn(ctx context.Context) (cc *grpc.ClientConn, err error) {
   180  	if c.isClosed() {
   181  		return nil, c.wrapError(errClosedConnection)
   182  	}
   183  
   184  	c.mtx.Lock()
   185  	defer c.mtx.Unlock()
   186  
   187  	if c.cc != nil {
   188  		return c.cc, nil
   189  	}
   190  
   191  	if dialTimeout := c.config.DialTimeout(); dialTimeout > 0 {
   192  		var cancel context.CancelFunc
   193  		ctx, cancel = xcontext.WithTimeout(ctx, dialTimeout)
   194  		defer cancel()
   195  	}
   196  
   197  	onDone := trace.DriverOnConnDial(
   198  		c.config.Trace(), &ctx,
   199  		stack.FunctionID(""),
   200  		c.endpoint.Copy(),
   201  	)
   202  	defer func() {
   203  		onDone(err)
   204  	}()
   205  
   206  	// prepend "ydb" scheme for grpc dns-resolver to find the proper scheme
   207  	// three slashes in "ydb:///" is ok. It needs for good parse scheme in grpc resolver.
   208  	address := "ydb:///" + c.endpoint.Address()
   209  
   210  	cc, err = grpc.DialContext(ctx, address, append(
   211  		[]grpc.DialOption{
   212  			grpc.WithStatsHandler(statsHandler{}),
   213  		}, c.config.GrpcDialOptions()...,
   214  	)...)
   215  	if err != nil {
   216  		defer func() {
   217  			c.onTransportError(ctx, err)
   218  		}()
   219  
   220  		err = xerrors.Transport(err,
   221  			xerrors.WithAddress(address),
   222  		)
   223  
   224  		return nil, c.wrapError(
   225  			xerrors.Retryable(err,
   226  				xerrors.WithName("realConn"),
   227  			),
   228  		)
   229  	}
   230  
   231  	c.cc = cc
   232  	c.setState(ctx, Online)
   233  
   234  	return c.cc, nil
   235  }
   236  
   237  func (c *conn) onTransportError(ctx context.Context, cause error) {
   238  	for _, onTransportError := range c.onTransportErrors {
   239  		onTransportError(ctx, c, cause)
   240  	}
   241  }
   242  
   243  func (c *conn) touchLastUsage() {
   244  	c.mtx.Lock()
   245  	defer c.mtx.Unlock()
   246  	c.lastUsage = time.Now()
   247  }
   248  
   249  func isAvailable(raw *grpc.ClientConn) bool {
   250  	return raw != nil && raw.GetState() == connectivity.Ready
   251  }
   252  
   253  // conn must be locked
   254  func (c *conn) close(ctx context.Context) (err error) {
   255  	if c.cc == nil {
   256  		return nil
   257  	}
   258  	err = c.cc.Close()
   259  	c.cc = nil
   260  	c.setState(ctx, Offline)
   261  
   262  	return c.wrapError(err)
   263  }
   264  
   265  func (c *conn) isClosed() bool {
   266  	c.mtx.RLock()
   267  	defer c.mtx.RUnlock()
   268  
   269  	return c.closed
   270  }
   271  
   272  func (c *conn) Close(ctx context.Context) (err error) {
   273  	c.mtx.Lock()
   274  	defer c.mtx.Unlock()
   275  
   276  	if c.closed {
   277  		return nil
   278  	}
   279  
   280  	onDone := trace.DriverOnConnClose(
   281  		c.config.Trace(), &ctx,
   282  		stack.FunctionID(""),
   283  		c.Endpoint(),
   284  	)
   285  	defer func() {
   286  		onDone(err)
   287  	}()
   288  
   289  	c.closed = true
   290  
   291  	err = c.close(ctx)
   292  
   293  	c.setState(ctx, Destroyed)
   294  
   295  	for _, onClose := range c.onClose {
   296  		onClose(c)
   297  	}
   298  
   299  	return c.wrapError(err)
   300  }
   301  
   302  func (c *conn) Invoke(
   303  	ctx context.Context,
   304  	method string,
   305  	req interface{},
   306  	res interface{},
   307  	opts ...grpc.CallOption,
   308  ) (err error) {
   309  	var (
   310  		opID        string
   311  		issues      []trace.Issue
   312  		useWrapping = UseWrapping(ctx)
   313  		onDone      = trace.DriverOnConnInvoke(
   314  			c.config.Trace(), &ctx,
   315  			stack.FunctionID(""),
   316  			c.endpoint, trace.Method(method),
   317  		)
   318  		cc *grpc.ClientConn
   319  		md = metadata.MD{}
   320  	)
   321  	defer func() {
   322  		meta.CallTrailerCallback(ctx, md)
   323  		onDone(err, issues, opID, c.GetState(), md)
   324  	}()
   325  
   326  	cc, err = c.realConn(ctx)
   327  	if err != nil {
   328  		return c.wrapError(err)
   329  	}
   330  
   331  	c.touchLastUsage()
   332  	defer c.touchLastUsage()
   333  
   334  	ctx, traceID, err := meta.TraceID(ctx)
   335  	if err != nil {
   336  		return xerrors.WithStackTrace(err)
   337  	}
   338  
   339  	ctx, sentMark := markContext(meta.WithTraceID(ctx, traceID))
   340  
   341  	err = cc.Invoke(ctx, method, req, res, append(opts, grpc.Trailer(&md))...)
   342  	if err != nil {
   343  		defer func() {
   344  			c.onTransportError(ctx, err)
   345  		}()
   346  
   347  		if useWrapping {
   348  			err = xerrors.Transport(err,
   349  				xerrors.WithAddress(c.Address()),
   350  				xerrors.WithTraceID(traceID),
   351  			)
   352  			if sentMark.canRetry() {
   353  				return c.wrapError(xerrors.Retryable(err, xerrors.WithName("Invoke")))
   354  			}
   355  
   356  			return c.wrapError(err)
   357  		}
   358  
   359  		return err
   360  	}
   361  
   362  	if o, ok := res.(response.Response); ok {
   363  		opID = o.GetOperation().GetId()
   364  		for _, issue := range o.GetOperation().GetIssues() {
   365  			issues = append(issues, issue)
   366  		}
   367  		if useWrapping {
   368  			switch {
   369  			case !o.GetOperation().GetReady():
   370  				return c.wrapError(errOperationNotReady)
   371  
   372  			case o.GetOperation().GetStatus() != Ydb.StatusIds_SUCCESS:
   373  				return c.wrapError(
   374  					xerrors.Operation(
   375  						xerrors.FromOperation(o.GetOperation()),
   376  						xerrors.WithAddress(c.Address()),
   377  						xerrors.WithTraceID(traceID),
   378  					),
   379  				)
   380  			}
   381  		}
   382  	}
   383  
   384  	return err
   385  }
   386  
   387  func (c *conn) NewStream(
   388  	ctx context.Context,
   389  	desc *grpc.StreamDesc,
   390  	method string,
   391  	opts ...grpc.CallOption,
   392  ) (_ grpc.ClientStream, err error) {
   393  	var (
   394  		streamRecv = trace.DriverOnConnNewStream(
   395  			c.config.Trace(), &ctx,
   396  			stack.FunctionID(""),
   397  			c.endpoint.Copy(), trace.Method(method),
   398  		)
   399  		useWrapping = UseWrapping(ctx)
   400  		cc          *grpc.ClientConn
   401  		s           grpc.ClientStream
   402  	)
   403  
   404  	defer func() {
   405  		if err != nil {
   406  			streamRecv(err)(err, c.GetState(), metadata.MD{})
   407  		}
   408  	}()
   409  
   410  	var cancel context.CancelFunc
   411  	ctx, cancel = xcontext.WithCancel(ctx)
   412  
   413  	defer func() {
   414  		if err != nil {
   415  			cancel()
   416  		}
   417  	}()
   418  
   419  	cc, err = c.realConn(ctx)
   420  	if err != nil {
   421  		return nil, c.wrapError(err)
   422  	}
   423  
   424  	c.touchLastUsage()
   425  	defer c.touchLastUsage()
   426  
   427  	ctx, traceID, err := meta.TraceID(ctx)
   428  	if err != nil {
   429  		return nil, xerrors.WithStackTrace(err)
   430  	}
   431  
   432  	ctx, sentMark := markContext(meta.WithTraceID(ctx, traceID))
   433  
   434  	s, err = cc.NewStream(ctx, desc, method, opts...)
   435  	if err != nil {
   436  		defer func() {
   437  			c.onTransportError(ctx, err)
   438  		}()
   439  
   440  		if useWrapping {
   441  			err = xerrors.Transport(err,
   442  				xerrors.WithAddress(c.Address()),
   443  				xerrors.WithTraceID(traceID),
   444  			)
   445  			if sentMark.canRetry() {
   446  				return s, c.wrapError(xerrors.Retryable(err, xerrors.WithName("NewStream")))
   447  			}
   448  
   449  			return s, c.wrapError(err)
   450  		}
   451  
   452  		return s, err
   453  	}
   454  
   455  	return &grpcClientStream{
   456  		ClientStream: s,
   457  		c:            c,
   458  		wrapping:     useWrapping,
   459  		traceID:      traceID,
   460  		sentMark:     sentMark,
   461  		onDone: func(ctx context.Context, md metadata.MD) {
   462  			cancel()
   463  			meta.CallTrailerCallback(ctx, md)
   464  		},
   465  		recv: streamRecv,
   466  	}, nil
   467  }
   468  
   469  func (c *conn) wrapError(err error) error {
   470  	if err == nil {
   471  		return nil
   472  	}
   473  	nodeErr := newConnError(c.endpoint.NodeID(), c.endpoint.Address(), err)
   474  
   475  	return xerrors.WithStackTrace(nodeErr, xerrors.WithSkipDepth(1))
   476  }
   477  
   478  type option func(c *conn)
   479  
   480  func withOnClose(onClose func(*conn)) option {
   481  	return func(c *conn) {
   482  		if onClose != nil {
   483  			c.onClose = append(c.onClose, onClose)
   484  		}
   485  	}
   486  }
   487  
   488  func withOnTransportError(onTransportError func(ctx context.Context, cc Conn, cause error)) option {
   489  	return func(c *conn) {
   490  		if onTransportError != nil {
   491  			c.onTransportErrors = append(c.onTransportErrors, onTransportError)
   492  		}
   493  	}
   494  }
   495  
   496  func newConn(e endpoint.Endpoint, config Config, opts ...option) *conn {
   497  	c := &conn{
   498  		endpoint: e,
   499  		config:   config,
   500  		done:     make(chan struct{}),
   501  	}
   502  	c.state.Store(uint32(Created))
   503  	for _, o := range opts {
   504  		if o != nil {
   505  			o(c)
   506  		}
   507  	}
   508  
   509  	return c
   510  }
   511  
   512  func New(e endpoint.Endpoint, config Config, opts ...option) Conn {
   513  	return newConn(e, config, opts...)
   514  }
   515  
   516  var _ stats.Handler = statsHandler{}
   517  
   518  type statsHandler struct{}
   519  
   520  func (statsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
   521  	return ctx
   522  }
   523  
   524  func (statsHandler) HandleRPC(ctx context.Context, rpcStats stats.RPCStats) {
   525  	switch rpcStats.(type) {
   526  	case *stats.Begin, *stats.End:
   527  	default:
   528  		getContextMark(ctx).markDirty()
   529  	}
   530  }
   531  
   532  func (statsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context {
   533  	return ctx
   534  }
   535  
   536  func (statsHandler) HandleConn(context.Context, stats.ConnStats) {}
   537  
   538  type ctxHandleRPCKey struct{}
   539  
   540  var rpcKey = ctxHandleRPCKey{}
   541  
   542  func markContext(ctx context.Context) (context.Context, *modificationMark) {
   543  	mark := &modificationMark{}
   544  
   545  	return context.WithValue(ctx, rpcKey, mark), mark
   546  }
   547  
   548  func getContextMark(ctx context.Context) *modificationMark {
   549  	v := ctx.Value(rpcKey)
   550  	if v == nil {
   551  		return &modificationMark{}
   552  	}
   553  
   554  	return v.(*modificationMark)
   555  }
   556  
   557  type modificationMark struct {
   558  	dirty atomic.Bool
   559  }
   560  
   561  func (m *modificationMark) canRetry() bool {
   562  	return !m.dirty.Load()
   563  }
   564  
   565  func (m *modificationMark) markDirty() {
   566  	m.dirty.Store(true)
   567  }