github.com/anycable/anycable-go@v1.5.1/rpc/rpc.go (about)

     1  package rpc
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"log/slog"
     8  	"math"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/anycable/anycable-go/common"
    14  	"github.com/anycable/anycable-go/metrics"
    15  	"github.com/anycable/anycable-go/protocol"
    16  	"github.com/anycable/anycable-go/utils"
    17  	"github.com/joomcode/errorx"
    18  
    19  	pb "github.com/anycable/anycable-go/protos"
    20  	"google.golang.org/grpc"
    21  	"google.golang.org/grpc/codes"
    22  	"google.golang.org/grpc/connectivity"
    23  	"google.golang.org/grpc/credentials"
    24  	"google.golang.org/grpc/credentials/insecure"
    25  	"google.golang.org/grpc/keepalive"
    26  	"google.golang.org/grpc/metadata"
    27  	"google.golang.org/grpc/peer"
    28  	"google.golang.org/grpc/stats"
    29  	"google.golang.org/grpc/status"
    30  )
    31  
    32  const (
    33  	// ProtoVersions contains a comma-seprated list of compatible RPC protos versions
    34  	// (we pass it as request meta to notify clients)
    35  	ProtoVersions = "v1"
    36  	invokeTimeout = 3000
    37  
    38  	retryExhaustedInterval   = 10
    39  	retryUnavailableInterval = 100
    40  
    41  	refreshMetricsInterval = time.Duration(10) * time.Second
    42  
    43  	metricsRPCCalls        = "rpc_call_total"
    44  	metricsRPCRetries      = "rpc_retries_total"
    45  	metricsRPCFailures     = "rpc_error_total"
    46  	metricsRPCPending      = "rpc_pending_num"
    47  	metricsRPCCapacity     = "rpc_capacity_num"
    48  	metricsGRPCActiveConns = "grpc_active_conn_num"
    49  
    50  	secretKeyPhrase = "rpc-cable"
    51  )
    52  
    53  type grpcClientHelper struct {
    54  	conn       *grpc.ClientConn
    55  	recovering bool
    56  	mu         sync.Mutex
    57  
    58  	log    *slog.Logger
    59  	active int64
    60  }
    61  
    62  // Returns nil if connection in the READY/IDLE/CONNECTING state.
    63  // If connection is in the TransientFailure state, we try to re-connect immediately
    64  // once.
    65  // See https://github.com/grpc/grpc/blob/master/doc/connectivity-semantics-and-api.md
    66  // and https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md
    67  // See also https://github.com/cockroachdb/cockroach/blob/master/pkg/util/grpcutil/grpc_util.go
    68  func (st *grpcClientHelper) Ready() error {
    69  	s := st.conn.GetState()
    70  
    71  	if s == connectivity.Shutdown {
    72  		return errors.New("grpc connection is closed")
    73  	}
    74  
    75  	if s == connectivity.TransientFailure {
    76  		return st.tryRecover()
    77  	}
    78  
    79  	if st.recovering {
    80  		st.reset()
    81  	}
    82  
    83  	return nil
    84  }
    85  
    86  func (st *grpcClientHelper) Close() {
    87  	st.conn.Close()
    88  }
    89  
    90  func (st *grpcClientHelper) ActiveConns() int {
    91  	return int(atomic.LoadInt64(&st.active))
    92  }
    93  
    94  func (st *grpcClientHelper) SupportsActiveConns() bool {
    95  	return true
    96  }
    97  
    98  func (st *grpcClientHelper) HandleConn(ctx context.Context, stat stats.ConnStats) {
    99  	var addr string
   100  
   101  	if p, ok := peer.FromContext(ctx); ok {
   102  		addr = p.Addr.String()
   103  	}
   104  
   105  	if _, ok := stat.(*stats.ConnBegin); ok {
   106  		st.log.Debug("connected", "addr", addr)
   107  		atomic.AddInt64(&st.active, 1)
   108  	}
   109  
   110  	if _, ok := stat.(*stats.ConnEnd); ok {
   111  		st.log.Debug("disconnected", "addr", addr)
   112  		atomic.AddInt64(&st.active, -1)
   113  	}
   114  }
   115  
   116  func (st *grpcClientHelper) HandleRPC(ctx context.Context, stat stats.RPCStats) {
   117  	// no-op
   118  }
   119  
   120  func (st *grpcClientHelper) TagConn(ctx context.Context, stat *stats.ConnTagInfo) context.Context {
   121  	return ctx
   122  }
   123  
   124  func (st *grpcClientHelper) TagRPC(ctx context.Context, stat *stats.RPCTagInfo) context.Context {
   125  	return ctx
   126  }
   127  
   128  func (st *grpcClientHelper) tryRecover() error {
   129  	st.mu.Lock()
   130  	defer st.mu.Unlock()
   131  
   132  	if st.recovering {
   133  		return errors.New("grpc connection is not ready")
   134  	}
   135  
   136  	st.recovering = true
   137  	st.conn.ResetConnectBackoff()
   138  
   139  	st.log.Warn("connection is lost, trying to reconnect immediately")
   140  
   141  	return nil
   142  }
   143  
   144  func (st *grpcClientHelper) reset() {
   145  	st.mu.Lock()
   146  	defer st.mu.Unlock()
   147  
   148  	if st.recovering {
   149  		st.recovering = false
   150  		st.log.Info("connection is restored")
   151  	}
   152  }
   153  
   154  // Controller implements node.Controller interface for gRPC
   155  type Controller struct {
   156  	config      *Config
   157  	barrier     Barrier
   158  	client      pb.RPCClient
   159  	metrics     metrics.Instrumenter
   160  	log         *slog.Logger
   161  	clientState ClientHelper
   162  
   163  	timerMu      sync.Mutex
   164  	metricsTimer *time.Timer
   165  }
   166  
   167  // NewController builds new Controller
   168  func NewController(metrics metrics.Instrumenter, config *Config, l *slog.Logger) (*Controller, error) {
   169  	metrics.RegisterCounter(metricsRPCCalls, "The total number of RPC calls")
   170  	metrics.RegisterCounter(metricsRPCRetries, "The total number of RPC call retries")
   171  	metrics.RegisterCounter(metricsRPCFailures, "The total number of failed RPC calls")
   172  	metrics.RegisterGauge(metricsRPCPending, "The number of pending RPC calls")
   173  
   174  	capacity := config.Concurrency
   175  	if capacity <= 0 {
   176  		capacity = defaultRPCConcurrency
   177  		l.Warn("RPC concurrency must be positive, reverted to the default value")
   178  	}
   179  	barrier, err := NewFixedSizeBarrier(capacity)
   180  
   181  	if err != nil {
   182  		return nil, err
   183  	}
   184  
   185  	if barrier.HasDynamicCapacity() {
   186  		metrics.RegisterGauge(metricsRPCCapacity, "The max number of concurrent RPC calls allowed")
   187  		metrics.GaugeSet(metricsRPCCapacity, uint64(barrier.Capacity()))
   188  	}
   189  
   190  	if config.Impl() == "grpc" {
   191  		metrics.RegisterGauge(metricsGRPCActiveConns, "The number of active HTTP connections used by gRPC")
   192  	}
   193  
   194  	return &Controller{log: l.With("context", "rpc"), metrics: metrics, config: config, barrier: barrier}, nil
   195  }
   196  
   197  // Start initializes RPC connection pool
   198  func (c *Controller) Start() error {
   199  	host := c.config.Host
   200  	enableTLS := c.config.TLSEnabled()
   201  	impl := c.config.Impl()
   202  
   203  	dialer := c.config.DialFun
   204  
   205  	if dialer == nil {
   206  		switch impl {
   207  		case "http":
   208  			var err error
   209  
   210  			if c.config.Secret == "" && c.config.SecretBase != "" {
   211  				secret, verr := utils.NewMessageVerifier(c.config.SecretBase).Sign([]byte(secretKeyPhrase))
   212  
   213  				if verr != nil {
   214  					verr = errorx.Decorate(verr, "failed to auto-generate authentication key for HTTP RPC")
   215  					return verr
   216  				}
   217  
   218  				c.log.Info("auto-generated authorization secret from the application secret")
   219  				c.config.Secret = string(secret)
   220  			}
   221  
   222  			dialer, err = NewHTTPDialer(c.config)
   223  			if err != nil {
   224  				return err
   225  			}
   226  		case "grpc":
   227  			dialer = defaultDialer
   228  		default:
   229  			return fmt.Errorf("unknown RPC implementation: %s", impl)
   230  		}
   231  	}
   232  
   233  	client, state, err := dialer(c.config, c.log)
   234  
   235  	if err == nil {
   236  		c.log.Info(fmt.Sprintf("RPC controller initialized: %s (concurrency: %s, impl: %s, enable_tls: %t, proto_versions: %s)", host, c.barrier.CapacityInfo(), impl, enableTLS, ProtoVersions))
   237  	} else {
   238  		return err
   239  	}
   240  
   241  	c.client = client
   242  	c.clientState = state
   243  
   244  	if c.barrier.HasDynamicCapacity() || state.SupportsActiveConns() {
   245  		c.metricsTimer = time.AfterFunc(refreshMetricsInterval, c.refreshMetrics)
   246  	}
   247  
   248  	c.barrier.Start()
   249  
   250  	return nil
   251  }
   252  
   253  // Shutdown closes connections
   254  func (c *Controller) Shutdown() error {
   255  	if c.clientState == nil {
   256  		return nil
   257  	}
   258  
   259  	c.timerMu.Lock()
   260  	if c.metricsTimer != nil {
   261  		c.metricsTimer.Stop()
   262  	}
   263  	c.timerMu.Unlock()
   264  
   265  	defer c.clientState.Close()
   266  
   267  	busy := c.busy()
   268  
   269  	if busy > 0 {
   270  		c.log.Info("waiting for active RPC calls to finish", "num", busy)
   271  	}
   272  
   273  	// Wait for active connections
   274  	_, err := c.retry("", func() (interface{}, error) {
   275  		busy := c.busy()
   276  
   277  		if busy > 0 {
   278  			return false, fmt.Errorf("terminated while completing active RPC calls: %d", busy)
   279  		}
   280  
   281  		c.log.Info("all active RPC calls finished")
   282  		return true, nil
   283  	})
   284  
   285  	c.barrier.Stop()
   286  
   287  	return err
   288  }
   289  
   290  // Authenticate performs Connect RPC call
   291  func (c *Controller) Authenticate(sid string, env *common.SessionEnv) (*common.ConnectResult, error) {
   292  	c.metrics.GaugeIncrement(metricsRPCPending)
   293  	c.barrier.Acquire()
   294  	c.metrics.GaugeDecrement(metricsRPCPending)
   295  
   296  	defer c.barrier.Release()
   297  
   298  	op := func() (interface{}, error) {
   299  		return c.client.Connect(
   300  			newContext(sid),
   301  			protocol.NewConnectMessage(env),
   302  		)
   303  	}
   304  
   305  	c.metrics.CounterIncrement(metricsRPCCalls)
   306  
   307  	response, err := c.retry(sid, op)
   308  
   309  	if err != nil {
   310  		c.metrics.CounterIncrement(metricsRPCFailures)
   311  
   312  		return nil, err
   313  	}
   314  
   315  	if r, ok := response.(*pb.ConnectionResponse); ok {
   316  		reply, err := protocol.ParseConnectResponse(r)
   317  
   318  		return reply, err
   319  	}
   320  
   321  	c.metrics.CounterIncrement(metricsRPCFailures)
   322  
   323  	return nil, errors.New("failed to deserialize connection response")
   324  }
   325  
   326  // Subscribe performs Command RPC call with "subscribe" command
   327  func (c *Controller) Subscribe(sid string, env *common.SessionEnv, id string, channel string) (*common.CommandResult, error) {
   328  	c.metrics.GaugeIncrement(metricsRPCPending)
   329  	c.barrier.Acquire()
   330  	c.metrics.GaugeDecrement(metricsRPCPending)
   331  
   332  	defer c.barrier.Release()
   333  
   334  	op := func() (interface{}, error) {
   335  		return c.client.Command(
   336  			newContext(sid),
   337  			protocol.NewCommandMessage(env, "subscribe", channel, id, ""),
   338  		)
   339  	}
   340  
   341  	response, err := c.retry(sid, op)
   342  
   343  	return c.parseCommandResponse(sid, response, err)
   344  }
   345  
   346  // Unsubscribe performs Command RPC call with "unsubscribe" command
   347  func (c *Controller) Unsubscribe(sid string, env *common.SessionEnv, id string, channel string) (*common.CommandResult, error) {
   348  	c.metrics.GaugeIncrement(metricsRPCPending)
   349  	c.barrier.Acquire()
   350  	c.metrics.GaugeDecrement(metricsRPCPending)
   351  
   352  	defer c.barrier.Release()
   353  
   354  	op := func() (interface{}, error) {
   355  		return c.client.Command(
   356  			newContext(sid),
   357  			protocol.NewCommandMessage(env, "unsubscribe", channel, id, ""),
   358  		)
   359  	}
   360  
   361  	response, err := c.retry(sid, op)
   362  
   363  	return c.parseCommandResponse(sid, response, err)
   364  }
   365  
   366  // Perform performs Command RPC call with "perform" command
   367  func (c *Controller) Perform(sid string, env *common.SessionEnv, id string, channel string, data string) (*common.CommandResult, error) {
   368  	c.metrics.GaugeIncrement(metricsRPCPending)
   369  	c.barrier.Acquire()
   370  	c.metrics.GaugeDecrement(metricsRPCPending)
   371  
   372  	defer c.barrier.Release()
   373  
   374  	op := func() (interface{}, error) {
   375  		return c.client.Command(
   376  			newContext(sid),
   377  			protocol.NewCommandMessage(env, "message", channel, id, data),
   378  		)
   379  	}
   380  
   381  	response, err := c.retry(sid, op)
   382  
   383  	return c.parseCommandResponse(sid, response, err)
   384  }
   385  
   386  // Disconnect performs disconnect RPC call
   387  func (c *Controller) Disconnect(sid string, env *common.SessionEnv, id string, subscriptions []string) error {
   388  	c.metrics.GaugeIncrement(metricsRPCPending)
   389  	c.barrier.Acquire()
   390  	c.metrics.GaugeDecrement(metricsRPCPending)
   391  
   392  	defer c.barrier.Release()
   393  
   394  	op := func() (interface{}, error) {
   395  		return c.client.Disconnect(
   396  			newContext(sid),
   397  			protocol.NewDisconnectMessage(env, id, subscriptions),
   398  		)
   399  	}
   400  
   401  	c.metrics.CounterIncrement(metricsRPCCalls)
   402  
   403  	response, err := c.retry(sid, op)
   404  
   405  	if err != nil {
   406  		c.metrics.CounterIncrement(metricsRPCFailures)
   407  		return err
   408  	}
   409  
   410  	if r, ok := response.(*pb.DisconnectResponse); ok {
   411  		err = protocol.ParseDisconnectResponse(r)
   412  
   413  		if err != nil {
   414  			c.metrics.CounterIncrement(metricsRPCFailures)
   415  		}
   416  
   417  		return err
   418  	}
   419  
   420  	return errors.New("failed to deserialize disconnect response")
   421  }
   422  
   423  func (c *Controller) parseCommandResponse(sid string, response interface{}, err error) (*common.CommandResult, error) {
   424  	c.metrics.CounterIncrement(metricsRPCCalls)
   425  
   426  	if err != nil {
   427  		c.metrics.CounterIncrement(metricsRPCFailures)
   428  
   429  		return nil, err
   430  	}
   431  
   432  	if r, ok := response.(*pb.CommandResponse); ok {
   433  		res, err := protocol.ParseCommandResponse(r)
   434  
   435  		return res, err
   436  	}
   437  
   438  	c.metrics.CounterIncrement(metricsRPCFailures)
   439  
   440  	return nil, errors.New("failed to deserialize command response")
   441  }
   442  
   443  func (c *Controller) busy() int {
   444  	return c.barrier.BusyCount()
   445  }
   446  
   447  func (c *Controller) retry(sid string, callback func() (interface{}, error)) (res interface{}, err error) {
   448  	retryAge := 0
   449  	attempt := 0
   450  	wasExhausted := false
   451  
   452  	for {
   453  		if stErr := c.clientState.Ready(); stErr != nil {
   454  			return nil, stErr
   455  		}
   456  
   457  		res, err = callback()
   458  
   459  		if err == nil {
   460  			return res, nil
   461  		}
   462  
   463  		if retryAge > invokeTimeout {
   464  			return nil, err
   465  		}
   466  
   467  		st, ok := status.FromError(err)
   468  		if !ok {
   469  			return nil, err
   470  		}
   471  
   472  		code := st.Code()
   473  
   474  		if !(code == codes.ResourceExhausted || code == codes.Unavailable) {
   475  			return nil, err
   476  		}
   477  
   478  		c.log.With("sid", sid).Debug("RPC failed", "code", st.Code(), "error", st.Message())
   479  
   480  		interval := retryUnavailableInterval
   481  
   482  		if st.Code() == codes.ResourceExhausted {
   483  			interval = retryExhaustedInterval
   484  			if !wasExhausted {
   485  				attempt = 0
   486  				wasExhausted = true
   487  			}
   488  			c.barrier.Exhausted()
   489  		} else if wasExhausted {
   490  			wasExhausted = false
   491  			attempt = 0
   492  		}
   493  
   494  		delayMS := int(math.Pow(2, float64(attempt))) * interval
   495  		delay := time.Duration(delayMS)
   496  
   497  		retryAge += delayMS
   498  
   499  		c.metrics.CounterIncrement(metricsRPCRetries)
   500  
   501  		time.Sleep(delay * time.Millisecond)
   502  
   503  		attempt++
   504  	}
   505  }
   506  
   507  func newContext(sessionID string) context.Context {
   508  	md := metadata.Pairs("sid", sessionID, "protov", ProtoVersions)
   509  	return metadata.NewOutgoingContext(context.Background(), md)
   510  }
   511  
   512  func defaultDialer(conf *Config, l *slog.Logger) (pb.RPCClient, ClientHelper, error) {
   513  	host := conf.Host
   514  	enableTLS := conf.TLSEnabled()
   515  
   516  	kacp := keepalive.ClientParameters{
   517  		Time:                10 * time.Second, // send pings every 10 seconds if there is no activity
   518  		PermitWithoutStream: true,             // send pings even without active streams
   519  	}
   520  
   521  	const grpcServiceConfig = `{"loadBalancingPolicy":"round_robin"}`
   522  
   523  	state := &grpcClientHelper{log: l.With("impl", "grpc")}
   524  
   525  	dialOptions := []grpc.DialOption{
   526  		grpc.WithKeepaliveParams(kacp),
   527  		grpc.WithDefaultServiceConfig(grpcServiceConfig),
   528  		grpc.WithStatsHandler(state),
   529  	}
   530  
   531  	if enableTLS {
   532  		tlsConfig, error := conf.TLSConfig()
   533  		if error != nil {
   534  			return nil, nil, error
   535  		}
   536  
   537  		dialOptions = append(dialOptions, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
   538  	} else {
   539  		dialOptions = append(dialOptions, grpc.WithTransportCredentials(insecure.NewCredentials()))
   540  	}
   541  
   542  	var callOptions = []grpc.CallOption{}
   543  
   544  	// Zero is the default
   545  	if conf.MaxRecvSize != 0 {
   546  		callOptions = append(callOptions, grpc.MaxCallRecvMsgSize(conf.MaxRecvSize))
   547  	}
   548  
   549  	if conf.MaxSendSize != 0 {
   550  		callOptions = append(callOptions, grpc.MaxCallSendMsgSize(conf.MaxSendSize))
   551  	}
   552  
   553  	if len(callOptions) > 0 {
   554  		dialOptions = append(dialOptions, grpc.WithDefaultCallOptions(callOptions...))
   555  	}
   556  
   557  	conn, err := grpc.Dial(
   558  		host,
   559  		dialOptions...,
   560  	)
   561  
   562  	if err != nil {
   563  		return nil, nil, err
   564  	}
   565  
   566  	client := pb.NewRPCClient(conn)
   567  	state.conn = conn
   568  
   569  	return client, state, nil
   570  }
   571  
   572  func (c *Controller) refreshMetrics() {
   573  	if c.clientState.SupportsActiveConns() {
   574  		c.metrics.GaugeSet(metricsGRPCActiveConns, uint64(c.clientState.ActiveConns()))
   575  	}
   576  
   577  	if c.barrier.HasDynamicCapacity() {
   578  		c.metrics.GaugeSet(metricsRPCCapacity, uint64(c.barrier.Capacity()))
   579  	}
   580  
   581  	c.timerMu.Lock()
   582  	defer c.timerMu.Unlock()
   583  
   584  	c.metricsTimer = time.AfterFunc(refreshMetricsInterval, c.refreshMetrics)
   585  }