github.com/anycable/anycable-go@v1.5.1/rpc/rpc.go (about) 1 package rpc 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "log/slog" 8 "math" 9 "sync" 10 "sync/atomic" 11 "time" 12 13 "github.com/anycable/anycable-go/common" 14 "github.com/anycable/anycable-go/metrics" 15 "github.com/anycable/anycable-go/protocol" 16 "github.com/anycable/anycable-go/utils" 17 "github.com/joomcode/errorx" 18 19 pb "github.com/anycable/anycable-go/protos" 20 "google.golang.org/grpc" 21 "google.golang.org/grpc/codes" 22 "google.golang.org/grpc/connectivity" 23 "google.golang.org/grpc/credentials" 24 "google.golang.org/grpc/credentials/insecure" 25 "google.golang.org/grpc/keepalive" 26 "google.golang.org/grpc/metadata" 27 "google.golang.org/grpc/peer" 28 "google.golang.org/grpc/stats" 29 "google.golang.org/grpc/status" 30 ) 31 32 const ( 33 // ProtoVersions contains a comma-seprated list of compatible RPC protos versions 34 // (we pass it as request meta to notify clients) 35 ProtoVersions = "v1" 36 invokeTimeout = 3000 37 38 retryExhaustedInterval = 10 39 retryUnavailableInterval = 100 40 41 refreshMetricsInterval = time.Duration(10) * time.Second 42 43 metricsRPCCalls = "rpc_call_total" 44 metricsRPCRetries = "rpc_retries_total" 45 metricsRPCFailures = "rpc_error_total" 46 metricsRPCPending = "rpc_pending_num" 47 metricsRPCCapacity = "rpc_capacity_num" 48 metricsGRPCActiveConns = "grpc_active_conn_num" 49 50 secretKeyPhrase = "rpc-cable" 51 ) 52 53 type grpcClientHelper struct { 54 conn *grpc.ClientConn 55 recovering bool 56 mu sync.Mutex 57 58 log *slog.Logger 59 active int64 60 } 61 62 // Returns nil if connection in the READY/IDLE/CONNECTING state. 63 // If connection is in the TransientFailure state, we try to re-connect immediately 64 // once. 65 // See https://github.com/grpc/grpc/blob/master/doc/connectivity-semantics-and-api.md 66 // and https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md 67 // See also https://github.com/cockroachdb/cockroach/blob/master/pkg/util/grpcutil/grpc_util.go 68 func (st *grpcClientHelper) Ready() error { 69 s := st.conn.GetState() 70 71 if s == connectivity.Shutdown { 72 return errors.New("grpc connection is closed") 73 } 74 75 if s == connectivity.TransientFailure { 76 return st.tryRecover() 77 } 78 79 if st.recovering { 80 st.reset() 81 } 82 83 return nil 84 } 85 86 func (st *grpcClientHelper) Close() { 87 st.conn.Close() 88 } 89 90 func (st *grpcClientHelper) ActiveConns() int { 91 return int(atomic.LoadInt64(&st.active)) 92 } 93 94 func (st *grpcClientHelper) SupportsActiveConns() bool { 95 return true 96 } 97 98 func (st *grpcClientHelper) HandleConn(ctx context.Context, stat stats.ConnStats) { 99 var addr string 100 101 if p, ok := peer.FromContext(ctx); ok { 102 addr = p.Addr.String() 103 } 104 105 if _, ok := stat.(*stats.ConnBegin); ok { 106 st.log.Debug("connected", "addr", addr) 107 atomic.AddInt64(&st.active, 1) 108 } 109 110 if _, ok := stat.(*stats.ConnEnd); ok { 111 st.log.Debug("disconnected", "addr", addr) 112 atomic.AddInt64(&st.active, -1) 113 } 114 } 115 116 func (st *grpcClientHelper) HandleRPC(ctx context.Context, stat stats.RPCStats) { 117 // no-op 118 } 119 120 func (st *grpcClientHelper) TagConn(ctx context.Context, stat *stats.ConnTagInfo) context.Context { 121 return ctx 122 } 123 124 func (st *grpcClientHelper) TagRPC(ctx context.Context, stat *stats.RPCTagInfo) context.Context { 125 return ctx 126 } 127 128 func (st *grpcClientHelper) tryRecover() error { 129 st.mu.Lock() 130 defer st.mu.Unlock() 131 132 if st.recovering { 133 return errors.New("grpc connection is not ready") 134 } 135 136 st.recovering = true 137 st.conn.ResetConnectBackoff() 138 139 st.log.Warn("connection is lost, trying to reconnect immediately") 140 141 return nil 142 } 143 144 func (st *grpcClientHelper) reset() { 145 st.mu.Lock() 146 defer st.mu.Unlock() 147 148 if st.recovering { 149 st.recovering = false 150 st.log.Info("connection is restored") 151 } 152 } 153 154 // Controller implements node.Controller interface for gRPC 155 type Controller struct { 156 config *Config 157 barrier Barrier 158 client pb.RPCClient 159 metrics metrics.Instrumenter 160 log *slog.Logger 161 clientState ClientHelper 162 163 timerMu sync.Mutex 164 metricsTimer *time.Timer 165 } 166 167 // NewController builds new Controller 168 func NewController(metrics metrics.Instrumenter, config *Config, l *slog.Logger) (*Controller, error) { 169 metrics.RegisterCounter(metricsRPCCalls, "The total number of RPC calls") 170 metrics.RegisterCounter(metricsRPCRetries, "The total number of RPC call retries") 171 metrics.RegisterCounter(metricsRPCFailures, "The total number of failed RPC calls") 172 metrics.RegisterGauge(metricsRPCPending, "The number of pending RPC calls") 173 174 capacity := config.Concurrency 175 if capacity <= 0 { 176 capacity = defaultRPCConcurrency 177 l.Warn("RPC concurrency must be positive, reverted to the default value") 178 } 179 barrier, err := NewFixedSizeBarrier(capacity) 180 181 if err != nil { 182 return nil, err 183 } 184 185 if barrier.HasDynamicCapacity() { 186 metrics.RegisterGauge(metricsRPCCapacity, "The max number of concurrent RPC calls allowed") 187 metrics.GaugeSet(metricsRPCCapacity, uint64(barrier.Capacity())) 188 } 189 190 if config.Impl() == "grpc" { 191 metrics.RegisterGauge(metricsGRPCActiveConns, "The number of active HTTP connections used by gRPC") 192 } 193 194 return &Controller{log: l.With("context", "rpc"), metrics: metrics, config: config, barrier: barrier}, nil 195 } 196 197 // Start initializes RPC connection pool 198 func (c *Controller) Start() error { 199 host := c.config.Host 200 enableTLS := c.config.TLSEnabled() 201 impl := c.config.Impl() 202 203 dialer := c.config.DialFun 204 205 if dialer == nil { 206 switch impl { 207 case "http": 208 var err error 209 210 if c.config.Secret == "" && c.config.SecretBase != "" { 211 secret, verr := utils.NewMessageVerifier(c.config.SecretBase).Sign([]byte(secretKeyPhrase)) 212 213 if verr != nil { 214 verr = errorx.Decorate(verr, "failed to auto-generate authentication key for HTTP RPC") 215 return verr 216 } 217 218 c.log.Info("auto-generated authorization secret from the application secret") 219 c.config.Secret = string(secret) 220 } 221 222 dialer, err = NewHTTPDialer(c.config) 223 if err != nil { 224 return err 225 } 226 case "grpc": 227 dialer = defaultDialer 228 default: 229 return fmt.Errorf("unknown RPC implementation: %s", impl) 230 } 231 } 232 233 client, state, err := dialer(c.config, c.log) 234 235 if err == nil { 236 c.log.Info(fmt.Sprintf("RPC controller initialized: %s (concurrency: %s, impl: %s, enable_tls: %t, proto_versions: %s)", host, c.barrier.CapacityInfo(), impl, enableTLS, ProtoVersions)) 237 } else { 238 return err 239 } 240 241 c.client = client 242 c.clientState = state 243 244 if c.barrier.HasDynamicCapacity() || state.SupportsActiveConns() { 245 c.metricsTimer = time.AfterFunc(refreshMetricsInterval, c.refreshMetrics) 246 } 247 248 c.barrier.Start() 249 250 return nil 251 } 252 253 // Shutdown closes connections 254 func (c *Controller) Shutdown() error { 255 if c.clientState == nil { 256 return nil 257 } 258 259 c.timerMu.Lock() 260 if c.metricsTimer != nil { 261 c.metricsTimer.Stop() 262 } 263 c.timerMu.Unlock() 264 265 defer c.clientState.Close() 266 267 busy := c.busy() 268 269 if busy > 0 { 270 c.log.Info("waiting for active RPC calls to finish", "num", busy) 271 } 272 273 // Wait for active connections 274 _, err := c.retry("", func() (interface{}, error) { 275 busy := c.busy() 276 277 if busy > 0 { 278 return false, fmt.Errorf("terminated while completing active RPC calls: %d", busy) 279 } 280 281 c.log.Info("all active RPC calls finished") 282 return true, nil 283 }) 284 285 c.barrier.Stop() 286 287 return err 288 } 289 290 // Authenticate performs Connect RPC call 291 func (c *Controller) Authenticate(sid string, env *common.SessionEnv) (*common.ConnectResult, error) { 292 c.metrics.GaugeIncrement(metricsRPCPending) 293 c.barrier.Acquire() 294 c.metrics.GaugeDecrement(metricsRPCPending) 295 296 defer c.barrier.Release() 297 298 op := func() (interface{}, error) { 299 return c.client.Connect( 300 newContext(sid), 301 protocol.NewConnectMessage(env), 302 ) 303 } 304 305 c.metrics.CounterIncrement(metricsRPCCalls) 306 307 response, err := c.retry(sid, op) 308 309 if err != nil { 310 c.metrics.CounterIncrement(metricsRPCFailures) 311 312 return nil, err 313 } 314 315 if r, ok := response.(*pb.ConnectionResponse); ok { 316 reply, err := protocol.ParseConnectResponse(r) 317 318 return reply, err 319 } 320 321 c.metrics.CounterIncrement(metricsRPCFailures) 322 323 return nil, errors.New("failed to deserialize connection response") 324 } 325 326 // Subscribe performs Command RPC call with "subscribe" command 327 func (c *Controller) Subscribe(sid string, env *common.SessionEnv, id string, channel string) (*common.CommandResult, error) { 328 c.metrics.GaugeIncrement(metricsRPCPending) 329 c.barrier.Acquire() 330 c.metrics.GaugeDecrement(metricsRPCPending) 331 332 defer c.barrier.Release() 333 334 op := func() (interface{}, error) { 335 return c.client.Command( 336 newContext(sid), 337 protocol.NewCommandMessage(env, "subscribe", channel, id, ""), 338 ) 339 } 340 341 response, err := c.retry(sid, op) 342 343 return c.parseCommandResponse(sid, response, err) 344 } 345 346 // Unsubscribe performs Command RPC call with "unsubscribe" command 347 func (c *Controller) Unsubscribe(sid string, env *common.SessionEnv, id string, channel string) (*common.CommandResult, error) { 348 c.metrics.GaugeIncrement(metricsRPCPending) 349 c.barrier.Acquire() 350 c.metrics.GaugeDecrement(metricsRPCPending) 351 352 defer c.barrier.Release() 353 354 op := func() (interface{}, error) { 355 return c.client.Command( 356 newContext(sid), 357 protocol.NewCommandMessage(env, "unsubscribe", channel, id, ""), 358 ) 359 } 360 361 response, err := c.retry(sid, op) 362 363 return c.parseCommandResponse(sid, response, err) 364 } 365 366 // Perform performs Command RPC call with "perform" command 367 func (c *Controller) Perform(sid string, env *common.SessionEnv, id string, channel string, data string) (*common.CommandResult, error) { 368 c.metrics.GaugeIncrement(metricsRPCPending) 369 c.barrier.Acquire() 370 c.metrics.GaugeDecrement(metricsRPCPending) 371 372 defer c.barrier.Release() 373 374 op := func() (interface{}, error) { 375 return c.client.Command( 376 newContext(sid), 377 protocol.NewCommandMessage(env, "message", channel, id, data), 378 ) 379 } 380 381 response, err := c.retry(sid, op) 382 383 return c.parseCommandResponse(sid, response, err) 384 } 385 386 // Disconnect performs disconnect RPC call 387 func (c *Controller) Disconnect(sid string, env *common.SessionEnv, id string, subscriptions []string) error { 388 c.metrics.GaugeIncrement(metricsRPCPending) 389 c.barrier.Acquire() 390 c.metrics.GaugeDecrement(metricsRPCPending) 391 392 defer c.barrier.Release() 393 394 op := func() (interface{}, error) { 395 return c.client.Disconnect( 396 newContext(sid), 397 protocol.NewDisconnectMessage(env, id, subscriptions), 398 ) 399 } 400 401 c.metrics.CounterIncrement(metricsRPCCalls) 402 403 response, err := c.retry(sid, op) 404 405 if err != nil { 406 c.metrics.CounterIncrement(metricsRPCFailures) 407 return err 408 } 409 410 if r, ok := response.(*pb.DisconnectResponse); ok { 411 err = protocol.ParseDisconnectResponse(r) 412 413 if err != nil { 414 c.metrics.CounterIncrement(metricsRPCFailures) 415 } 416 417 return err 418 } 419 420 return errors.New("failed to deserialize disconnect response") 421 } 422 423 func (c *Controller) parseCommandResponse(sid string, response interface{}, err error) (*common.CommandResult, error) { 424 c.metrics.CounterIncrement(metricsRPCCalls) 425 426 if err != nil { 427 c.metrics.CounterIncrement(metricsRPCFailures) 428 429 return nil, err 430 } 431 432 if r, ok := response.(*pb.CommandResponse); ok { 433 res, err := protocol.ParseCommandResponse(r) 434 435 return res, err 436 } 437 438 c.metrics.CounterIncrement(metricsRPCFailures) 439 440 return nil, errors.New("failed to deserialize command response") 441 } 442 443 func (c *Controller) busy() int { 444 return c.barrier.BusyCount() 445 } 446 447 func (c *Controller) retry(sid string, callback func() (interface{}, error)) (res interface{}, err error) { 448 retryAge := 0 449 attempt := 0 450 wasExhausted := false 451 452 for { 453 if stErr := c.clientState.Ready(); stErr != nil { 454 return nil, stErr 455 } 456 457 res, err = callback() 458 459 if err == nil { 460 return res, nil 461 } 462 463 if retryAge > invokeTimeout { 464 return nil, err 465 } 466 467 st, ok := status.FromError(err) 468 if !ok { 469 return nil, err 470 } 471 472 code := st.Code() 473 474 if !(code == codes.ResourceExhausted || code == codes.Unavailable) { 475 return nil, err 476 } 477 478 c.log.With("sid", sid).Debug("RPC failed", "code", st.Code(), "error", st.Message()) 479 480 interval := retryUnavailableInterval 481 482 if st.Code() == codes.ResourceExhausted { 483 interval = retryExhaustedInterval 484 if !wasExhausted { 485 attempt = 0 486 wasExhausted = true 487 } 488 c.barrier.Exhausted() 489 } else if wasExhausted { 490 wasExhausted = false 491 attempt = 0 492 } 493 494 delayMS := int(math.Pow(2, float64(attempt))) * interval 495 delay := time.Duration(delayMS) 496 497 retryAge += delayMS 498 499 c.metrics.CounterIncrement(metricsRPCRetries) 500 501 time.Sleep(delay * time.Millisecond) 502 503 attempt++ 504 } 505 } 506 507 func newContext(sessionID string) context.Context { 508 md := metadata.Pairs("sid", sessionID, "protov", ProtoVersions) 509 return metadata.NewOutgoingContext(context.Background(), md) 510 } 511 512 func defaultDialer(conf *Config, l *slog.Logger) (pb.RPCClient, ClientHelper, error) { 513 host := conf.Host 514 enableTLS := conf.TLSEnabled() 515 516 kacp := keepalive.ClientParameters{ 517 Time: 10 * time.Second, // send pings every 10 seconds if there is no activity 518 PermitWithoutStream: true, // send pings even without active streams 519 } 520 521 const grpcServiceConfig = `{"loadBalancingPolicy":"round_robin"}` 522 523 state := &grpcClientHelper{log: l.With("impl", "grpc")} 524 525 dialOptions := []grpc.DialOption{ 526 grpc.WithKeepaliveParams(kacp), 527 grpc.WithDefaultServiceConfig(grpcServiceConfig), 528 grpc.WithStatsHandler(state), 529 } 530 531 if enableTLS { 532 tlsConfig, error := conf.TLSConfig() 533 if error != nil { 534 return nil, nil, error 535 } 536 537 dialOptions = append(dialOptions, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) 538 } else { 539 dialOptions = append(dialOptions, grpc.WithTransportCredentials(insecure.NewCredentials())) 540 } 541 542 var callOptions = []grpc.CallOption{} 543 544 // Zero is the default 545 if conf.MaxRecvSize != 0 { 546 callOptions = append(callOptions, grpc.MaxCallRecvMsgSize(conf.MaxRecvSize)) 547 } 548 549 if conf.MaxSendSize != 0 { 550 callOptions = append(callOptions, grpc.MaxCallSendMsgSize(conf.MaxSendSize)) 551 } 552 553 if len(callOptions) > 0 { 554 dialOptions = append(dialOptions, grpc.WithDefaultCallOptions(callOptions...)) 555 } 556 557 conn, err := grpc.Dial( 558 host, 559 dialOptions..., 560 ) 561 562 if err != nil { 563 return nil, nil, err 564 } 565 566 client := pb.NewRPCClient(conn) 567 state.conn = conn 568 569 return client, state, nil 570 } 571 572 func (c *Controller) refreshMetrics() { 573 if c.clientState.SupportsActiveConns() { 574 c.metrics.GaugeSet(metricsGRPCActiveConns, uint64(c.clientState.ActiveConns())) 575 } 576 577 if c.barrier.HasDynamicCapacity() { 578 c.metrics.GaugeSet(metricsRPCCapacity, uint64(c.barrier.Capacity())) 579 } 580 581 c.timerMu.Lock() 582 defer c.timerMu.Unlock() 583 584 c.metricsTimer = time.AfterFunc(refreshMetricsInterval, c.refreshMetrics) 585 }