google.golang.org/grpc@v1.72.2/server.go (about) 1 /* 2 * 3 * Copyright 2014 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package grpc 20 21 import ( 22 "context" 23 "errors" 24 "fmt" 25 "io" 26 "math" 27 "net" 28 "net/http" 29 "reflect" 30 "runtime" 31 "strings" 32 "sync" 33 "sync/atomic" 34 "time" 35 36 "google.golang.org/grpc/codes" 37 "google.golang.org/grpc/credentials" 38 "google.golang.org/grpc/encoding" 39 "google.golang.org/grpc/encoding/proto" 40 estats "google.golang.org/grpc/experimental/stats" 41 "google.golang.org/grpc/grpclog" 42 "google.golang.org/grpc/internal" 43 "google.golang.org/grpc/internal/binarylog" 44 "google.golang.org/grpc/internal/channelz" 45 "google.golang.org/grpc/internal/grpcsync" 46 "google.golang.org/grpc/internal/grpcutil" 47 istats "google.golang.org/grpc/internal/stats" 48 "google.golang.org/grpc/internal/transport" 49 "google.golang.org/grpc/keepalive" 50 "google.golang.org/grpc/mem" 51 "google.golang.org/grpc/metadata" 52 "google.golang.org/grpc/peer" 53 "google.golang.org/grpc/stats" 54 "google.golang.org/grpc/status" 55 "google.golang.org/grpc/tap" 56 ) 57 58 const ( 59 defaultServerMaxReceiveMessageSize = 1024 * 1024 * 4 60 defaultServerMaxSendMessageSize = math.MaxInt32 61 62 // Server transports are tracked in a map which is keyed on listener 63 // address. For regular gRPC traffic, connections are accepted in Serve() 64 // through a call to Accept(), and we use the actual listener address as key 65 // when we add it to the map. But for connections received through 66 // ServeHTTP(), we do not have a listener and hence use this dummy value. 67 listenerAddressForServeHTTP = "listenerAddressForServeHTTP" 68 ) 69 70 func init() { 71 internal.GetServerCredentials = func(srv *Server) credentials.TransportCredentials { 72 return srv.opts.creds 73 } 74 internal.IsRegisteredMethod = func(srv *Server, method string) bool { 75 return srv.isRegisteredMethod(method) 76 } 77 internal.ServerFromContext = serverFromContext 78 internal.AddGlobalServerOptions = func(opt ...ServerOption) { 79 globalServerOptions = append(globalServerOptions, opt...) 80 } 81 internal.ClearGlobalServerOptions = func() { 82 globalServerOptions = nil 83 } 84 internal.BinaryLogger = binaryLogger 85 internal.JoinServerOptions = newJoinServerOption 86 internal.BufferPool = bufferPool 87 internal.MetricsRecorderForServer = func(srv *Server) estats.MetricsRecorder { 88 return istats.NewMetricsRecorderList(srv.opts.statsHandlers) 89 } 90 } 91 92 var statusOK = status.New(codes.OK, "") 93 var logger = grpclog.Component("core") 94 95 // MethodHandler is a function type that processes a unary RPC method call. 96 type MethodHandler func(srv any, ctx context.Context, dec func(any) error, interceptor UnaryServerInterceptor) (any, error) 97 98 // MethodDesc represents an RPC service's method specification. 99 type MethodDesc struct { 100 MethodName string 101 Handler MethodHandler 102 } 103 104 // ServiceDesc represents an RPC service's specification. 105 type ServiceDesc struct { 106 ServiceName string 107 // The pointer to the service interface. Used to check whether the user 108 // provided implementation satisfies the interface requirements. 109 HandlerType any 110 Methods []MethodDesc 111 Streams []StreamDesc 112 Metadata any 113 } 114 115 // serviceInfo wraps information about a service. It is very similar to 116 // ServiceDesc and is constructed from it for internal purposes. 117 type serviceInfo struct { 118 // Contains the implementation for the methods in this service. 119 serviceImpl any 120 methods map[string]*MethodDesc 121 streams map[string]*StreamDesc 122 mdata any 123 } 124 125 // Server is a gRPC server to serve RPC requests. 126 type Server struct { 127 opts serverOptions 128 129 mu sync.Mutex // guards following 130 lis map[net.Listener]bool 131 // conns contains all active server transports. It is a map keyed on a 132 // listener address with the value being the set of active transports 133 // belonging to that listener. 134 conns map[string]map[transport.ServerTransport]bool 135 serve bool 136 drain bool 137 cv *sync.Cond // signaled when connections close for GracefulStop 138 services map[string]*serviceInfo // service name -> service info 139 events traceEventLog 140 141 quit *grpcsync.Event 142 done *grpcsync.Event 143 channelzRemoveOnce sync.Once 144 serveWG sync.WaitGroup // counts active Serve goroutines for Stop/GracefulStop 145 handlersWG sync.WaitGroup // counts active method handler goroutines 146 147 channelz *channelz.Server 148 149 serverWorkerChannel chan func() 150 serverWorkerChannelClose func() 151 } 152 153 type serverOptions struct { 154 creds credentials.TransportCredentials 155 codec baseCodec 156 cp Compressor 157 dc Decompressor 158 unaryInt UnaryServerInterceptor 159 streamInt StreamServerInterceptor 160 chainUnaryInts []UnaryServerInterceptor 161 chainStreamInts []StreamServerInterceptor 162 binaryLogger binarylog.Logger 163 inTapHandle tap.ServerInHandle 164 statsHandlers []stats.Handler 165 maxConcurrentStreams uint32 166 maxReceiveMessageSize int 167 maxSendMessageSize int 168 unknownStreamDesc *StreamDesc 169 keepaliveParams keepalive.ServerParameters 170 keepalivePolicy keepalive.EnforcementPolicy 171 initialWindowSize int32 172 initialConnWindowSize int32 173 writeBufferSize int 174 readBufferSize int 175 sharedWriteBuffer bool 176 connectionTimeout time.Duration 177 maxHeaderListSize *uint32 178 headerTableSize *uint32 179 numServerWorkers uint32 180 bufferPool mem.BufferPool 181 waitForHandlers bool 182 } 183 184 var defaultServerOptions = serverOptions{ 185 maxConcurrentStreams: math.MaxUint32, 186 maxReceiveMessageSize: defaultServerMaxReceiveMessageSize, 187 maxSendMessageSize: defaultServerMaxSendMessageSize, 188 connectionTimeout: 120 * time.Second, 189 writeBufferSize: defaultWriteBufSize, 190 readBufferSize: defaultReadBufSize, 191 bufferPool: mem.DefaultBufferPool(), 192 } 193 var globalServerOptions []ServerOption 194 195 // A ServerOption sets options such as credentials, codec and keepalive parameters, etc. 196 type ServerOption interface { 197 apply(*serverOptions) 198 } 199 200 // EmptyServerOption does not alter the server configuration. It can be embedded 201 // in another structure to build custom server options. 202 // 203 // # Experimental 204 // 205 // Notice: This type is EXPERIMENTAL and may be changed or removed in a 206 // later release. 207 type EmptyServerOption struct{} 208 209 func (EmptyServerOption) apply(*serverOptions) {} 210 211 // funcServerOption wraps a function that modifies serverOptions into an 212 // implementation of the ServerOption interface. 213 type funcServerOption struct { 214 f func(*serverOptions) 215 } 216 217 func (fdo *funcServerOption) apply(do *serverOptions) { 218 fdo.f(do) 219 } 220 221 func newFuncServerOption(f func(*serverOptions)) *funcServerOption { 222 return &funcServerOption{ 223 f: f, 224 } 225 } 226 227 // joinServerOption provides a way to combine arbitrary number of server 228 // options into one. 229 type joinServerOption struct { 230 opts []ServerOption 231 } 232 233 func (mdo *joinServerOption) apply(do *serverOptions) { 234 for _, opt := range mdo.opts { 235 opt.apply(do) 236 } 237 } 238 239 func newJoinServerOption(opts ...ServerOption) ServerOption { 240 return &joinServerOption{opts: opts} 241 } 242 243 // SharedWriteBuffer allows reusing per-connection transport write buffer. 244 // If this option is set to true every connection will release the buffer after 245 // flushing the data on the wire. 246 // 247 // # Experimental 248 // 249 // Notice: This API is EXPERIMENTAL and may be changed or removed in a 250 // later release. 251 func SharedWriteBuffer(val bool) ServerOption { 252 return newFuncServerOption(func(o *serverOptions) { 253 o.sharedWriteBuffer = val 254 }) 255 } 256 257 // WriteBufferSize determines how much data can be batched before doing a write 258 // on the wire. The default value for this buffer is 32KB. Zero or negative 259 // values will disable the write buffer such that each write will be on underlying 260 // connection. Note: A Send call may not directly translate to a write. 261 func WriteBufferSize(s int) ServerOption { 262 return newFuncServerOption(func(o *serverOptions) { 263 o.writeBufferSize = s 264 }) 265 } 266 267 // ReadBufferSize lets you set the size of read buffer, this determines how much 268 // data can be read at most for one read syscall. The default value for this 269 // buffer is 32KB. Zero or negative values will disable read buffer for a 270 // connection so data framer can access the underlying conn directly. 271 func ReadBufferSize(s int) ServerOption { 272 return newFuncServerOption(func(o *serverOptions) { 273 o.readBufferSize = s 274 }) 275 } 276 277 // InitialWindowSize returns a ServerOption that sets window size for stream. 278 // The lower bound for window size is 64K and any value smaller than that will be ignored. 279 func InitialWindowSize(s int32) ServerOption { 280 return newFuncServerOption(func(o *serverOptions) { 281 o.initialWindowSize = s 282 }) 283 } 284 285 // InitialConnWindowSize returns a ServerOption that sets window size for a connection. 286 // The lower bound for window size is 64K and any value smaller than that will be ignored. 287 func InitialConnWindowSize(s int32) ServerOption { 288 return newFuncServerOption(func(o *serverOptions) { 289 o.initialConnWindowSize = s 290 }) 291 } 292 293 // KeepaliveParams returns a ServerOption that sets keepalive and max-age parameters for the server. 294 func KeepaliveParams(kp keepalive.ServerParameters) ServerOption { 295 if kp.Time > 0 && kp.Time < internal.KeepaliveMinServerPingTime { 296 logger.Warning("Adjusting keepalive ping interval to minimum period of 1s") 297 kp.Time = internal.KeepaliveMinServerPingTime 298 } 299 300 return newFuncServerOption(func(o *serverOptions) { 301 o.keepaliveParams = kp 302 }) 303 } 304 305 // KeepaliveEnforcementPolicy returns a ServerOption that sets keepalive enforcement policy for the server. 306 func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption { 307 return newFuncServerOption(func(o *serverOptions) { 308 o.keepalivePolicy = kep 309 }) 310 } 311 312 // CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling. 313 // 314 // This will override any lookups by content-subtype for Codecs registered with RegisterCodec. 315 // 316 // Deprecated: register codecs using encoding.RegisterCodec. The server will 317 // automatically use registered codecs based on the incoming requests' headers. 318 // See also 319 // https://github.com/grpc/grpc-go/blob/master/Documentation/encoding.md#using-a-codec. 320 // Will be supported throughout 1.x. 321 func CustomCodec(codec Codec) ServerOption { 322 return newFuncServerOption(func(o *serverOptions) { 323 o.codec = newCodecV0Bridge(codec) 324 }) 325 } 326 327 // ForceServerCodec returns a ServerOption that sets a codec for message 328 // marshaling and unmarshaling. 329 // 330 // This will override any lookups by content-subtype for Codecs registered 331 // with RegisterCodec. 332 // 333 // See Content-Type on 334 // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for 335 // more details. Also see the documentation on RegisterCodec and 336 // CallContentSubtype for more details on the interaction between encoding.Codec 337 // and content-subtype. 338 // 339 // This function is provided for advanced users; prefer to register codecs 340 // using encoding.RegisterCodec. 341 // The server will automatically use registered codecs based on the incoming 342 // requests' headers. See also 343 // https://github.com/grpc/grpc-go/blob/master/Documentation/encoding.md#using-a-codec. 344 // Will be supported throughout 1.x. 345 // 346 // # Experimental 347 // 348 // Notice: This API is EXPERIMENTAL and may be changed or removed in a 349 // later release. 350 func ForceServerCodec(codec encoding.Codec) ServerOption { 351 return newFuncServerOption(func(o *serverOptions) { 352 o.codec = newCodecV1Bridge(codec) 353 }) 354 } 355 356 // ForceServerCodecV2 is the equivalent of ForceServerCodec, but for the new 357 // CodecV2 interface. 358 // 359 // Will be supported throughout 1.x. 360 // 361 // # Experimental 362 // 363 // Notice: This API is EXPERIMENTAL and may be changed or removed in a 364 // later release. 365 func ForceServerCodecV2(codecV2 encoding.CodecV2) ServerOption { 366 return newFuncServerOption(func(o *serverOptions) { 367 o.codec = codecV2 368 }) 369 } 370 371 // RPCCompressor returns a ServerOption that sets a compressor for outbound 372 // messages. For backward compatibility, all outbound messages will be sent 373 // using this compressor, regardless of incoming message compression. By 374 // default, server messages will be sent using the same compressor with which 375 // request messages were sent. 376 // 377 // Deprecated: use encoding.RegisterCompressor instead. Will be supported 378 // throughout 1.x. 379 func RPCCompressor(cp Compressor) ServerOption { 380 return newFuncServerOption(func(o *serverOptions) { 381 o.cp = cp 382 }) 383 } 384 385 // RPCDecompressor returns a ServerOption that sets a decompressor for inbound 386 // messages. It has higher priority than decompressors registered via 387 // encoding.RegisterCompressor. 388 // 389 // Deprecated: use encoding.RegisterCompressor instead. Will be supported 390 // throughout 1.x. 391 func RPCDecompressor(dc Decompressor) ServerOption { 392 return newFuncServerOption(func(o *serverOptions) { 393 o.dc = dc 394 }) 395 } 396 397 // MaxMsgSize returns a ServerOption to set the max message size in bytes the server can receive. 398 // If this is not set, gRPC uses the default limit. 399 // 400 // Deprecated: use MaxRecvMsgSize instead. Will be supported throughout 1.x. 401 func MaxMsgSize(m int) ServerOption { 402 return MaxRecvMsgSize(m) 403 } 404 405 // MaxRecvMsgSize returns a ServerOption to set the max message size in bytes the server can receive. 406 // If this is not set, gRPC uses the default 4MB. 407 func MaxRecvMsgSize(m int) ServerOption { 408 return newFuncServerOption(func(o *serverOptions) { 409 o.maxReceiveMessageSize = m 410 }) 411 } 412 413 // MaxSendMsgSize returns a ServerOption to set the max message size in bytes the server can send. 414 // If this is not set, gRPC uses the default `math.MaxInt32`. 415 func MaxSendMsgSize(m int) ServerOption { 416 return newFuncServerOption(func(o *serverOptions) { 417 o.maxSendMessageSize = m 418 }) 419 } 420 421 // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number 422 // of concurrent streams to each ServerTransport. 423 func MaxConcurrentStreams(n uint32) ServerOption { 424 if n == 0 { 425 n = math.MaxUint32 426 } 427 return newFuncServerOption(func(o *serverOptions) { 428 o.maxConcurrentStreams = n 429 }) 430 } 431 432 // Creds returns a ServerOption that sets credentials for server connections. 433 func Creds(c credentials.TransportCredentials) ServerOption { 434 return newFuncServerOption(func(o *serverOptions) { 435 o.creds = c 436 }) 437 } 438 439 // UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the 440 // server. Only one unary interceptor can be installed. The construction of multiple 441 // interceptors (e.g., chaining) can be implemented at the caller. 442 func UnaryInterceptor(i UnaryServerInterceptor) ServerOption { 443 return newFuncServerOption(func(o *serverOptions) { 444 if o.unaryInt != nil { 445 panic("The unary server interceptor was already set and may not be reset.") 446 } 447 o.unaryInt = i 448 }) 449 } 450 451 // ChainUnaryInterceptor returns a ServerOption that specifies the chained interceptor 452 // for unary RPCs. The first interceptor will be the outer most, 453 // while the last interceptor will be the inner most wrapper around the real call. 454 // All unary interceptors added by this method will be chained. 455 func ChainUnaryInterceptor(interceptors ...UnaryServerInterceptor) ServerOption { 456 return newFuncServerOption(func(o *serverOptions) { 457 o.chainUnaryInts = append(o.chainUnaryInts, interceptors...) 458 }) 459 } 460 461 // StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the 462 // server. Only one stream interceptor can be installed. 463 func StreamInterceptor(i StreamServerInterceptor) ServerOption { 464 return newFuncServerOption(func(o *serverOptions) { 465 if o.streamInt != nil { 466 panic("The stream server interceptor was already set and may not be reset.") 467 } 468 o.streamInt = i 469 }) 470 } 471 472 // ChainStreamInterceptor returns a ServerOption that specifies the chained interceptor 473 // for streaming RPCs. The first interceptor will be the outer most, 474 // while the last interceptor will be the inner most wrapper around the real call. 475 // All stream interceptors added by this method will be chained. 476 func ChainStreamInterceptor(interceptors ...StreamServerInterceptor) ServerOption { 477 return newFuncServerOption(func(o *serverOptions) { 478 o.chainStreamInts = append(o.chainStreamInts, interceptors...) 479 }) 480 } 481 482 // InTapHandle returns a ServerOption that sets the tap handle for all the server 483 // transport to be created. Only one can be installed. 484 // 485 // # Experimental 486 // 487 // Notice: This API is EXPERIMENTAL and may be changed or removed in a 488 // later release. 489 func InTapHandle(h tap.ServerInHandle) ServerOption { 490 return newFuncServerOption(func(o *serverOptions) { 491 if o.inTapHandle != nil { 492 panic("The tap handle was already set and may not be reset.") 493 } 494 o.inTapHandle = h 495 }) 496 } 497 498 // StatsHandler returns a ServerOption that sets the stats handler for the server. 499 func StatsHandler(h stats.Handler) ServerOption { 500 return newFuncServerOption(func(o *serverOptions) { 501 if h == nil { 502 logger.Error("ignoring nil parameter in grpc.StatsHandler ServerOption") 503 // Do not allow a nil stats handler, which would otherwise cause 504 // panics. 505 return 506 } 507 o.statsHandlers = append(o.statsHandlers, h) 508 }) 509 } 510 511 // binaryLogger returns a ServerOption that can set the binary logger for the 512 // server. 513 func binaryLogger(bl binarylog.Logger) ServerOption { 514 return newFuncServerOption(func(o *serverOptions) { 515 o.binaryLogger = bl 516 }) 517 } 518 519 // UnknownServiceHandler returns a ServerOption that allows for adding a custom 520 // unknown service handler. The provided method is a bidi-streaming RPC service 521 // handler that will be invoked instead of returning the "unimplemented" gRPC 522 // error whenever a request is received for an unregistered service or method. 523 // The handling function and stream interceptor (if set) have full access to 524 // the ServerStream, including its Context. 525 func UnknownServiceHandler(streamHandler StreamHandler) ServerOption { 526 return newFuncServerOption(func(o *serverOptions) { 527 o.unknownStreamDesc = &StreamDesc{ 528 StreamName: "unknown_service_handler", 529 Handler: streamHandler, 530 // We need to assume that the users of the streamHandler will want to use both. 531 ClientStreams: true, 532 ServerStreams: true, 533 } 534 }) 535 } 536 537 // ConnectionTimeout returns a ServerOption that sets the timeout for 538 // connection establishment (up to and including HTTP/2 handshaking) for all 539 // new connections. If this is not set, the default is 120 seconds. A zero or 540 // negative value will result in an immediate timeout. 541 // 542 // # Experimental 543 // 544 // Notice: This API is EXPERIMENTAL and may be changed or removed in a 545 // later release. 546 func ConnectionTimeout(d time.Duration) ServerOption { 547 return newFuncServerOption(func(o *serverOptions) { 548 o.connectionTimeout = d 549 }) 550 } 551 552 // MaxHeaderListSizeServerOption is a ServerOption that sets the max 553 // (uncompressed) size of header list that the server is prepared to accept. 554 type MaxHeaderListSizeServerOption struct { 555 MaxHeaderListSize uint32 556 } 557 558 func (o MaxHeaderListSizeServerOption) apply(so *serverOptions) { 559 so.maxHeaderListSize = &o.MaxHeaderListSize 560 } 561 562 // MaxHeaderListSize returns a ServerOption that sets the max (uncompressed) size 563 // of header list that the server is prepared to accept. 564 func MaxHeaderListSize(s uint32) ServerOption { 565 return MaxHeaderListSizeServerOption{ 566 MaxHeaderListSize: s, 567 } 568 } 569 570 // HeaderTableSize returns a ServerOption that sets the size of dynamic 571 // header table for stream. 572 // 573 // # Experimental 574 // 575 // Notice: This API is EXPERIMENTAL and may be changed or removed in a 576 // later release. 577 func HeaderTableSize(s uint32) ServerOption { 578 return newFuncServerOption(func(o *serverOptions) { 579 o.headerTableSize = &s 580 }) 581 } 582 583 // NumStreamWorkers returns a ServerOption that sets the number of worker 584 // goroutines that should be used to process incoming streams. Setting this to 585 // zero (default) will disable workers and spawn a new goroutine for each 586 // stream. 587 // 588 // # Experimental 589 // 590 // Notice: This API is EXPERIMENTAL and may be changed or removed in a 591 // later release. 592 func NumStreamWorkers(numServerWorkers uint32) ServerOption { 593 // TODO: If/when this API gets stabilized (i.e. stream workers become the 594 // only way streams are processed), change the behavior of the zero value to 595 // a sane default. Preliminary experiments suggest that a value equal to the 596 // number of CPUs available is most performant; requires thorough testing. 597 return newFuncServerOption(func(o *serverOptions) { 598 o.numServerWorkers = numServerWorkers 599 }) 600 } 601 602 // WaitForHandlers cause Stop to wait until all outstanding method handlers have 603 // exited before returning. If false, Stop will return as soon as all 604 // connections have closed, but method handlers may still be running. By 605 // default, Stop does not wait for method handlers to return. 606 // 607 // # Experimental 608 // 609 // Notice: This API is EXPERIMENTAL and may be changed or removed in a 610 // later release. 611 func WaitForHandlers(w bool) ServerOption { 612 return newFuncServerOption(func(o *serverOptions) { 613 o.waitForHandlers = w 614 }) 615 } 616 617 func bufferPool(bufferPool mem.BufferPool) ServerOption { 618 return newFuncServerOption(func(o *serverOptions) { 619 o.bufferPool = bufferPool 620 }) 621 } 622 623 // serverWorkerResetThreshold defines how often the stack must be reset. Every 624 // N requests, by spawning a new goroutine in its place, a worker can reset its 625 // stack so that large stacks don't live in memory forever. 2^16 should allow 626 // each goroutine stack to live for at least a few seconds in a typical 627 // workload (assuming a QPS of a few thousand requests/sec). 628 const serverWorkerResetThreshold = 1 << 16 629 630 // serverWorker blocks on a *transport.ServerStream channel forever and waits 631 // for data to be fed by serveStreams. This allows multiple requests to be 632 // processed by the same goroutine, removing the need for expensive stack 633 // re-allocations (see the runtime.morestack problem [1]). 634 // 635 // [1] https://github.com/golang/go/issues/18138 636 func (s *Server) serverWorker() { 637 for completed := 0; completed < serverWorkerResetThreshold; completed++ { 638 f, ok := <-s.serverWorkerChannel 639 if !ok { 640 return 641 } 642 f() 643 } 644 go s.serverWorker() 645 } 646 647 // initServerWorkers creates worker goroutines and a channel to process incoming 648 // connections to reduce the time spent overall on runtime.morestack. 649 func (s *Server) initServerWorkers() { 650 s.serverWorkerChannel = make(chan func()) 651 s.serverWorkerChannelClose = sync.OnceFunc(func() { 652 close(s.serverWorkerChannel) 653 }) 654 for i := uint32(0); i < s.opts.numServerWorkers; i++ { 655 go s.serverWorker() 656 } 657 } 658 659 // NewServer creates a gRPC server which has no service registered and has not 660 // started to accept requests yet. 661 func NewServer(opt ...ServerOption) *Server { 662 opts := defaultServerOptions 663 for _, o := range globalServerOptions { 664 o.apply(&opts) 665 } 666 for _, o := range opt { 667 o.apply(&opts) 668 } 669 s := &Server{ 670 lis: make(map[net.Listener]bool), 671 opts: opts, 672 conns: make(map[string]map[transport.ServerTransport]bool), 673 services: make(map[string]*serviceInfo), 674 quit: grpcsync.NewEvent(), 675 done: grpcsync.NewEvent(), 676 channelz: channelz.RegisterServer(""), 677 } 678 chainUnaryServerInterceptors(s) 679 chainStreamServerInterceptors(s) 680 s.cv = sync.NewCond(&s.mu) 681 if EnableTracing { 682 _, file, line, _ := runtime.Caller(1) 683 s.events = newTraceEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line)) 684 } 685 686 if s.opts.numServerWorkers > 0 { 687 s.initServerWorkers() 688 } 689 690 channelz.Info(logger, s.channelz, "Server created") 691 return s 692 } 693 694 // printf records an event in s's event log, unless s has been stopped. 695 // REQUIRES s.mu is held. 696 func (s *Server) printf(format string, a ...any) { 697 if s.events != nil { 698 s.events.Printf(format, a...) 699 } 700 } 701 702 // errorf records an error in s's event log, unless s has been stopped. 703 // REQUIRES s.mu is held. 704 func (s *Server) errorf(format string, a ...any) { 705 if s.events != nil { 706 s.events.Errorf(format, a...) 707 } 708 } 709 710 // ServiceRegistrar wraps a single method that supports service registration. It 711 // enables users to pass concrete types other than grpc.Server to the service 712 // registration methods exported by the IDL generated code. 713 type ServiceRegistrar interface { 714 // RegisterService registers a service and its implementation to the 715 // concrete type implementing this interface. It may not be called 716 // once the server has started serving. 717 // desc describes the service and its methods and handlers. impl is the 718 // service implementation which is passed to the method handlers. 719 RegisterService(desc *ServiceDesc, impl any) 720 } 721 722 // RegisterService registers a service and its implementation to the gRPC 723 // server. It is called from the IDL generated code. This must be called before 724 // invoking Serve. If ss is non-nil (for legacy code), its type is checked to 725 // ensure it implements sd.HandlerType. 726 func (s *Server) RegisterService(sd *ServiceDesc, ss any) { 727 if ss != nil { 728 ht := reflect.TypeOf(sd.HandlerType).Elem() 729 st := reflect.TypeOf(ss) 730 if !st.Implements(ht) { 731 logger.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht) 732 } 733 } 734 s.register(sd, ss) 735 } 736 737 func (s *Server) register(sd *ServiceDesc, ss any) { 738 s.mu.Lock() 739 defer s.mu.Unlock() 740 s.printf("RegisterService(%q)", sd.ServiceName) 741 if s.serve { 742 logger.Fatalf("grpc: Server.RegisterService after Server.Serve for %q", sd.ServiceName) 743 } 744 if _, ok := s.services[sd.ServiceName]; ok { 745 logger.Fatalf("grpc: Server.RegisterService found duplicate service registration for %q", sd.ServiceName) 746 } 747 info := &serviceInfo{ 748 serviceImpl: ss, 749 methods: make(map[string]*MethodDesc), 750 streams: make(map[string]*StreamDesc), 751 mdata: sd.Metadata, 752 } 753 for i := range sd.Methods { 754 d := &sd.Methods[i] 755 info.methods[d.MethodName] = d 756 } 757 for i := range sd.Streams { 758 d := &sd.Streams[i] 759 info.streams[d.StreamName] = d 760 } 761 s.services[sd.ServiceName] = info 762 } 763 764 // MethodInfo contains the information of an RPC including its method name and type. 765 type MethodInfo struct { 766 // Name is the method name only, without the service name or package name. 767 Name string 768 // IsClientStream indicates whether the RPC is a client streaming RPC. 769 IsClientStream bool 770 // IsServerStream indicates whether the RPC is a server streaming RPC. 771 IsServerStream bool 772 } 773 774 // ServiceInfo contains unary RPC method info, streaming RPC method info and metadata for a service. 775 type ServiceInfo struct { 776 Methods []MethodInfo 777 // Metadata is the metadata specified in ServiceDesc when registering service. 778 Metadata any 779 } 780 781 // GetServiceInfo returns a map from service names to ServiceInfo. 782 // Service names include the package names, in the form of <package>.<service>. 783 func (s *Server) GetServiceInfo() map[string]ServiceInfo { 784 ret := make(map[string]ServiceInfo) 785 for n, srv := range s.services { 786 methods := make([]MethodInfo, 0, len(srv.methods)+len(srv.streams)) 787 for m := range srv.methods { 788 methods = append(methods, MethodInfo{ 789 Name: m, 790 IsClientStream: false, 791 IsServerStream: false, 792 }) 793 } 794 for m, d := range srv.streams { 795 methods = append(methods, MethodInfo{ 796 Name: m, 797 IsClientStream: d.ClientStreams, 798 IsServerStream: d.ServerStreams, 799 }) 800 } 801 802 ret[n] = ServiceInfo{ 803 Methods: methods, 804 Metadata: srv.mdata, 805 } 806 } 807 return ret 808 } 809 810 // ErrServerStopped indicates that the operation is now illegal because of 811 // the server being stopped. 812 var ErrServerStopped = errors.New("grpc: the server has been stopped") 813 814 type listenSocket struct { 815 net.Listener 816 channelz *channelz.Socket 817 } 818 819 func (l *listenSocket) Close() error { 820 err := l.Listener.Close() 821 channelz.RemoveEntry(l.channelz.ID) 822 channelz.Info(logger, l.channelz, "ListenSocket deleted") 823 return err 824 } 825 826 // Serve accepts incoming connections on the listener lis, creating a new 827 // ServerTransport and service goroutine for each. The service goroutines 828 // read gRPC requests and then call the registered handlers to reply to them. 829 // Serve returns when lis.Accept fails with fatal errors. lis will be closed when 830 // this method returns. 831 // Serve will return a non-nil error unless Stop or GracefulStop is called. 832 // 833 // Note: All supported releases of Go (as of December 2023) override the OS 834 // defaults for TCP keepalive time and interval to 15s. To enable TCP keepalive 835 // with OS defaults for keepalive time and interval, callers need to do the 836 // following two things: 837 // - pass a net.Listener created by calling the Listen method on a 838 // net.ListenConfig with the `KeepAlive` field set to a negative value. This 839 // will result in the Go standard library not overriding OS defaults for TCP 840 // keepalive interval and time. But this will also result in the Go standard 841 // library not enabling TCP keepalives by default. 842 // - override the Accept method on the passed in net.Listener and set the 843 // SO_KEEPALIVE socket option to enable TCP keepalives, with OS defaults. 844 func (s *Server) Serve(lis net.Listener) error { 845 s.mu.Lock() 846 s.printf("serving") 847 s.serve = true 848 if s.lis == nil { 849 // Serve called after Stop or GracefulStop. 850 s.mu.Unlock() 851 lis.Close() 852 return ErrServerStopped 853 } 854 855 s.serveWG.Add(1) 856 defer func() { 857 s.serveWG.Done() 858 if s.quit.HasFired() { 859 // Stop or GracefulStop called; block until done and return nil. 860 <-s.done.Done() 861 } 862 }() 863 864 ls := &listenSocket{ 865 Listener: lis, 866 channelz: channelz.RegisterSocket(&channelz.Socket{ 867 SocketType: channelz.SocketTypeListen, 868 Parent: s.channelz, 869 RefName: lis.Addr().String(), 870 LocalAddr: lis.Addr(), 871 SocketOptions: channelz.GetSocketOption(lis)}, 872 ), 873 } 874 s.lis[ls] = true 875 876 defer func() { 877 s.mu.Lock() 878 if s.lis != nil && s.lis[ls] { 879 ls.Close() 880 delete(s.lis, ls) 881 } 882 s.mu.Unlock() 883 }() 884 885 s.mu.Unlock() 886 channelz.Info(logger, ls.channelz, "ListenSocket created") 887 888 var tempDelay time.Duration // how long to sleep on accept failure 889 for { 890 rawConn, err := lis.Accept() 891 if err != nil { 892 if ne, ok := err.(interface { 893 Temporary() bool 894 }); ok && ne.Temporary() { 895 if tempDelay == 0 { 896 tempDelay = 5 * time.Millisecond 897 } else { 898 tempDelay *= 2 899 } 900 if max := 1 * time.Second; tempDelay > max { 901 tempDelay = max 902 } 903 s.mu.Lock() 904 s.printf("Accept error: %v; retrying in %v", err, tempDelay) 905 s.mu.Unlock() 906 timer := time.NewTimer(tempDelay) 907 select { 908 case <-timer.C: 909 case <-s.quit.Done(): 910 timer.Stop() 911 return nil 912 } 913 continue 914 } 915 s.mu.Lock() 916 s.printf("done serving; Accept = %v", err) 917 s.mu.Unlock() 918 919 if s.quit.HasFired() { 920 return nil 921 } 922 return err 923 } 924 tempDelay = 0 925 // Start a new goroutine to deal with rawConn so we don't stall this Accept 926 // loop goroutine. 927 // 928 // Make sure we account for the goroutine so GracefulStop doesn't nil out 929 // s.conns before this conn can be added. 930 s.serveWG.Add(1) 931 go func() { 932 s.handleRawConn(lis.Addr().String(), rawConn) 933 s.serveWG.Done() 934 }() 935 } 936 } 937 938 // handleRawConn forks a goroutine to handle a just-accepted connection that 939 // has not had any I/O performed on it yet. 940 func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) { 941 if s.quit.HasFired() { 942 rawConn.Close() 943 return 944 } 945 rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout)) 946 947 // Finish handshaking (HTTP2) 948 st := s.newHTTP2Transport(rawConn) 949 rawConn.SetDeadline(time.Time{}) 950 if st == nil { 951 return 952 } 953 954 if cc, ok := rawConn.(interface { 955 PassServerTransport(transport.ServerTransport) 956 }); ok { 957 cc.PassServerTransport(st) 958 } 959 960 if !s.addConn(lisAddr, st) { 961 return 962 } 963 go func() { 964 s.serveStreams(context.Background(), st, rawConn) 965 s.removeConn(lisAddr, st) 966 }() 967 } 968 969 // newHTTP2Transport sets up a http/2 transport (using the 970 // gRPC http2 server transport in transport/http2_server.go). 971 func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport { 972 config := &transport.ServerConfig{ 973 MaxStreams: s.opts.maxConcurrentStreams, 974 ConnectionTimeout: s.opts.connectionTimeout, 975 Credentials: s.opts.creds, 976 InTapHandle: s.opts.inTapHandle, 977 StatsHandlers: s.opts.statsHandlers, 978 KeepaliveParams: s.opts.keepaliveParams, 979 KeepalivePolicy: s.opts.keepalivePolicy, 980 InitialWindowSize: s.opts.initialWindowSize, 981 InitialConnWindowSize: s.opts.initialConnWindowSize, 982 WriteBufferSize: s.opts.writeBufferSize, 983 ReadBufferSize: s.opts.readBufferSize, 984 SharedWriteBuffer: s.opts.sharedWriteBuffer, 985 ChannelzParent: s.channelz, 986 MaxHeaderListSize: s.opts.maxHeaderListSize, 987 HeaderTableSize: s.opts.headerTableSize, 988 BufferPool: s.opts.bufferPool, 989 } 990 st, err := transport.NewServerTransport(c, config) 991 if err != nil { 992 s.mu.Lock() 993 s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err) 994 s.mu.Unlock() 995 // ErrConnDispatched means that the connection was dispatched away from 996 // gRPC; those connections should be left open. 997 if err != credentials.ErrConnDispatched { 998 // Don't log on ErrConnDispatched and io.EOF to prevent log spam. 999 if err != io.EOF { 1000 channelz.Info(logger, s.channelz, "grpc: Server.Serve failed to create ServerTransport: ", err) 1001 } 1002 c.Close() 1003 } 1004 return nil 1005 } 1006 1007 return st 1008 } 1009 1010 func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, rawConn net.Conn) { 1011 ctx = transport.SetConnection(ctx, rawConn) 1012 ctx = peer.NewContext(ctx, st.Peer()) 1013 for _, sh := range s.opts.statsHandlers { 1014 ctx = sh.TagConn(ctx, &stats.ConnTagInfo{ 1015 RemoteAddr: st.Peer().Addr, 1016 LocalAddr: st.Peer().LocalAddr, 1017 }) 1018 sh.HandleConn(ctx, &stats.ConnBegin{}) 1019 } 1020 1021 defer func() { 1022 st.Close(errors.New("finished serving streams for the server transport")) 1023 for _, sh := range s.opts.statsHandlers { 1024 sh.HandleConn(ctx, &stats.ConnEnd{}) 1025 } 1026 }() 1027 1028 streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams) 1029 st.HandleStreams(ctx, func(stream *transport.ServerStream) { 1030 s.handlersWG.Add(1) 1031 streamQuota.acquire() 1032 f := func() { 1033 defer streamQuota.release() 1034 defer s.handlersWG.Done() 1035 s.handleStream(st, stream) 1036 } 1037 1038 if s.opts.numServerWorkers > 0 { 1039 select { 1040 case s.serverWorkerChannel <- f: 1041 return 1042 default: 1043 // If all stream workers are busy, fallback to the default code path. 1044 } 1045 } 1046 go f() 1047 }) 1048 } 1049 1050 var _ http.Handler = (*Server)(nil) 1051 1052 // ServeHTTP implements the Go standard library's http.Handler 1053 // interface by responding to the gRPC request r, by looking up 1054 // the requested gRPC method in the gRPC server s. 1055 // 1056 // The provided HTTP request must have arrived on an HTTP/2 1057 // connection. When using the Go standard library's server, 1058 // practically this means that the Request must also have arrived 1059 // over TLS. 1060 // 1061 // To share one port (such as 443 for https) between gRPC and an 1062 // existing http.Handler, use a root http.Handler such as: 1063 // 1064 // if r.ProtoMajor == 2 && strings.HasPrefix( 1065 // r.Header.Get("Content-Type"), "application/grpc") { 1066 // grpcServer.ServeHTTP(w, r) 1067 // } else { 1068 // yourMux.ServeHTTP(w, r) 1069 // } 1070 // 1071 // Note that ServeHTTP uses Go's HTTP/2 server implementation which is totally 1072 // separate from grpc-go's HTTP/2 server. Performance and features may vary 1073 // between the two paths. ServeHTTP does not support some gRPC features 1074 // available through grpc-go's HTTP/2 server. 1075 // 1076 // # Experimental 1077 // 1078 // Notice: This API is EXPERIMENTAL and may be changed or removed in a 1079 // later release. 1080 func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { 1081 st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers, s.opts.bufferPool) 1082 if err != nil { 1083 // Errors returned from transport.NewServerHandlerTransport have 1084 // already been written to w. 1085 return 1086 } 1087 if !s.addConn(listenerAddressForServeHTTP, st) { 1088 return 1089 } 1090 defer s.removeConn(listenerAddressForServeHTTP, st) 1091 s.serveStreams(r.Context(), st, nil) 1092 } 1093 1094 func (s *Server) addConn(addr string, st transport.ServerTransport) bool { 1095 s.mu.Lock() 1096 defer s.mu.Unlock() 1097 if s.conns == nil { 1098 st.Close(errors.New("Server.addConn called when server has already been stopped")) 1099 return false 1100 } 1101 if s.drain { 1102 // Transport added after we drained our existing conns: drain it 1103 // immediately. 1104 st.Drain("") 1105 } 1106 1107 if s.conns[addr] == nil { 1108 // Create a map entry if this is the first connection on this listener. 1109 s.conns[addr] = make(map[transport.ServerTransport]bool) 1110 } 1111 s.conns[addr][st] = true 1112 return true 1113 } 1114 1115 func (s *Server) removeConn(addr string, st transport.ServerTransport) { 1116 s.mu.Lock() 1117 defer s.mu.Unlock() 1118 1119 conns := s.conns[addr] 1120 if conns != nil { 1121 delete(conns, st) 1122 if len(conns) == 0 { 1123 // If the last connection for this address is being removed, also 1124 // remove the map entry corresponding to the address. This is used 1125 // in GracefulStop() when waiting for all connections to be closed. 1126 delete(s.conns, addr) 1127 } 1128 s.cv.Broadcast() 1129 } 1130 } 1131 1132 func (s *Server) incrCallsStarted() { 1133 s.channelz.ServerMetrics.CallsStarted.Add(1) 1134 s.channelz.ServerMetrics.LastCallStartedTimestamp.Store(time.Now().UnixNano()) 1135 } 1136 1137 func (s *Server) incrCallsSucceeded() { 1138 s.channelz.ServerMetrics.CallsSucceeded.Add(1) 1139 } 1140 1141 func (s *Server) incrCallsFailed() { 1142 s.channelz.ServerMetrics.CallsFailed.Add(1) 1143 } 1144 1145 func (s *Server) sendResponse(ctx context.Context, stream *transport.ServerStream, msg any, cp Compressor, opts *transport.WriteOptions, comp encoding.Compressor) error { 1146 data, err := encode(s.getCodec(stream.ContentSubtype()), msg) 1147 if err != nil { 1148 channelz.Error(logger, s.channelz, "grpc: server failed to encode response: ", err) 1149 return err 1150 } 1151 1152 compData, pf, err := compress(data, cp, comp, s.opts.bufferPool) 1153 if err != nil { 1154 data.Free() 1155 channelz.Error(logger, s.channelz, "grpc: server failed to compress response: ", err) 1156 return err 1157 } 1158 1159 hdr, payload := msgHeader(data, compData, pf) 1160 1161 defer func() { 1162 compData.Free() 1163 data.Free() 1164 // payload does not need to be freed here, it is either data or compData, both of 1165 // which are already freed. 1166 }() 1167 1168 dataLen := data.Len() 1169 payloadLen := payload.Len() 1170 // TODO(dfawley): should we be checking len(data) instead? 1171 if payloadLen > s.opts.maxSendMessageSize { 1172 return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", payloadLen, s.opts.maxSendMessageSize) 1173 } 1174 err = stream.Write(hdr, payload, opts) 1175 if err == nil { 1176 if len(s.opts.statsHandlers) != 0 { 1177 for _, sh := range s.opts.statsHandlers { 1178 sh.HandleRPC(ctx, outPayload(false, msg, dataLen, payloadLen, time.Now())) 1179 } 1180 } 1181 } 1182 return err 1183 } 1184 1185 // chainUnaryServerInterceptors chains all unary server interceptors into one. 1186 func chainUnaryServerInterceptors(s *Server) { 1187 // Prepend opts.unaryInt to the chaining interceptors if it exists, since unaryInt will 1188 // be executed before any other chained interceptors. 1189 interceptors := s.opts.chainUnaryInts 1190 if s.opts.unaryInt != nil { 1191 interceptors = append([]UnaryServerInterceptor{s.opts.unaryInt}, s.opts.chainUnaryInts...) 1192 } 1193 1194 var chainedInt UnaryServerInterceptor 1195 if len(interceptors) == 0 { 1196 chainedInt = nil 1197 } else if len(interceptors) == 1 { 1198 chainedInt = interceptors[0] 1199 } else { 1200 chainedInt = chainUnaryInterceptors(interceptors) 1201 } 1202 1203 s.opts.unaryInt = chainedInt 1204 } 1205 1206 func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor { 1207 return func(ctx context.Context, req any, info *UnaryServerInfo, handler UnaryHandler) (any, error) { 1208 return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler)) 1209 } 1210 } 1211 1212 func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler { 1213 if curr == len(interceptors)-1 { 1214 return finalHandler 1215 } 1216 return func(ctx context.Context, req any) (any, error) { 1217 return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler)) 1218 } 1219 } 1220 1221 func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerStream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) { 1222 shs := s.opts.statsHandlers 1223 if len(shs) != 0 || trInfo != nil || channelz.IsOn() { 1224 if channelz.IsOn() { 1225 s.incrCallsStarted() 1226 } 1227 var statsBegin *stats.Begin 1228 for _, sh := range shs { 1229 beginTime := time.Now() 1230 statsBegin = &stats.Begin{ 1231 BeginTime: beginTime, 1232 IsClientStream: false, 1233 IsServerStream: false, 1234 } 1235 sh.HandleRPC(ctx, statsBegin) 1236 } 1237 if trInfo != nil { 1238 trInfo.tr.LazyLog(&trInfo.firstLine, false) 1239 } 1240 // The deferred error handling for tracing, stats handler and channelz are 1241 // combined into one function to reduce stack usage -- a defer takes ~56-64 1242 // bytes on the stack, so overflowing the stack will require a stack 1243 // re-allocation, which is expensive. 1244 // 1245 // To maintain behavior similar to separate deferred statements, statements 1246 // should be executed in the reverse order. That is, tracing first, stats 1247 // handler second, and channelz last. Note that panics *within* defers will 1248 // lead to different behavior, but that's an acceptable compromise; that 1249 // would be undefined behavior territory anyway. 1250 defer func() { 1251 if trInfo != nil { 1252 if err != nil && err != io.EOF { 1253 trInfo.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true) 1254 trInfo.tr.SetError() 1255 } 1256 trInfo.tr.Finish() 1257 } 1258 1259 for _, sh := range shs { 1260 end := &stats.End{ 1261 BeginTime: statsBegin.BeginTime, 1262 EndTime: time.Now(), 1263 } 1264 if err != nil && err != io.EOF { 1265 end.Error = toRPCErr(err) 1266 } 1267 sh.HandleRPC(ctx, end) 1268 } 1269 1270 if channelz.IsOn() { 1271 if err != nil && err != io.EOF { 1272 s.incrCallsFailed() 1273 } else { 1274 s.incrCallsSucceeded() 1275 } 1276 } 1277 }() 1278 } 1279 var binlogs []binarylog.MethodLogger 1280 if ml := binarylog.GetMethodLogger(stream.Method()); ml != nil { 1281 binlogs = append(binlogs, ml) 1282 } 1283 if s.opts.binaryLogger != nil { 1284 if ml := s.opts.binaryLogger.GetMethodLogger(stream.Method()); ml != nil { 1285 binlogs = append(binlogs, ml) 1286 } 1287 } 1288 if len(binlogs) != 0 { 1289 md, _ := metadata.FromIncomingContext(ctx) 1290 logEntry := &binarylog.ClientHeader{ 1291 Header: md, 1292 MethodName: stream.Method(), 1293 PeerAddr: nil, 1294 } 1295 if deadline, ok := ctx.Deadline(); ok { 1296 logEntry.Timeout = time.Until(deadline) 1297 if logEntry.Timeout < 0 { 1298 logEntry.Timeout = 0 1299 } 1300 } 1301 if a := md[":authority"]; len(a) > 0 { 1302 logEntry.Authority = a[0] 1303 } 1304 if peer, ok := peer.FromContext(ctx); ok { 1305 logEntry.PeerAddr = peer.Addr 1306 } 1307 for _, binlog := range binlogs { 1308 binlog.Log(ctx, logEntry) 1309 } 1310 } 1311 1312 // comp and cp are used for compression. decomp and dc are used for 1313 // decompression. If comp and decomp are both set, they are the same; 1314 // however they are kept separate to ensure that at most one of the 1315 // compressor/decompressor variable pairs are set for use later. 1316 var comp, decomp encoding.Compressor 1317 var cp Compressor 1318 var dc Decompressor 1319 var sendCompressorName string 1320 1321 // If dc is set and matches the stream's compression, use it. Otherwise, try 1322 // to find a matching registered compressor for decomp. 1323 if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc { 1324 dc = s.opts.dc 1325 } else if rc != "" && rc != encoding.Identity { 1326 decomp = encoding.GetCompressor(rc) 1327 if decomp == nil { 1328 st := status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", rc) 1329 stream.WriteStatus(st) 1330 return st.Err() 1331 } 1332 } 1333 1334 // If cp is set, use it. Otherwise, attempt to compress the response using 1335 // the incoming message compression method. 1336 // 1337 // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. 1338 if s.opts.cp != nil { 1339 cp = s.opts.cp 1340 sendCompressorName = cp.Type() 1341 } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { 1342 // Legacy compressor not specified; attempt to respond with same encoding. 1343 comp = encoding.GetCompressor(rc) 1344 if comp != nil { 1345 sendCompressorName = comp.Name() 1346 } 1347 } 1348 1349 if sendCompressorName != "" { 1350 if err := stream.SetSendCompress(sendCompressorName); err != nil { 1351 return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err) 1352 } 1353 } 1354 1355 var payInfo *payloadInfo 1356 if len(shs) != 0 || len(binlogs) != 0 { 1357 payInfo = &payloadInfo{} 1358 defer payInfo.free() 1359 } 1360 1361 d, err := recvAndDecompress(&parser{r: stream, bufferPool: s.opts.bufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true) 1362 if err != nil { 1363 if e := stream.WriteStatus(status.Convert(err)); e != nil { 1364 channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e) 1365 } 1366 return err 1367 } 1368 freed := false 1369 dataFree := func() { 1370 if !freed { 1371 d.Free() 1372 freed = true 1373 } 1374 } 1375 defer dataFree() 1376 df := func(v any) error { 1377 defer dataFree() 1378 if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil { 1379 return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) 1380 } 1381 1382 for _, sh := range shs { 1383 sh.HandleRPC(ctx, &stats.InPayload{ 1384 RecvTime: time.Now(), 1385 Payload: v, 1386 Length: d.Len(), 1387 WireLength: payInfo.compressedLength + headerLen, 1388 CompressedLength: payInfo.compressedLength, 1389 }) 1390 } 1391 if len(binlogs) != 0 { 1392 cm := &binarylog.ClientMessage{ 1393 Message: d.Materialize(), 1394 } 1395 for _, binlog := range binlogs { 1396 binlog.Log(ctx, cm) 1397 } 1398 } 1399 if trInfo != nil { 1400 trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true) 1401 } 1402 return nil 1403 } 1404 ctx = NewContextWithServerTransportStream(ctx, stream) 1405 reply, appErr := md.Handler(info.serviceImpl, ctx, df, s.opts.unaryInt) 1406 if appErr != nil { 1407 appStatus, ok := status.FromError(appErr) 1408 if !ok { 1409 // Convert non-status application error to a status error with code 1410 // Unknown, but handle context errors specifically. 1411 appStatus = status.FromContextError(appErr) 1412 appErr = appStatus.Err() 1413 } 1414 if trInfo != nil { 1415 trInfo.tr.LazyLog(stringer(appStatus.Message()), true) 1416 trInfo.tr.SetError() 1417 } 1418 if e := stream.WriteStatus(appStatus); e != nil { 1419 channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e) 1420 } 1421 if len(binlogs) != 0 { 1422 if h, _ := stream.Header(); h.Len() > 0 { 1423 // Only log serverHeader if there was header. Otherwise it can 1424 // be trailer only. 1425 sh := &binarylog.ServerHeader{ 1426 Header: h, 1427 } 1428 for _, binlog := range binlogs { 1429 binlog.Log(ctx, sh) 1430 } 1431 } 1432 st := &binarylog.ServerTrailer{ 1433 Trailer: stream.Trailer(), 1434 Err: appErr, 1435 } 1436 for _, binlog := range binlogs { 1437 binlog.Log(ctx, st) 1438 } 1439 } 1440 return appErr 1441 } 1442 if trInfo != nil { 1443 trInfo.tr.LazyLog(stringer("OK"), false) 1444 } 1445 opts := &transport.WriteOptions{Last: true} 1446 1447 // Server handler could have set new compressor by calling SetSendCompressor. 1448 // In case it is set, we need to use it for compressing outbound message. 1449 if stream.SendCompress() != sendCompressorName { 1450 comp = encoding.GetCompressor(stream.SendCompress()) 1451 } 1452 if err := s.sendResponse(ctx, stream, reply, cp, opts, comp); err != nil { 1453 if err == io.EOF { 1454 // The entire stream is done (for unary RPC only). 1455 return err 1456 } 1457 if sts, ok := status.FromError(err); ok { 1458 if e := stream.WriteStatus(sts); e != nil { 1459 channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e) 1460 } 1461 } else { 1462 switch st := err.(type) { 1463 case transport.ConnectionError: 1464 // Nothing to do here. 1465 default: 1466 panic(fmt.Sprintf("grpc: Unexpected error (%T) from sendResponse: %v", st, st)) 1467 } 1468 } 1469 if len(binlogs) != 0 { 1470 h, _ := stream.Header() 1471 sh := &binarylog.ServerHeader{ 1472 Header: h, 1473 } 1474 st := &binarylog.ServerTrailer{ 1475 Trailer: stream.Trailer(), 1476 Err: appErr, 1477 } 1478 for _, binlog := range binlogs { 1479 binlog.Log(ctx, sh) 1480 binlog.Log(ctx, st) 1481 } 1482 } 1483 return err 1484 } 1485 if len(binlogs) != 0 { 1486 h, _ := stream.Header() 1487 sh := &binarylog.ServerHeader{ 1488 Header: h, 1489 } 1490 sm := &binarylog.ServerMessage{ 1491 Message: reply, 1492 } 1493 for _, binlog := range binlogs { 1494 binlog.Log(ctx, sh) 1495 binlog.Log(ctx, sm) 1496 } 1497 } 1498 if trInfo != nil { 1499 trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) 1500 } 1501 // TODO: Should we be logging if writing status failed here, like above? 1502 // Should the logging be in WriteStatus? Should we ignore the WriteStatus 1503 // error or allow the stats handler to see it? 1504 if len(binlogs) != 0 { 1505 st := &binarylog.ServerTrailer{ 1506 Trailer: stream.Trailer(), 1507 Err: appErr, 1508 } 1509 for _, binlog := range binlogs { 1510 binlog.Log(ctx, st) 1511 } 1512 } 1513 return stream.WriteStatus(statusOK) 1514 } 1515 1516 // chainStreamServerInterceptors chains all stream server interceptors into one. 1517 func chainStreamServerInterceptors(s *Server) { 1518 // Prepend opts.streamInt to the chaining interceptors if it exists, since streamInt will 1519 // be executed before any other chained interceptors. 1520 interceptors := s.opts.chainStreamInts 1521 if s.opts.streamInt != nil { 1522 interceptors = append([]StreamServerInterceptor{s.opts.streamInt}, s.opts.chainStreamInts...) 1523 } 1524 1525 var chainedInt StreamServerInterceptor 1526 if len(interceptors) == 0 { 1527 chainedInt = nil 1528 } else if len(interceptors) == 1 { 1529 chainedInt = interceptors[0] 1530 } else { 1531 chainedInt = chainStreamInterceptors(interceptors) 1532 } 1533 1534 s.opts.streamInt = chainedInt 1535 } 1536 1537 func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor { 1538 return func(srv any, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error { 1539 return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler)) 1540 } 1541 } 1542 1543 func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler { 1544 if curr == len(interceptors)-1 { 1545 return finalHandler 1546 } 1547 return func(srv any, stream ServerStream) error { 1548 return interceptors[curr+1](srv, stream, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler)) 1549 } 1550 } 1551 1552 func (s *Server) processStreamingRPC(ctx context.Context, stream *transport.ServerStream, info *serviceInfo, sd *StreamDesc, trInfo *traceInfo) (err error) { 1553 if channelz.IsOn() { 1554 s.incrCallsStarted() 1555 } 1556 shs := s.opts.statsHandlers 1557 var statsBegin *stats.Begin 1558 if len(shs) != 0 { 1559 beginTime := time.Now() 1560 statsBegin = &stats.Begin{ 1561 BeginTime: beginTime, 1562 IsClientStream: sd.ClientStreams, 1563 IsServerStream: sd.ServerStreams, 1564 } 1565 for _, sh := range shs { 1566 sh.HandleRPC(ctx, statsBegin) 1567 } 1568 } 1569 ctx = NewContextWithServerTransportStream(ctx, stream) 1570 ss := &serverStream{ 1571 ctx: ctx, 1572 s: stream, 1573 p: &parser{r: stream, bufferPool: s.opts.bufferPool}, 1574 codec: s.getCodec(stream.ContentSubtype()), 1575 maxReceiveMessageSize: s.opts.maxReceiveMessageSize, 1576 maxSendMessageSize: s.opts.maxSendMessageSize, 1577 trInfo: trInfo, 1578 statsHandler: shs, 1579 } 1580 1581 if len(shs) != 0 || trInfo != nil || channelz.IsOn() { 1582 // See comment in processUnaryRPC on defers. 1583 defer func() { 1584 if trInfo != nil { 1585 ss.mu.Lock() 1586 if err != nil && err != io.EOF { 1587 ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true) 1588 ss.trInfo.tr.SetError() 1589 } 1590 ss.trInfo.tr.Finish() 1591 ss.trInfo.tr = nil 1592 ss.mu.Unlock() 1593 } 1594 1595 if len(shs) != 0 { 1596 end := &stats.End{ 1597 BeginTime: statsBegin.BeginTime, 1598 EndTime: time.Now(), 1599 } 1600 if err != nil && err != io.EOF { 1601 end.Error = toRPCErr(err) 1602 } 1603 for _, sh := range shs { 1604 sh.HandleRPC(ctx, end) 1605 } 1606 } 1607 1608 if channelz.IsOn() { 1609 if err != nil && err != io.EOF { 1610 s.incrCallsFailed() 1611 } else { 1612 s.incrCallsSucceeded() 1613 } 1614 } 1615 }() 1616 } 1617 1618 if ml := binarylog.GetMethodLogger(stream.Method()); ml != nil { 1619 ss.binlogs = append(ss.binlogs, ml) 1620 } 1621 if s.opts.binaryLogger != nil { 1622 if ml := s.opts.binaryLogger.GetMethodLogger(stream.Method()); ml != nil { 1623 ss.binlogs = append(ss.binlogs, ml) 1624 } 1625 } 1626 if len(ss.binlogs) != 0 { 1627 md, _ := metadata.FromIncomingContext(ctx) 1628 logEntry := &binarylog.ClientHeader{ 1629 Header: md, 1630 MethodName: stream.Method(), 1631 PeerAddr: nil, 1632 } 1633 if deadline, ok := ctx.Deadline(); ok { 1634 logEntry.Timeout = time.Until(deadline) 1635 if logEntry.Timeout < 0 { 1636 logEntry.Timeout = 0 1637 } 1638 } 1639 if a := md[":authority"]; len(a) > 0 { 1640 logEntry.Authority = a[0] 1641 } 1642 if peer, ok := peer.FromContext(ss.Context()); ok { 1643 logEntry.PeerAddr = peer.Addr 1644 } 1645 for _, binlog := range ss.binlogs { 1646 binlog.Log(ctx, logEntry) 1647 } 1648 } 1649 1650 // If dc is set and matches the stream's compression, use it. Otherwise, try 1651 // to find a matching registered compressor for decomp. 1652 if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc { 1653 ss.decompressorV0 = s.opts.dc 1654 } else if rc != "" && rc != encoding.Identity { 1655 ss.decompressorV1 = encoding.GetCompressor(rc) 1656 if ss.decompressorV1 == nil { 1657 st := status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", rc) 1658 ss.s.WriteStatus(st) 1659 return st.Err() 1660 } 1661 } 1662 1663 // If cp is set, use it. Otherwise, attempt to compress the response using 1664 // the incoming message compression method. 1665 // 1666 // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. 1667 if s.opts.cp != nil { 1668 ss.compressorV0 = s.opts.cp 1669 ss.sendCompressorName = s.opts.cp.Type() 1670 } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { 1671 // Legacy compressor not specified; attempt to respond with same encoding. 1672 ss.compressorV1 = encoding.GetCompressor(rc) 1673 if ss.compressorV1 != nil { 1674 ss.sendCompressorName = rc 1675 } 1676 } 1677 1678 if ss.sendCompressorName != "" { 1679 if err := stream.SetSendCompress(ss.sendCompressorName); err != nil { 1680 return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err) 1681 } 1682 } 1683 1684 ss.ctx = newContextWithRPCInfo(ss.ctx, false, ss.codec, ss.compressorV0, ss.compressorV1) 1685 1686 if trInfo != nil { 1687 trInfo.tr.LazyLog(&trInfo.firstLine, false) 1688 } 1689 var appErr error 1690 var server any 1691 if info != nil { 1692 server = info.serviceImpl 1693 } 1694 if s.opts.streamInt == nil { 1695 appErr = sd.Handler(server, ss) 1696 } else { 1697 info := &StreamServerInfo{ 1698 FullMethod: stream.Method(), 1699 IsClientStream: sd.ClientStreams, 1700 IsServerStream: sd.ServerStreams, 1701 } 1702 appErr = s.opts.streamInt(server, ss, info, sd.Handler) 1703 } 1704 if appErr != nil { 1705 appStatus, ok := status.FromError(appErr) 1706 if !ok { 1707 // Convert non-status application error to a status error with code 1708 // Unknown, but handle context errors specifically. 1709 appStatus = status.FromContextError(appErr) 1710 appErr = appStatus.Err() 1711 } 1712 if trInfo != nil { 1713 ss.mu.Lock() 1714 ss.trInfo.tr.LazyLog(stringer(appStatus.Message()), true) 1715 ss.trInfo.tr.SetError() 1716 ss.mu.Unlock() 1717 } 1718 if len(ss.binlogs) != 0 { 1719 st := &binarylog.ServerTrailer{ 1720 Trailer: ss.s.Trailer(), 1721 Err: appErr, 1722 } 1723 for _, binlog := range ss.binlogs { 1724 binlog.Log(ctx, st) 1725 } 1726 } 1727 ss.s.WriteStatus(appStatus) 1728 // TODO: Should we log an error from WriteStatus here and below? 1729 return appErr 1730 } 1731 if trInfo != nil { 1732 ss.mu.Lock() 1733 ss.trInfo.tr.LazyLog(stringer("OK"), false) 1734 ss.mu.Unlock() 1735 } 1736 if len(ss.binlogs) != 0 { 1737 st := &binarylog.ServerTrailer{ 1738 Trailer: ss.s.Trailer(), 1739 Err: appErr, 1740 } 1741 for _, binlog := range ss.binlogs { 1742 binlog.Log(ctx, st) 1743 } 1744 } 1745 return ss.s.WriteStatus(statusOK) 1746 } 1747 1748 func (s *Server) handleStream(t transport.ServerTransport, stream *transport.ServerStream) { 1749 ctx := stream.Context() 1750 ctx = contextWithServer(ctx, s) 1751 var ti *traceInfo 1752 if EnableTracing { 1753 tr := newTrace("grpc.Recv."+methodFamily(stream.Method()), stream.Method()) 1754 ctx = newTraceContext(ctx, tr) 1755 ti = &traceInfo{ 1756 tr: tr, 1757 firstLine: firstLine{ 1758 client: false, 1759 remoteAddr: t.Peer().Addr, 1760 }, 1761 } 1762 if dl, ok := ctx.Deadline(); ok { 1763 ti.firstLine.deadline = time.Until(dl) 1764 } 1765 } 1766 1767 sm := stream.Method() 1768 if sm != "" && sm[0] == '/' { 1769 sm = sm[1:] 1770 } 1771 pos := strings.LastIndex(sm, "/") 1772 if pos == -1 { 1773 if ti != nil { 1774 ti.tr.LazyLog(&fmtStringer{"Malformed method name %q", []any{sm}}, true) 1775 ti.tr.SetError() 1776 } 1777 errDesc := fmt.Sprintf("malformed method name: %q", stream.Method()) 1778 if err := stream.WriteStatus(status.New(codes.Unimplemented, errDesc)); err != nil { 1779 if ti != nil { 1780 ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true) 1781 ti.tr.SetError() 1782 } 1783 channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream failed to write status: %v", err) 1784 } 1785 if ti != nil { 1786 ti.tr.Finish() 1787 } 1788 return 1789 } 1790 service := sm[:pos] 1791 method := sm[pos+1:] 1792 1793 // FromIncomingContext is expensive: skip if there are no statsHandlers 1794 if len(s.opts.statsHandlers) > 0 { 1795 md, _ := metadata.FromIncomingContext(ctx) 1796 for _, sh := range s.opts.statsHandlers { 1797 ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: stream.Method()}) 1798 sh.HandleRPC(ctx, &stats.InHeader{ 1799 FullMethod: stream.Method(), 1800 RemoteAddr: t.Peer().Addr, 1801 LocalAddr: t.Peer().LocalAddr, 1802 Compression: stream.RecvCompress(), 1803 WireLength: stream.HeaderWireLength(), 1804 Header: md, 1805 }) 1806 } 1807 } 1808 // To have calls in stream callouts work. Will delete once all stats handler 1809 // calls come from the gRPC layer. 1810 stream.SetContext(ctx) 1811 1812 srv, knownService := s.services[service] 1813 if knownService { 1814 if md, ok := srv.methods[method]; ok { 1815 s.processUnaryRPC(ctx, stream, srv, md, ti) 1816 return 1817 } 1818 if sd, ok := srv.streams[method]; ok { 1819 s.processStreamingRPC(ctx, stream, srv, sd, ti) 1820 return 1821 } 1822 } 1823 // Unknown service, or known server unknown method. 1824 if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil { 1825 s.processStreamingRPC(ctx, stream, nil, unknownDesc, ti) 1826 return 1827 } 1828 var errDesc string 1829 if !knownService { 1830 errDesc = fmt.Sprintf("unknown service %v", service) 1831 } else { 1832 errDesc = fmt.Sprintf("unknown method %v for service %v", method, service) 1833 } 1834 if ti != nil { 1835 ti.tr.LazyPrintf("%s", errDesc) 1836 ti.tr.SetError() 1837 } 1838 if err := stream.WriteStatus(status.New(codes.Unimplemented, errDesc)); err != nil { 1839 if ti != nil { 1840 ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true) 1841 ti.tr.SetError() 1842 } 1843 channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream failed to write status: %v", err) 1844 } 1845 if ti != nil { 1846 ti.tr.Finish() 1847 } 1848 } 1849 1850 // The key to save ServerTransportStream in the context. 1851 type streamKey struct{} 1852 1853 // NewContextWithServerTransportStream creates a new context from ctx and 1854 // attaches stream to it. 1855 // 1856 // # Experimental 1857 // 1858 // Notice: This API is EXPERIMENTAL and may be changed or removed in a 1859 // later release. 1860 func NewContextWithServerTransportStream(ctx context.Context, stream ServerTransportStream) context.Context { 1861 return context.WithValue(ctx, streamKey{}, stream) 1862 } 1863 1864 // ServerTransportStream is a minimal interface that a transport stream must 1865 // implement. This can be used to mock an actual transport stream for tests of 1866 // handler code that use, for example, grpc.SetHeader (which requires some 1867 // stream to be in context). 1868 // 1869 // See also NewContextWithServerTransportStream. 1870 // 1871 // # Experimental 1872 // 1873 // Notice: This type is EXPERIMENTAL and may be changed or removed in a 1874 // later release. 1875 type ServerTransportStream interface { 1876 Method() string 1877 SetHeader(md metadata.MD) error 1878 SendHeader(md metadata.MD) error 1879 SetTrailer(md metadata.MD) error 1880 } 1881 1882 // ServerTransportStreamFromContext returns the ServerTransportStream saved in 1883 // ctx. Returns nil if the given context has no stream associated with it 1884 // (which implies it is not an RPC invocation context). 1885 // 1886 // # Experimental 1887 // 1888 // Notice: This API is EXPERIMENTAL and may be changed or removed in a 1889 // later release. 1890 func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream { 1891 s, _ := ctx.Value(streamKey{}).(ServerTransportStream) 1892 return s 1893 } 1894 1895 // Stop stops the gRPC server. It immediately closes all open 1896 // connections and listeners. 1897 // It cancels all active RPCs on the server side and the corresponding 1898 // pending RPCs on the client side will get notified by connection 1899 // errors. 1900 func (s *Server) Stop() { 1901 s.stop(false) 1902 } 1903 1904 // GracefulStop stops the gRPC server gracefully. It stops the server from 1905 // accepting new connections and RPCs and blocks until all the pending RPCs are 1906 // finished. 1907 func (s *Server) GracefulStop() { 1908 s.stop(true) 1909 } 1910 1911 func (s *Server) stop(graceful bool) { 1912 s.quit.Fire() 1913 defer s.done.Fire() 1914 1915 s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelz.ID) }) 1916 s.mu.Lock() 1917 s.closeListenersLocked() 1918 // Wait for serving threads to be ready to exit. Only then can we be sure no 1919 // new conns will be created. 1920 s.mu.Unlock() 1921 s.serveWG.Wait() 1922 1923 s.mu.Lock() 1924 defer s.mu.Unlock() 1925 1926 if graceful { 1927 s.drainAllServerTransportsLocked() 1928 } else { 1929 s.closeServerTransportsLocked() 1930 } 1931 1932 for len(s.conns) != 0 { 1933 s.cv.Wait() 1934 } 1935 s.conns = nil 1936 1937 if s.opts.numServerWorkers > 0 { 1938 // Closing the channel (only once, via sync.OnceFunc) after all the 1939 // connections have been closed above ensures that there are no 1940 // goroutines executing the callback passed to st.HandleStreams (where 1941 // the channel is written to). 1942 s.serverWorkerChannelClose() 1943 } 1944 1945 if graceful || s.opts.waitForHandlers { 1946 s.handlersWG.Wait() 1947 } 1948 1949 if s.events != nil { 1950 s.events.Finish() 1951 s.events = nil 1952 } 1953 } 1954 1955 // s.mu must be held by the caller. 1956 func (s *Server) closeServerTransportsLocked() { 1957 for _, conns := range s.conns { 1958 for st := range conns { 1959 st.Close(errors.New("Server.Stop called")) 1960 } 1961 } 1962 } 1963 1964 // s.mu must be held by the caller. 1965 func (s *Server) drainAllServerTransportsLocked() { 1966 if !s.drain { 1967 for _, conns := range s.conns { 1968 for st := range conns { 1969 st.Drain("graceful_stop") 1970 } 1971 } 1972 s.drain = true 1973 } 1974 } 1975 1976 // s.mu must be held by the caller. 1977 func (s *Server) closeListenersLocked() { 1978 for lis := range s.lis { 1979 lis.Close() 1980 } 1981 s.lis = nil 1982 } 1983 1984 // contentSubtype must be lowercase 1985 // cannot return nil 1986 func (s *Server) getCodec(contentSubtype string) baseCodec { 1987 if s.opts.codec != nil { 1988 return s.opts.codec 1989 } 1990 if contentSubtype == "" { 1991 return getCodec(proto.Name) 1992 } 1993 codec := getCodec(contentSubtype) 1994 if codec == nil { 1995 logger.Warningf("Unsupported codec %q. Defaulting to %q for now. This will start to fail in future releases.", contentSubtype, proto.Name) 1996 return getCodec(proto.Name) 1997 } 1998 return codec 1999 } 2000 2001 type serverKey struct{} 2002 2003 // serverFromContext gets the Server from the context. 2004 func serverFromContext(ctx context.Context) *Server { 2005 s, _ := ctx.Value(serverKey{}).(*Server) 2006 return s 2007 } 2008 2009 // contextWithServer sets the Server in the context. 2010 func contextWithServer(ctx context.Context, server *Server) context.Context { 2011 return context.WithValue(ctx, serverKey{}, server) 2012 } 2013 2014 // isRegisteredMethod returns whether the passed in method is registered as a 2015 // method on the server. /service/method and service/method will match if the 2016 // service and method are registered on the server. 2017 func (s *Server) isRegisteredMethod(serviceMethod string) bool { 2018 if serviceMethod != "" && serviceMethod[0] == '/' { 2019 serviceMethod = serviceMethod[1:] 2020 } 2021 pos := strings.LastIndex(serviceMethod, "/") 2022 if pos == -1 { // Invalid method name syntax. 2023 return false 2024 } 2025 service := serviceMethod[:pos] 2026 method := serviceMethod[pos+1:] 2027 srv, knownService := s.services[service] 2028 if knownService { 2029 if _, ok := srv.methods[method]; ok { 2030 return true 2031 } 2032 if _, ok := srv.streams[method]; ok { 2033 return true 2034 } 2035 } 2036 return false 2037 } 2038 2039 // SetHeader sets the header metadata to be sent from the server to the client. 2040 // The context provided must be the context passed to the server's handler. 2041 // 2042 // Streaming RPCs should prefer the SetHeader method of the ServerStream. 2043 // 2044 // When called multiple times, all the provided metadata will be merged. All 2045 // the metadata will be sent out when one of the following happens: 2046 // 2047 // - grpc.SendHeader is called, or for streaming handlers, stream.SendHeader. 2048 // - The first response message is sent. For unary handlers, this occurs when 2049 // the handler returns; for streaming handlers, this can happen when stream's 2050 // SendMsg method is called. 2051 // - An RPC status is sent out (error or success). This occurs when the handler 2052 // returns. 2053 // 2054 // SetHeader will fail if called after any of the events above. 2055 // 2056 // The error returned is compatible with the status package. However, the 2057 // status code will often not match the RPC status as seen by the client 2058 // application, and therefore, should not be relied upon for this purpose. 2059 func SetHeader(ctx context.Context, md metadata.MD) error { 2060 if md.Len() == 0 { 2061 return nil 2062 } 2063 stream := ServerTransportStreamFromContext(ctx) 2064 if stream == nil { 2065 return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) 2066 } 2067 return stream.SetHeader(md) 2068 } 2069 2070 // SendHeader sends header metadata. It may be called at most once, and may not 2071 // be called after any event that causes headers to be sent (see SetHeader for 2072 // a complete list). The provided md and headers set by SetHeader() will be 2073 // sent. 2074 // 2075 // The error returned is compatible with the status package. However, the 2076 // status code will often not match the RPC status as seen by the client 2077 // application, and therefore, should not be relied upon for this purpose. 2078 func SendHeader(ctx context.Context, md metadata.MD) error { 2079 stream := ServerTransportStreamFromContext(ctx) 2080 if stream == nil { 2081 return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) 2082 } 2083 if err := stream.SendHeader(md); err != nil { 2084 return toRPCErr(err) 2085 } 2086 return nil 2087 } 2088 2089 // SetSendCompressor sets a compressor for outbound messages from the server. 2090 // It must not be called after any event that causes headers to be sent 2091 // (see ServerStream.SetHeader for the complete list). Provided compressor is 2092 // used when below conditions are met: 2093 // 2094 // - compressor is registered via encoding.RegisterCompressor 2095 // - compressor name must exist in the client advertised compressor names 2096 // sent in grpc-accept-encoding header. Use ClientSupportedCompressors to 2097 // get client supported compressor names. 2098 // 2099 // The context provided must be the context passed to the server's handler. 2100 // It must be noted that compressor name encoding.Identity disables the 2101 // outbound compression. 2102 // By default, server messages will be sent using the same compressor with 2103 // which request messages were sent. 2104 // 2105 // It is not safe to call SetSendCompressor concurrently with SendHeader and 2106 // SendMsg. 2107 // 2108 // # Experimental 2109 // 2110 // Notice: This function is EXPERIMENTAL and may be changed or removed in a 2111 // later release. 2112 func SetSendCompressor(ctx context.Context, name string) error { 2113 stream, ok := ServerTransportStreamFromContext(ctx).(*transport.ServerStream) 2114 if !ok || stream == nil { 2115 return fmt.Errorf("failed to fetch the stream from the given context") 2116 } 2117 2118 if err := validateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil { 2119 return fmt.Errorf("unable to set send compressor: %w", err) 2120 } 2121 2122 return stream.SetSendCompress(name) 2123 } 2124 2125 // ClientSupportedCompressors returns compressor names advertised by the client 2126 // via grpc-accept-encoding header. 2127 // 2128 // The context provided must be the context passed to the server's handler. 2129 // 2130 // # Experimental 2131 // 2132 // Notice: This function is EXPERIMENTAL and may be changed or removed in a 2133 // later release. 2134 func ClientSupportedCompressors(ctx context.Context) ([]string, error) { 2135 stream, ok := ServerTransportStreamFromContext(ctx).(*transport.ServerStream) 2136 if !ok || stream == nil { 2137 return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx) 2138 } 2139 2140 return stream.ClientAdvertisedCompressors(), nil 2141 } 2142 2143 // SetTrailer sets the trailer metadata that will be sent when an RPC returns. 2144 // When called more than once, all the provided metadata will be merged. 2145 // 2146 // The error returned is compatible with the status package. However, the 2147 // status code will often not match the RPC status as seen by the client 2148 // application, and therefore, should not be relied upon for this purpose. 2149 func SetTrailer(ctx context.Context, md metadata.MD) error { 2150 if md.Len() == 0 { 2151 return nil 2152 } 2153 stream := ServerTransportStreamFromContext(ctx) 2154 if stream == nil { 2155 return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) 2156 } 2157 return stream.SetTrailer(md) 2158 } 2159 2160 // Method returns the method string for the server context. The returned 2161 // string is in the format of "/service/method". 2162 func Method(ctx context.Context) (string, bool) { 2163 s := ServerTransportStreamFromContext(ctx) 2164 if s == nil { 2165 return "", false 2166 } 2167 return s.Method(), true 2168 } 2169 2170 // validateSendCompressor returns an error when given compressor name cannot be 2171 // handled by the server or the client based on the advertised compressors. 2172 func validateSendCompressor(name string, clientCompressors []string) error { 2173 if name == encoding.Identity { 2174 return nil 2175 } 2176 2177 if !grpcutil.IsCompressorNameRegistered(name) { 2178 return fmt.Errorf("compressor not registered %q", name) 2179 } 2180 2181 for _, c := range clientCompressors { 2182 if c == name { 2183 return nil // found match 2184 } 2185 } 2186 return fmt.Errorf("client does not support compressor %q", name) 2187 } 2188 2189 // atomicSemaphore implements a blocking, counting semaphore. acquire should be 2190 // called synchronously; release may be called asynchronously. 2191 type atomicSemaphore struct { 2192 n atomic.Int64 2193 wait chan struct{} 2194 } 2195 2196 func (q *atomicSemaphore) acquire() { 2197 if q.n.Add(-1) < 0 { 2198 // We ran out of quota. Block until a release happens. 2199 <-q.wait 2200 } 2201 } 2202 2203 func (q *atomicSemaphore) release() { 2204 // N.B. the "<= 0" check below should allow for this to work with multiple 2205 // concurrent calls to acquire, but also note that with synchronous calls to 2206 // acquire, as our system does, n will never be less than -1. There are 2207 // fairness issues (queuing) to consider if this was to be generalized. 2208 if q.n.Add(1) <= 0 { 2209 // An acquire was waiting on us. Unblock it. 2210 q.wait <- struct{}{} 2211 } 2212 } 2213 2214 func newHandlerQuota(n uint32) *atomicSemaphore { 2215 a := &atomicSemaphore{wait: make(chan struct{}, 1)} 2216 a.n.Store(int64(n)) 2217 return a 2218 }