github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/rpc/context.go (about) 1 // Copyright 2015 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package rpc 12 13 import ( 14 "bytes" 15 "context" 16 "encoding/binary" 17 "fmt" 18 "io" 19 "math" 20 "net" 21 "sync" 22 "sync/atomic" 23 "time" 24 25 circuit "github.com/cockroachdb/circuitbreaker" 26 "github.com/cockroachdb/cockroach/pkg/base" 27 "github.com/cockroachdb/cockroach/pkg/roachpb" 28 "github.com/cockroachdb/cockroach/pkg/security" 29 "github.com/cockroachdb/cockroach/pkg/settings/cluster" 30 "github.com/cockroachdb/cockroach/pkg/util/contextutil" 31 "github.com/cockroachdb/cockroach/pkg/util/envutil" 32 "github.com/cockroachdb/cockroach/pkg/util/growstack" 33 "github.com/cockroachdb/cockroach/pkg/util/grpcutil" 34 "github.com/cockroachdb/cockroach/pkg/util/hlc" 35 "github.com/cockroachdb/cockroach/pkg/util/log" 36 "github.com/cockroachdb/cockroach/pkg/util/netutil" 37 "github.com/cockroachdb/cockroach/pkg/util/stop" 38 "github.com/cockroachdb/cockroach/pkg/util/syncutil" 39 "github.com/cockroachdb/cockroach/pkg/util/timeutil" 40 "github.com/cockroachdb/cockroach/pkg/util/tracing" 41 "github.com/cockroachdb/errors" 42 "github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc" 43 opentracing "github.com/opentracing/opentracing-go" 44 "golang.org/x/sync/syncmap" 45 "google.golang.org/grpc" 46 "google.golang.org/grpc/backoff" 47 "google.golang.org/grpc/credentials" 48 "google.golang.org/grpc/encoding" 49 encodingproto "google.golang.org/grpc/encoding/proto" 50 "google.golang.org/grpc/metadata" 51 "google.golang.org/grpc/peer" 52 ) 53 54 func init() { 55 // Disable GRPC tracing. This retains a subset of messages for 56 // display on /debug/requests, which is very expensive for 57 // snapshots. Until we can be more selective about what is retained 58 // in traces, we must disable tracing entirely. 59 // https://github.com/grpc/grpc-go/issues/695 60 grpc.EnableTracing = false 61 } 62 63 const ( 64 // The coefficient by which the maximum offset is multiplied to determine the 65 // maximum acceptable measurement latency. 66 maximumPingDurationMult = 2 67 ) 68 69 const ( 70 defaultWindowSize = 65535 71 initialWindowSize = defaultWindowSize * 32 // for an RPC 72 initialConnWindowSize = initialWindowSize * 16 // for a connection 73 ) 74 75 // sourceAddr is the environment-provided local address for outgoing 76 // connections. 77 var sourceAddr = func() net.Addr { 78 const envKey = "COCKROACH_SOURCE_IP_ADDRESS" 79 if sourceAddr, ok := envutil.EnvString(envKey, 0); ok { 80 sourceIP := net.ParseIP(sourceAddr) 81 if sourceIP == nil { 82 panic(fmt.Sprintf("unable to parse %s '%s' as IP address", envKey, sourceAddr)) 83 } 84 return &net.TCPAddr{ 85 IP: sourceIP, 86 } 87 } 88 return nil 89 }() 90 91 var enableRPCCompression = envutil.EnvOrDefaultBool("COCKROACH_ENABLE_RPC_COMPRESSION", true) 92 93 // spanInclusionFuncForServer is used as a SpanInclusionFunc for the server-side 94 // of RPCs, deciding for which operations the gRPC opentracing interceptor should 95 // create a span. 96 func spanInclusionFuncForServer( 97 t *tracing.Tracer, parentSpanCtx opentracing.SpanContext, method string, req, resp interface{}, 98 ) bool { 99 // Is client tracing? 100 return (parentSpanCtx != nil && !tracing.IsNoopContext(parentSpanCtx)) || 101 // Should we trace regardless of the client? This is useful for calls coming 102 // through the HTTP->RPC gateway (i.e. the AdminUI), where client is never 103 // tracing. 104 t.AlwaysTrace() 105 } 106 107 // spanInclusionFuncForClient is used as a SpanInclusionFunc for the client-side 108 // of RPCs, deciding for which operations the gRPC opentracing interceptor should 109 // create a span. 110 func spanInclusionFuncForClient( 111 parentSpanCtx opentracing.SpanContext, method string, req, resp interface{}, 112 ) bool { 113 return parentSpanCtx != nil && !tracing.IsNoopContext(parentSpanCtx) 114 } 115 116 func requireSuperUser(ctx context.Context) error { 117 // TODO(marc): grpc's authentication model (which gives credential access in 118 // the request handler) doesn't really fit with the current design of the 119 // security package (which assumes that TLS state is only given at connection 120 // time) - that should be fixed. 121 if grpcutil.IsLocalRequestContext(ctx) { 122 // This is an in-process request. Bypass authentication check. 123 } else if peer, ok := peer.FromContext(ctx); ok { 124 if tlsInfo, ok := peer.AuthInfo.(credentials.TLSInfo); ok { 125 certUsers, err := security.GetCertificateUsers(&tlsInfo.State) 126 if err != nil { 127 return err 128 } 129 // TODO(benesch): the vast majority of RPCs should be limited to just 130 // NodeUser. This is not a security concern, as RootUser has access to 131 // read and write all data, merely good hygiene. For example, there is 132 // no reason to permit the root user to send raw Raft RPCs. 133 if !security.ContainsUser(security.NodeUser, certUsers) && 134 !security.ContainsUser(security.RootUser, certUsers) { 135 return errors.Errorf("user %s is not allowed to perform this RPC", certUsers) 136 } 137 } 138 } else { 139 return errors.New("internal authentication error: TLSInfo is not available in request context") 140 } 141 return nil 142 } 143 144 // NewServer is a thin wrapper around grpc.NewServer that registers a heartbeat 145 // service. 146 func NewServer(ctx *Context) *grpc.Server { 147 return NewServerWithInterceptor(ctx, nil) 148 } 149 150 // NewServerWithInterceptor is like NewServer, but accepts an additional 151 // interceptor which is called before streaming and unary RPCs and may inject an 152 // error. 153 func NewServerWithInterceptor( 154 ctx *Context, interceptor func(fullMethod string) error, 155 ) *grpc.Server { 156 opts := []grpc.ServerOption{ 157 // The limiting factor for lowering the max message size is the fact 158 // that a single large kv can be sent over the network in one message. 159 // Our maximum kv size is unlimited, so we need this to be very large. 160 // 161 // TODO(peter,tamird): need tests before lowering. 162 grpc.MaxRecvMsgSize(math.MaxInt32), 163 grpc.MaxSendMsgSize(math.MaxInt32), 164 // Adjust the stream and connection window sizes. The gRPC defaults are too 165 // low for high latency connections. 166 grpc.InitialWindowSize(initialWindowSize), 167 grpc.InitialConnWindowSize(initialConnWindowSize), 168 // The default number of concurrent streams/requests on a client connection 169 // is 100, while the server is unlimited. The client setting can only be 170 // controlled by adjusting the server value. Set a very large value for the 171 // server value so that we have no fixed limit on the number of concurrent 172 // streams/requests on either the client or server. 173 grpc.MaxConcurrentStreams(math.MaxInt32), 174 grpc.KeepaliveParams(serverKeepalive), 175 grpc.KeepaliveEnforcementPolicy(serverEnforcement), 176 // A stats handler to measure server network stats. 177 grpc.StatsHandler(&ctx.stats), 178 } 179 if !ctx.Insecure { 180 tlsConfig, err := ctx.GetServerTLSConfig() 181 if err != nil { 182 panic(err) 183 } 184 opts = append(opts, grpc.Creds(credentials.NewTLS(tlsConfig))) 185 } 186 187 var unaryInterceptor grpc.UnaryServerInterceptor 188 var streamInterceptor grpc.StreamServerInterceptor 189 190 if tracer := ctx.AmbientCtx.Tracer; tracer != nil { 191 // We use a SpanInclusionFunc to save a bit of unnecessary work when 192 // tracing is disabled. 193 unaryInterceptor = otgrpc.OpenTracingServerInterceptor( 194 tracer, 195 otgrpc.IncludingSpans(otgrpc.SpanInclusionFunc( 196 func( 197 parentSpanCtx opentracing.SpanContext, 198 method string, 199 req, resp interface{}) bool { 200 // This anonymous func serves to bind the tracer for 201 // spanInclusionFuncForServer. 202 return spanInclusionFuncForServer( 203 tracer.(*tracing.Tracer), parentSpanCtx, method, req, resp) 204 })), 205 ) 206 // TODO(tschottdorf): should set up tracing for stream-based RPCs as 207 // well. The otgrpc package has no such facility, but there's also this: 208 // 209 // https://github.com/grpc-ecosystem/go-grpc-middleware/tree/master/tracing/opentracing 210 } 211 212 // TODO(tschottdorf): when setting up the interceptors below, could make the 213 // functions a wee bit more performant by hoisting some of the nil checks 214 // out. Doubt measurements can tell the difference though. 215 216 if interceptor != nil { 217 prevUnaryInterceptor := unaryInterceptor 218 unaryInterceptor = func( 219 ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, 220 ) (interface{}, error) { 221 if err := interceptor(info.FullMethod); err != nil { 222 return nil, err 223 } 224 if prevUnaryInterceptor != nil { 225 return prevUnaryInterceptor(ctx, req, info, handler) 226 } 227 return handler(ctx, req) 228 } 229 } 230 231 if interceptor != nil { 232 prevStreamInterceptor := streamInterceptor 233 streamInterceptor = func( 234 srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler, 235 ) error { 236 if err := interceptor(info.FullMethod); err != nil { 237 return err 238 } 239 if prevStreamInterceptor != nil { 240 return prevStreamInterceptor(srv, stream, info, handler) 241 } 242 return handler(srv, stream) 243 } 244 } 245 246 if !ctx.Insecure { 247 prevUnaryInterceptor := unaryInterceptor 248 unaryInterceptor = func( 249 ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, 250 ) (interface{}, error) { 251 if err := requireSuperUser(ctx); err != nil { 252 return nil, err 253 } 254 if prevUnaryInterceptor != nil { 255 return prevUnaryInterceptor(ctx, req, info, handler) 256 } 257 return handler(ctx, req) 258 } 259 prevStreamInterceptor := streamInterceptor 260 streamInterceptor = func( 261 srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler, 262 ) error { 263 if err := requireSuperUser(stream.Context()); err != nil { 264 return err 265 } 266 if prevStreamInterceptor != nil { 267 return prevStreamInterceptor(srv, stream, info, handler) 268 } 269 return handler(srv, stream) 270 } 271 } 272 273 if unaryInterceptor != nil { 274 opts = append(opts, grpc.UnaryInterceptor(unaryInterceptor)) 275 } 276 if streamInterceptor != nil { 277 opts = append(opts, grpc.StreamInterceptor(streamInterceptor)) 278 } 279 280 s := grpc.NewServer(opts...) 281 RegisterHeartbeatServer(s, &HeartbeatService{ 282 clock: ctx.LocalClock, 283 remoteClockMonitor: ctx.RemoteClocks, 284 clusterName: ctx.clusterName, 285 disableClusterNameVerification: ctx.disableClusterNameVerification, 286 clusterID: &ctx.ClusterID, 287 nodeID: &ctx.NodeID, 288 settings: ctx.settings, 289 testingAllowNamedRPCToAnonymousServer: ctx.TestingAllowNamedRPCToAnonymousServer, 290 }) 291 return s 292 } 293 294 type heartbeatResult struct { 295 everSucceeded bool // true if the heartbeat has ever succeeded 296 err error // heartbeat error, initialized to ErrNotHeartbeated 297 } 298 299 // state is a helper to return the heartbeatState implied by a heartbeatResult. 300 func (hr heartbeatResult) state() (s heartbeatState) { 301 switch { 302 case !hr.everSucceeded && hr.err != nil: 303 s = heartbeatInitializing 304 case hr.everSucceeded && hr.err == nil: 305 s = heartbeatNominal 306 case hr.everSucceeded && hr.err != nil: 307 s = heartbeatFailed 308 } 309 return s 310 } 311 312 // Connection is a wrapper around grpc.ClientConn. It prevents the underlying 313 // connection from being used until it has been validated via heartbeat. 314 type Connection struct { 315 grpcConn *grpc.ClientConn 316 dialErr error // error while dialing; if set, connection is unusable 317 heartbeatResult atomic.Value // result of latest heartbeat 318 initialHeartbeatDone chan struct{} // closed after first heartbeat 319 stopper *stop.Stopper 320 321 // remoteNodeID implies checking the remote node ID. 0 when unknown, 322 // non-zero to check with remote node. This is constant throughout 323 // the lifetime of a Connection object. 324 remoteNodeID roachpb.NodeID 325 326 initOnce sync.Once 327 } 328 329 func newConnectionToNodeID(stopper *stop.Stopper, remoteNodeID roachpb.NodeID) *Connection { 330 c := &Connection{ 331 initialHeartbeatDone: make(chan struct{}), 332 stopper: stopper, 333 remoteNodeID: remoteNodeID, 334 } 335 c.heartbeatResult.Store(heartbeatResult{err: ErrNotHeartbeated}) 336 return c 337 } 338 339 // Connect returns the underlying grpc.ClientConn after it has been validated, 340 // or an error if dialing or validation fails. 341 func (c *Connection) Connect(ctx context.Context) (*grpc.ClientConn, error) { 342 if c.dialErr != nil { 343 return nil, c.dialErr 344 } 345 346 // Wait for initial heartbeat. 347 select { 348 case <-c.initialHeartbeatDone: 349 case <-c.stopper.ShouldStop(): 350 return nil, errors.Errorf("stopped") 351 case <-ctx.Done(): 352 return nil, ctx.Err() 353 } 354 355 // If connection is invalid, return latest heartbeat error. 356 h := c.heartbeatResult.Load().(heartbeatResult) 357 if !h.everSucceeded { 358 // If we've never succeeded, h.err will be ErrNotHeartbeated. 359 return nil, netutil.NewInitialHeartBeatFailedError(h.err) 360 } 361 return c.grpcConn, nil 362 } 363 364 // Health returns an error indicating the success or failure of the 365 // connection's latest heartbeat. Returns ErrNotHeartbeated if the 366 // first heartbeat has not completed. 367 func (c *Connection) Health() error { 368 return c.heartbeatResult.Load().(heartbeatResult).err 369 } 370 371 // Context contains the fields required by the rpc framework. 372 type Context struct { 373 *base.Config 374 375 AmbientCtx log.AmbientContext 376 LocalClock *hlc.Clock 377 breakerClock breakerClock 378 Stopper *stop.Stopper 379 RemoteClocks *RemoteClockMonitor 380 masterCtx context.Context 381 382 heartbeatInterval time.Duration 383 heartbeatTimeout time.Duration 384 HeartbeatCB func() 385 386 rpcCompression bool 387 388 localInternalClient roachpb.InternalClient 389 390 conns syncmap.Map 391 392 stats StatsHandler 393 394 ClusterID base.ClusterIDContainer 395 NodeID base.NodeIDContainer 396 settings *cluster.Settings 397 398 clusterName string 399 disableClusterNameVerification bool 400 401 metrics Metrics 402 403 // For unittesting. 404 BreakerFactory func() *circuit.Breaker 405 testingDialOpts []grpc.DialOption 406 testingKnobs ContextTestingKnobs 407 408 // For testing. See the comment on the same field in HeartbeatService. 409 TestingAllowNamedRPCToAnonymousServer bool 410 } 411 412 // connKey is used as key in the Context.conns map. 413 // Connections which carry a different class but share a target and nodeID 414 // will always specify distinct connections. Different remote node IDs get 415 // distinct *Connection objects to ensure that we don't mis-route RPC 416 // requests in the face of address reuse. Gossip connections and other 417 // non-Internal users of the Context are free to dial nodes without 418 // specifying a node ID (see GRPCUnvalidatedDial()) however later calls to 419 // Dial with the same target and class with a node ID will create a new 420 // underlying connection. The inverse however is not true, a connection 421 // dialed without a node ID will use an existing connection to a matching 422 // (targetAddr, class) pair. 423 type connKey struct { 424 targetAddr string 425 nodeID roachpb.NodeID 426 class ConnectionClass 427 } 428 429 // NewContext creates an rpc Context with the supplied values. 430 func NewContext( 431 ambient log.AmbientContext, 432 baseCtx *base.Config, 433 hlcClock *hlc.Clock, 434 stopper *stop.Stopper, 435 st *cluster.Settings, 436 ) *Context { 437 return NewContextWithTestingKnobs(ambient, baseCtx, hlcClock, stopper, st, 438 ContextTestingKnobs{}) 439 } 440 441 // NewContextWithTestingKnobs creates an rpc Context with the supplied values. 442 func NewContextWithTestingKnobs( 443 ambient log.AmbientContext, 444 baseCtx *base.Config, 445 hlcClock *hlc.Clock, 446 stopper *stop.Stopper, 447 st *cluster.Settings, 448 knobs ContextTestingKnobs, 449 ) *Context { 450 if hlcClock == nil { 451 panic("nil clock is forbidden") 452 } 453 ctx := &Context{ 454 AmbientCtx: ambient, 455 Config: baseCtx, 456 LocalClock: hlcClock, 457 breakerClock: breakerClock{ 458 clock: hlcClock, 459 }, 460 rpcCompression: enableRPCCompression, 461 settings: st, 462 clusterName: baseCtx.ClusterName, 463 disableClusterNameVerification: baseCtx.DisableClusterNameVerification, 464 testingKnobs: knobs, 465 } 466 var cancel context.CancelFunc 467 ctx.masterCtx, cancel = context.WithCancel(ambient.AnnotateCtx(context.Background())) 468 ctx.Stopper = stopper 469 ctx.heartbeatInterval = baseCtx.RPCHeartbeatInterval 470 ctx.RemoteClocks = newRemoteClockMonitor( 471 ctx.LocalClock, 10*ctx.heartbeatInterval, baseCtx.HistogramWindowInterval()) 472 ctx.heartbeatTimeout = 2 * ctx.heartbeatInterval 473 ctx.metrics = makeMetrics() 474 475 stopper.RunWorker(ctx.masterCtx, func(context.Context) { 476 <-stopper.ShouldQuiesce() 477 478 cancel() 479 ctx.conns.Range(func(k, v interface{}) bool { 480 conn := v.(*Connection) 481 conn.initOnce.Do(func() { 482 // Make sure initialization is not in progress when we're removing the 483 // conn. We need to set the error in case we win the race against the 484 // real initialization code. 485 if conn.dialErr == nil { 486 conn.dialErr = &roachpb.NodeUnavailableError{} 487 } 488 }) 489 ctx.removeConn(conn, k.(connKey)) 490 return true 491 }) 492 }) 493 if knobs.ClusterID != nil { 494 ctx.ClusterID.Set(ctx.masterCtx, *knobs.ClusterID) 495 } 496 return ctx 497 } 498 499 // ClusterName retrieves the configured cluster name. 500 func (ctx *Context) ClusterName() string { 501 if ctx == nil { 502 // This is used in tests. 503 return "<MISSING RPC CONTEXT>" 504 } 505 return ctx.clusterName 506 } 507 508 // GetStatsMap returns a map of network statistics maintained by the 509 // internal stats handler. The map is from the remote network address 510 // (in string form) to an rpc.Stats object. 511 func (ctx *Context) GetStatsMap() *syncmap.Map { 512 return &ctx.stats.stats 513 } 514 515 // Metrics returns the Context's Metrics struct. 516 func (ctx *Context) Metrics() *Metrics { 517 return &ctx.metrics 518 } 519 520 // GetLocalInternalClientForAddr returns the context's internal batch client 521 // for target, if it exists. 522 func (ctx *Context) GetLocalInternalClientForAddr( 523 target string, nodeID roachpb.NodeID, 524 ) roachpb.InternalClient { 525 if target == ctx.AdvertiseAddr && nodeID == ctx.NodeID.Get() { 526 return ctx.localInternalClient 527 } 528 return nil 529 } 530 531 type internalClientAdapter struct { 532 roachpb.InternalServer 533 } 534 535 func (a internalClientAdapter) Batch( 536 ctx context.Context, ba *roachpb.BatchRequest, _ ...grpc.CallOption, 537 ) (*roachpb.BatchResponse, error) { 538 return a.InternalServer.Batch(ctx, ba) 539 } 540 541 type rangeFeedClientAdapter struct { 542 ctx context.Context 543 eventC chan *roachpb.RangeFeedEvent 544 errC chan error 545 } 546 547 // roachpb.Internal_RangeFeedServer methods. 548 func (a rangeFeedClientAdapter) Recv() (*roachpb.RangeFeedEvent, error) { 549 // Prioritize eventC. Both channels are buffered and the only guarantee we 550 // have is that once an error is sent on errC no other events will be sent 551 // on eventC again. 552 select { 553 case e := <-a.eventC: 554 return e, nil 555 case err := <-a.errC: 556 select { 557 case e := <-a.eventC: 558 a.errC <- err 559 return e, nil 560 default: 561 return nil, err 562 } 563 } 564 } 565 566 // roachpb.Internal_RangeFeedServer methods. 567 func (a rangeFeedClientAdapter) Send(e *roachpb.RangeFeedEvent) error { 568 select { 569 case a.eventC <- e: 570 return nil 571 case <-a.ctx.Done(): 572 return a.ctx.Err() 573 } 574 } 575 576 // grpc.ClientStream methods. 577 func (rangeFeedClientAdapter) Header() (metadata.MD, error) { panic("unimplemented") } 578 func (rangeFeedClientAdapter) Trailer() metadata.MD { panic("unimplemented") } 579 func (rangeFeedClientAdapter) CloseSend() error { panic("unimplemented") } 580 581 // grpc.ServerStream methods. 582 func (rangeFeedClientAdapter) SetHeader(metadata.MD) error { panic("unimplemented") } 583 func (rangeFeedClientAdapter) SendHeader(metadata.MD) error { panic("unimplemented") } 584 func (rangeFeedClientAdapter) SetTrailer(metadata.MD) { panic("unimplemented") } 585 586 // grpc.Stream methods. 587 func (a rangeFeedClientAdapter) Context() context.Context { return a.ctx } 588 func (rangeFeedClientAdapter) SendMsg(m interface{}) error { panic("unimplemented") } 589 func (rangeFeedClientAdapter) RecvMsg(m interface{}) error { panic("unimplemented") } 590 591 var _ roachpb.Internal_RangeFeedClient = rangeFeedClientAdapter{} 592 var _ roachpb.Internal_RangeFeedServer = rangeFeedClientAdapter{} 593 594 func (a internalClientAdapter) RangeFeed( 595 ctx context.Context, args *roachpb.RangeFeedRequest, _ ...grpc.CallOption, 596 ) (roachpb.Internal_RangeFeedClient, error) { 597 ctx, cancel := context.WithCancel(ctx) 598 rfAdapter := rangeFeedClientAdapter{ 599 ctx: ctx, 600 eventC: make(chan *roachpb.RangeFeedEvent, 128), 601 errC: make(chan error, 1), 602 } 603 604 go func() { 605 defer cancel() 606 err := a.InternalServer.RangeFeed(args, rfAdapter) 607 if err == nil { 608 err = io.EOF 609 } 610 rfAdapter.errC <- err 611 }() 612 613 return rfAdapter, nil 614 } 615 616 var _ roachpb.InternalClient = internalClientAdapter{} 617 618 // IsLocal returns true if the given InternalClient is local. 619 func IsLocal(iface roachpb.InternalClient) bool { 620 _, ok := iface.(internalClientAdapter) 621 return ok // internalClientAdapter is used for local connections. 622 } 623 624 // SetLocalInternalServer sets the context's local internal batch server. 625 func (ctx *Context) SetLocalInternalServer(internalServer roachpb.InternalServer) { 626 ctx.localInternalClient = internalClientAdapter{internalServer} 627 } 628 629 // removeConn removes the given connection from the pool. The supplied connKeys 630 // must represent *all* the keys under among which the connection was shared. 631 func (ctx *Context) removeConn(conn *Connection, keys ...connKey) { 632 for _, key := range keys { 633 ctx.conns.Delete(key) 634 } 635 if log.V(1) { 636 log.Infof(ctx.masterCtx, "closing %+v", keys) 637 } 638 if grpcConn := conn.grpcConn; grpcConn != nil { 639 if err := grpcConn.Close(); err != nil && !grpcutil.IsClosedConnection(err) { 640 if log.V(1) { 641 log.Errorf(ctx.masterCtx, "failed to close client connection: %v", err) 642 } 643 } 644 } 645 } 646 647 // GRPCDialOptions returns the minimal `grpc.DialOption`s necessary to connect 648 // to a server created with `NewServer`. 649 // 650 // At the time of writing, this is being used for making net.Pipe-based 651 // connections, so only those options that affect semantics are included. In 652 // particular, performance tuning options are omitted. Decompression is 653 // necessarily included to support compression-enabled servers, and compression 654 // is included for symmetry. These choices are admittedly subjective. 655 func (ctx *Context) GRPCDialOptions() ([]grpc.DialOption, error) { 656 return ctx.grpcDialOptions("", DefaultClass) 657 } 658 659 // grpcDialOptions extends GRPCDialOptions to support a connection class for use 660 // with TestingKnobs. 661 func (ctx *Context) grpcDialOptions( 662 target string, class ConnectionClass, 663 ) ([]grpc.DialOption, error) { 664 var dialOpts []grpc.DialOption 665 if ctx.Insecure { 666 dialOpts = append(dialOpts, grpc.WithInsecure()) 667 } else { 668 tlsConfig, err := ctx.GetClientTLSConfig() 669 if err != nil { 670 return nil, err 671 } 672 dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) 673 } 674 675 // The limiting factor for lowering the max message size is the fact 676 // that a single large kv can be sent over the network in one message. 677 // Our maximum kv size is unlimited, so we need this to be very large. 678 // 679 // TODO(peter,tamird): need tests before lowering. 680 dialOpts = append(dialOpts, grpc.WithDefaultCallOptions( 681 grpc.MaxCallRecvMsgSize(math.MaxInt32), 682 grpc.MaxCallSendMsgSize(math.MaxInt32), 683 )) 684 685 // Compression is enabled separately from decompression to allow staged 686 // rollout. 687 if ctx.rpcCompression { 688 dialOpts = append(dialOpts, grpc.WithDefaultCallOptions(grpc.UseCompressor((snappyCompressor{}).Name()))) 689 } 690 691 var unaryInterceptors []grpc.UnaryClientInterceptor 692 693 if tracer := ctx.AmbientCtx.Tracer; tracer != nil { 694 unaryInterceptors = append(unaryInterceptors, 695 otgrpc.OpenTracingClientInterceptor(tracer, 696 // We use a SpanInclusionFunc to circumvent the interceptor's work when 697 // tracing is disabled. Otherwise, the interceptor causes an increase in 698 // the number of packets (even with an empty context!). See #17177. 699 otgrpc.IncludingSpans(otgrpc.SpanInclusionFunc(spanInclusionFuncForClient)), 700 // We use a decorator to set the "node" tag. All other spans get the 701 // node tag from context log tags. 702 // 703 // Unfortunately we cannot use the corresponding interceptor on the 704 // server-side of gRPC to set this tag on server spans because that 705 // interceptor runs too late - after a traced RPC's recording had 706 // already been collected. So, on the server-side, the equivalent code 707 // is in setupSpanForIncomingRPC(). 708 otgrpc.SpanDecorator(func(span opentracing.Span, _ string, _, _ interface{}, _ error) { 709 span.SetTag("node", ctx.NodeID.String()) 710 }))) 711 } 712 if ctx.testingKnobs.UnaryClientInterceptor != nil { 713 testingUnaryInterceptor := ctx.testingKnobs.UnaryClientInterceptor(target, class) 714 if testingUnaryInterceptor != nil { 715 unaryInterceptors = append(unaryInterceptors, testingUnaryInterceptor) 716 } 717 } 718 dialOpts = append(dialOpts, grpc.WithChainUnaryInterceptor(unaryInterceptors...)) 719 if ctx.testingKnobs.StreamClientInterceptor != nil { 720 testingStreamInterceptor := ctx.testingKnobs.StreamClientInterceptor(target, class) 721 if testingStreamInterceptor != nil { 722 dialOpts = append(dialOpts, grpc.WithStreamInterceptor(testingStreamInterceptor)) 723 } 724 } 725 return dialOpts, nil 726 } 727 728 // growStackCodec wraps the default grpc/encoding/proto codec to detect 729 // BatchRequest rpcs and grow the stack prior to Unmarshaling. 730 type growStackCodec struct { 731 encoding.Codec 732 } 733 734 // Unmarshal detects BatchRequests and calls growstack.Grow before calling 735 // through to the underlying codec. 736 func (c growStackCodec) Unmarshal(data []byte, v interface{}) error { 737 if _, ok := v.(*roachpb.BatchRequest); ok { 738 growstack.Grow() 739 } 740 return c.Codec.Unmarshal(data, v) 741 } 742 743 // Install the growStackCodec over the default proto codec in order to grow the 744 // stack for BatchRequest RPCs prior to unmarshaling. 745 func init() { 746 protoCodec := encoding.GetCodec(encodingproto.Name) 747 encoding.RegisterCodec(growStackCodec{Codec: protoCodec}) 748 } 749 750 // onlyOnceDialer implements the grpc.WithDialer interface but only 751 // allows a single connection attempt. If a reconnection is attempted, 752 // redialChan is closed to signal a higher-level retry loop. This 753 // ensures that our initial heartbeat (and its version/clusterID 754 // validation) occurs on every new connection. 755 type onlyOnceDialer struct { 756 syncutil.Mutex 757 dialed bool 758 closed bool 759 redialChan chan struct{} 760 } 761 762 func (ood *onlyOnceDialer) dial(ctx context.Context, addr string) (net.Conn, error) { 763 ood.Lock() 764 defer ood.Unlock() 765 if !ood.dialed { 766 ood.dialed = true 767 dialer := net.Dialer{ 768 LocalAddr: sourceAddr, 769 } 770 return dialer.DialContext(ctx, "tcp", addr) 771 } else if !ood.closed { 772 ood.closed = true 773 close(ood.redialChan) 774 } 775 return nil, grpcutil.ErrCannotReuseClientConn 776 } 777 778 type dialerFunc func(context.Context, string) (net.Conn, error) 779 780 type artificialLatencyDialer struct { 781 dialerFunc dialerFunc 782 latencyMS int 783 } 784 785 func (ald *artificialLatencyDialer) dial(ctx context.Context, addr string) (net.Conn, error) { 786 conn, err := ald.dialerFunc(ctx, addr) 787 if err != nil { 788 return conn, err 789 } 790 return delayingConn{ 791 Conn: conn, 792 latency: time.Duration(ald.latencyMS) * time.Millisecond, 793 readBuf: new(bytes.Buffer), 794 }, nil 795 } 796 797 type delayingListener struct { 798 net.Listener 799 } 800 801 // NewDelayingListener creates a net.Listener that introduces a set delay on its connections. 802 func NewDelayingListener(l net.Listener) net.Listener { 803 return delayingListener{Listener: l} 804 } 805 806 func (d delayingListener) Accept() (net.Conn, error) { 807 c, err := d.Listener.Accept() 808 if err != nil { 809 return nil, err 810 } 811 return delayingConn{ 812 Conn: c, 813 // Put a default latency as the server's conn. This value will get populated 814 // as packets are exchanged across the delayingConnections. 815 latency: time.Duration(0) * time.Millisecond, 816 readBuf: new(bytes.Buffer), 817 }, nil 818 } 819 820 type delayingConn struct { 821 net.Conn 822 latency time.Duration 823 lastSendEnd time.Time 824 readBuf *bytes.Buffer 825 } 826 827 func (d delayingConn) Write(b []byte) (n int, err error) { 828 tNow := timeutil.Now() 829 if d.lastSendEnd.Before(tNow) { 830 d.lastSendEnd = tNow 831 } 832 hdr := delayingHeader{ 833 Magic: magic, 834 ReadTime: d.lastSendEnd.Add(d.latency).UnixNano(), 835 Sz: int32(len(b)), 836 DelayMS: int32(d.latency / time.Millisecond), 837 } 838 if err := binary.Write(d.Conn, binary.BigEndian, hdr); err != nil { 839 return n, err 840 } 841 x, err := d.Conn.Write(b) 842 n += x 843 return n, err 844 } 845 846 func (d delayingConn) Read(b []byte) (n int, err error) { 847 if d.readBuf.Len() == 0 { 848 var hdr delayingHeader 849 if err := binary.Read(d.Conn, binary.BigEndian, &hdr); err != nil { 850 return 0, err 851 } 852 // If we somehow don't get our expected magic, throw an error. 853 if hdr.Magic != magic { 854 panic(errors.New("didn't get expected magic bytes header")) 855 // TODO (rohany): I can't get this to work. I suspect that the problem 856 // is with that maybe the improperly parsed struct is not written back 857 // into the same binary format that it was read as. I tried this with sending 858 // the magic integer over first and saw the same thing. 859 } else { 860 d.latency = time.Duration(hdr.DelayMS) * time.Millisecond 861 defer func() { 862 time.Sleep(timeutil.Until(timeutil.Unix(0, hdr.ReadTime))) 863 }() 864 if _, err := io.CopyN(d.readBuf, d.Conn, int64(hdr.Sz)); err != nil { 865 return 0, err 866 } 867 } 868 } 869 return d.readBuf.Read(b) 870 } 871 872 const magic = 0xfeedfeed 873 874 type delayingHeader struct { 875 Magic int64 876 ReadTime int64 877 Sz int32 878 DelayMS int32 879 } 880 881 // GRPCDialRaw calls grpc.Dial with options appropriate for the context. 882 // Unlike GRPCDialNode, it does not start an RPC heartbeat to validate the 883 // connection. This connection will not be reconnected automatically; 884 // the returned channel is closed when a reconnection is attempted. 885 // This method implies a DefaultClass ConnectionClass for the returned 886 // ClientConn. 887 func (ctx *Context) GRPCDialRaw(target string) (*grpc.ClientConn, <-chan struct{}, error) { 888 return ctx.grpcDialRaw(target, 0, DefaultClass) 889 } 890 891 func (ctx *Context) grpcDialRaw( 892 target string, remoteNodeID roachpb.NodeID, class ConnectionClass, 893 ) (*grpc.ClientConn, <-chan struct{}, error) { 894 dialOpts, err := ctx.grpcDialOptions(target, class) 895 if err != nil { 896 return nil, nil, err 897 } 898 899 // Add a stats handler to measure client network stats. 900 dialOpts = append(dialOpts, grpc.WithStatsHandler(ctx.stats.newClient(target))) 901 902 // Lower the MaxBackoff (which defaults to ~minutes) to something in the 903 // ~second range. 904 backoffConfig := backoff.DefaultConfig 905 backoffConfig.MaxDelay = maxBackoff 906 dialOpts = append(dialOpts, grpc.WithConnectParams(grpc.ConnectParams{Backoff: backoffConfig})) 907 dialOpts = append(dialOpts, grpc.WithKeepaliveParams(clientKeepalive)) 908 dialOpts = append(dialOpts, 909 grpc.WithInitialWindowSize(initialWindowSize), 910 grpc.WithInitialConnWindowSize(initialConnWindowSize)) 911 912 dialer := onlyOnceDialer{ 913 redialChan: make(chan struct{}), 914 } 915 dialerFunc := dialer.dial 916 if ctx.testingKnobs.ArtificialLatencyMap != nil { 917 latency := ctx.testingKnobs.ArtificialLatencyMap[target] 918 log.VEventf(ctx.masterCtx, 1, "Connecting to node %s (%d) with simulated latency %dms", target, remoteNodeID, 919 latency) 920 dialer := artificialLatencyDialer{ 921 dialerFunc: dialerFunc, 922 latencyMS: latency, 923 } 924 dialerFunc = dialer.dial 925 } 926 dialOpts = append(dialOpts, grpc.WithContextDialer(dialerFunc)) 927 928 // add testingDialOpts after our dialer because one of our tests 929 // uses a custom dialer (this disables the only-one-connection 930 // behavior and redialChan will never be closed). 931 dialOpts = append(dialOpts, ctx.testingDialOpts...) 932 933 if log.V(1) { 934 log.Infof(ctx.masterCtx, "dialing %s", target) 935 } 936 conn, err := grpc.DialContext(ctx.masterCtx, target, dialOpts...) 937 return conn, dialer.redialChan, err 938 } 939 940 // GRPCUnvalidatedDial uses GRPCDialNode and disables validation of the 941 // node ID between client and server. This function should only be 942 // used with the gossip client and CLI commands which can talk to any 943 // node. This method implies a SystemClass. 944 func (ctx *Context) GRPCUnvalidatedDial(target string) *Connection { 945 return ctx.grpcDialNodeInternal(target, 0, SystemClass) 946 } 947 948 // GRPCDialNode calls grpc.Dial with options appropriate for the 949 // context and class (see the comment on ConnectionClass). 950 // 951 // The remoteNodeID becomes a constraint on the expected node ID of 952 // the remote node; this is checked during heartbeats. The caller is 953 // responsible for ensuring the remote node ID is known prior to using 954 // this function. 955 func (ctx *Context) GRPCDialNode( 956 target string, remoteNodeID roachpb.NodeID, class ConnectionClass, 957 ) *Connection { 958 if remoteNodeID == 0 && !ctx.TestingAllowNamedRPCToAnonymousServer { 959 log.Fatalf(context.TODO(), "invalid node ID 0 in GRPCDialNode()") 960 } 961 return ctx.grpcDialNodeInternal(target, remoteNodeID, class) 962 } 963 964 func (ctx *Context) grpcDialNodeInternal( 965 target string, remoteNodeID roachpb.NodeID, class ConnectionClass, 966 ) *Connection { 967 thisConnKeys := []connKey{{target, remoteNodeID, class}} 968 value, ok := ctx.conns.Load(thisConnKeys[0]) 969 if !ok { 970 value, _ = ctx.conns.LoadOrStore(thisConnKeys[0], newConnectionToNodeID(ctx.Stopper, remoteNodeID)) 971 if remoteNodeID != 0 { 972 // If the first connection established at a target address is 973 // for a specific node ID, then we want to reuse that connection 974 // also for other dials (eg for gossip) which don't require a 975 // specific node ID. (We do this as an optimization to reduce 976 // the number of TCP connections alive between nodes. This is 977 // not strictly required for correctness.) This LoadOrStore will 978 // ensure we're registering the connection we just created for 979 // future use by these other dials. 980 // 981 // We need to be careful to unregister both connKeys when the 982 // connection breaks. Otherwise, we leak the entry below which 983 // "simulates" a hard network partition for anyone dialing without 984 // the nodeID (gossip). 985 // 986 // See: 987 // https://github.com/cockroachdb/cockroach/issues/37200 988 otherKey := connKey{target, 0, class} 989 if _, loaded := ctx.conns.LoadOrStore(otherKey, value); !loaded { 990 thisConnKeys = append(thisConnKeys, otherKey) 991 } 992 } 993 } 994 995 conn := value.(*Connection) 996 conn.initOnce.Do(func() { 997 // Either we kick off the heartbeat loop (and clean up when it's done), 998 // or we clean up the connKey entries immediately. 999 var redialChan <-chan struct{} 1000 conn.grpcConn, redialChan, conn.dialErr = ctx.grpcDialRaw(target, remoteNodeID, class) 1001 if conn.dialErr == nil { 1002 if err := ctx.Stopper.RunTask( 1003 ctx.masterCtx, "rpc.Context: grpc heartbeat", func(masterCtx context.Context) { 1004 ctx.Stopper.RunWorker(masterCtx, func(masterCtx context.Context) { 1005 err := ctx.runHeartbeat(conn, target, redialChan) 1006 if err != nil && !grpcutil.IsClosedConnection(err) { 1007 log.Errorf(masterCtx, "removing connection to %s due to error: %s", target, err) 1008 } 1009 ctx.removeConn(conn, thisConnKeys...) 1010 }) 1011 }); err != nil { 1012 conn.dialErr = err 1013 } 1014 } 1015 if conn.dialErr != nil { 1016 ctx.removeConn(conn, thisConnKeys...) 1017 } 1018 }) 1019 1020 return conn 1021 } 1022 1023 // NewBreaker creates a new circuit breaker properly configured for RPC 1024 // connections. name is used internally for logging state changes of the 1025 // returned breaker. 1026 func (ctx *Context) NewBreaker(name string) *circuit.Breaker { 1027 if ctx.BreakerFactory != nil { 1028 return ctx.BreakerFactory() 1029 } 1030 return newBreaker(ctx.masterCtx, name, &ctx.breakerClock) 1031 } 1032 1033 // ErrNotHeartbeated is returned by ConnHealth when we have not yet performed 1034 // the first heartbeat. 1035 var ErrNotHeartbeated = errors.New("not yet heartbeated") 1036 1037 func (ctx *Context) runHeartbeat( 1038 conn *Connection, target string, redialChan <-chan struct{}, 1039 ) (retErr error) { 1040 ctx.metrics.HeartbeatLoopsStarted.Inc(1) 1041 // setInitialHeartbeatDone is idempotent and is critical to notify Connect 1042 // callers of the failure in the case where no heartbeat is ever sent. 1043 state := updateHeartbeatState(&ctx.metrics, heartbeatNotRunning, heartbeatInitializing) 1044 initialHeartbeatDone := false 1045 setInitialHeartbeatDone := func() { 1046 if !initialHeartbeatDone { 1047 close(conn.initialHeartbeatDone) 1048 initialHeartbeatDone = true 1049 } 1050 } 1051 defer func() { 1052 if retErr != nil { 1053 ctx.metrics.HeartbeatLoopsExited.Inc(1) 1054 } 1055 updateHeartbeatState(&ctx.metrics, state, heartbeatNotRunning) 1056 setInitialHeartbeatDone() 1057 }() 1058 maxOffset := ctx.LocalClock.MaxOffset() 1059 maxOffsetNanos := maxOffset.Nanoseconds() 1060 1061 heartbeatClient := NewHeartbeatClient(conn.grpcConn) 1062 1063 var heartbeatTimer timeutil.Timer 1064 defer heartbeatTimer.Stop() 1065 1066 // Give the first iteration a wait-free heartbeat attempt. 1067 heartbeatTimer.Reset(0) 1068 everSucceeded := false 1069 for { 1070 select { 1071 case <-redialChan: 1072 return grpcutil.ErrCannotReuseClientConn 1073 case <-ctx.Stopper.ShouldQuiesce(): 1074 return nil 1075 case <-heartbeatTimer.C: 1076 heartbeatTimer.Read = true 1077 } 1078 1079 if err := ctx.Stopper.RunTaskWithErr(ctx.masterCtx, "rpc heartbeat", func(goCtx context.Context) error { 1080 // We re-mint the PingRequest to pick up any asynchronous update to clusterID. 1081 clusterID := ctx.ClusterID.Get() 1082 request := &PingRequest{ 1083 Addr: ctx.Addr, 1084 MaxOffsetNanos: maxOffsetNanos, 1085 ClusterID: &clusterID, 1086 NodeID: conn.remoteNodeID, 1087 ServerVersion: ctx.settings.Version.BinaryVersion(), 1088 } 1089 1090 var response *PingResponse 1091 sendTime := ctx.LocalClock.PhysicalTime() 1092 ping := func(goCtx context.Context) (err error) { 1093 // NB: We want the request to fail-fast (the default), otherwise we won't 1094 // be notified of transport failures. 1095 response, err = heartbeatClient.Ping(goCtx, request) 1096 return err 1097 } 1098 var err error 1099 if ctx.heartbeatTimeout > 0 { 1100 err = contextutil.RunWithTimeout(goCtx, "rpc heartbeat", ctx.heartbeatTimeout, ping) 1101 } else { 1102 err = ping(goCtx) 1103 } 1104 1105 if err == nil { 1106 // We verify the cluster name on the initiator side (instead 1107 // of the hearbeat service side, as done for the cluster ID 1108 // and node ID checks) so that the operator who is starting a 1109 // new node in a cluster and mistakenly joins the wrong 1110 // cluster gets a chance to see the error message on their 1111 // management console. 1112 if !ctx.disableClusterNameVerification && !response.DisableClusterNameVerification { 1113 err = errors.Wrap( 1114 checkClusterName(ctx.clusterName, response.ClusterName), 1115 "cluster name check failed on ping response") 1116 } 1117 } 1118 1119 if err == nil { 1120 err = errors.Wrap( 1121 checkVersion(goCtx, ctx.settings, response.ServerVersion), 1122 "version compatibility check failed on ping response") 1123 } 1124 1125 if err == nil { 1126 everSucceeded = true 1127 receiveTime := ctx.LocalClock.PhysicalTime() 1128 1129 // Only update the clock offset measurement if we actually got a 1130 // successful response from the server. 1131 pingDuration := receiveTime.Sub(sendTime) 1132 maxOffset := ctx.LocalClock.MaxOffset() 1133 if pingDuration > maximumPingDurationMult*maxOffset { 1134 request.Offset.Reset() 1135 } else { 1136 // Offset and error are measured using the remote clock reading 1137 // technique described in 1138 // http://se.inf.tu-dresden.de/pubs/papers/SRDS1994.pdf, page 6. 1139 // However, we assume that drift and min message delay are 0, for 1140 // now. 1141 request.Offset.MeasuredAt = receiveTime.UnixNano() 1142 request.Offset.Uncertainty = (pingDuration / 2).Nanoseconds() 1143 remoteTimeNow := timeutil.Unix(0, response.ServerTime).Add(pingDuration / 2) 1144 request.Offset.Offset = remoteTimeNow.Sub(receiveTime).Nanoseconds() 1145 } 1146 ctx.RemoteClocks.UpdateOffset(ctx.masterCtx, target, request.Offset, pingDuration) 1147 1148 if cb := ctx.HeartbeatCB; cb != nil { 1149 cb() 1150 } 1151 } 1152 1153 hr := heartbeatResult{ 1154 everSucceeded: everSucceeded, 1155 err: err, 1156 } 1157 state = updateHeartbeatState(&ctx.metrics, state, hr.state()) 1158 conn.heartbeatResult.Store(hr) 1159 setInitialHeartbeatDone() 1160 return nil 1161 }); err != nil { 1162 return err 1163 } 1164 1165 heartbeatTimer.Reset(ctx.heartbeatInterval) 1166 } 1167 }