github.com/nya3jp/tast@v0.0.0-20230601000426-85c8e4d83a9b/src/go.chromium.org/tast/core/internal/rpc/server.go (about) 1 // Copyright 2019 The ChromiumOS Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 package rpc 6 7 import ( 8 "context" 9 "encoding/json" 10 "fmt" 11 "io" 12 "io/ioutil" 13 "net" 14 "os" 15 "os/signal" 16 "strconv" 17 "sync" 18 19 "golang.org/x/sys/unix" 20 "google.golang.org/grpc" 21 "google.golang.org/grpc/codes" 22 "google.golang.org/grpc/metadata" 23 "google.golang.org/grpc/reflection" 24 "google.golang.org/grpc/status" 25 26 "go.chromium.org/tast/core/errors" 27 "go.chromium.org/tast/core/internal/logging" 28 "go.chromium.org/tast/core/internal/protocol" 29 "go.chromium.org/tast/core/internal/testcontext" 30 "go.chromium.org/tast/core/internal/testing" 31 "go.chromium.org/tast/core/internal/timing" 32 ) 33 34 // MaxMessageSize is used to tell Tast's GRPC servers and clients the maximum size of messages. 35 const MaxMessageSize = 1024 * 1024 * 8 36 37 // RunServer runs a gRPC server on r/w channels. 38 // register is called back to register core services. svcs is a list of 39 // user-defined gRPC services to be registered if the client requests them in 40 // HandshakeRequest. 41 // RunServer blocks until the client connection is closed or it encounters an 42 // error. 43 func RunServer(r io.Reader, w io.Writer, svcs []*testing.Service, register func(srv *grpc.Server, req *protocol.HandshakeRequest) error) error { 44 // In case w is stdout or stderr, writing data to it after it is closed 45 // causes SIGPIPE to be delivered to the process, which by default 46 // terminates the process without running deferred cleanup calls. 47 // To avoid the issue, ignore SIGPIPE while running the gRPC server. 48 // See https://golang.org/pkg/os/signal/#hdr-SIGPIPE for more details. 49 signal.Ignore(unix.SIGPIPE) 50 defer signal.Reset(unix.SIGPIPE) 51 52 var req protocol.HandshakeRequest 53 if err := receiveRawMessage(r, &req); err != nil { 54 return err 55 } 56 57 // Make sure to return only after all active method calls finish. 58 // Otherwise the process can exit before running deferred function 59 // calls on service goroutines. 60 var calls sync.WaitGroup 61 defer calls.Wait() 62 63 // Start a remote logging server. It is used to forward logs from 64 // user-defined gRPC services via side channels. 65 ls := newRemoteLoggingServer() 66 67 // Setup logger to encapsulate the underlying logging mechanism. 68 logger := newFuncLogger(ls.Log) 69 srv := grpc.NewServer(serverOpts(ls, logger, &calls)...) 70 71 // Register core services. 72 regErr := registerCoreServices(srv, ls, &req, register) 73 74 // Create a server-scoped context. 75 ctx, cancel := context.WithCancel(context.Background()) 76 defer cancel() 77 78 // Register user-defined gRPC services if requested. 79 if req.GetNeedUserServices() { 80 registerUserServices(ctx, srv, logger, &req, svcs, false) 81 } 82 83 if regErr != nil { 84 err := errors.Wrap(regErr, "gRPC server initialization failed") 85 res := &protocol.HandshakeResponse{ 86 Error: &protocol.HandshakeError{ 87 Reason: fmt.Sprintf("gRPC server initialization failed: %v", err), 88 }, 89 } 90 sendRawMessage(w, res) 91 return err 92 } 93 94 if err := sendRawMessage(w, &protocol.HandshakeResponse{}); err != nil { 95 return err 96 } 97 98 // From now on, catch SIGINT/SIGTERM to stop the server gracefully. 99 sigCh := make(chan os.Signal, 1) 100 defer close(sigCh) 101 signal.Notify(sigCh, unix.SIGINT, unix.SIGTERM) 102 defer signal.Stop(sigCh) 103 sigErrCh := make(chan error, 1) 104 go func() { 105 if sig, ok := <-sigCh; ok { 106 sigErrCh <- errors.Errorf("caught signal %d (%s)", sig, sig) 107 srv.Stop() 108 } 109 }() 110 111 if err := srv.Serve(NewPipeListener(r, w)); err != nil && err != io.EOF { 112 // Replace the error if we saw a signal. 113 select { 114 case err := <-sigErrCh: 115 return err 116 default: 117 } 118 return err 119 } 120 return nil 121 } 122 123 // tcpServerResponse contains the return value for RunTCPServer. 124 type tcpServerResponse struct { 125 // Port represents the TCP port number the gRPC server is listening on. 126 Port int `json:"port"` 127 } 128 129 // RunTCPServer runs a gRPC server listening on the specified port thought TCP 130 // Port contains the TCP port number where gRPC server listens to 131 // HandshakeRequest contains parameters needed to initialize a gRPC server. 132 // stdin is the linux standard input. 133 // stdout is the linux standard output. 134 // stderr is the linux standard error. 135 // svcs is the candidate list of user-defined gRPC services and they will be 136 // registered if GuaranteeCompatibility is set. 137 func RunTCPServer(port int, handshakeReq *protocol.HandshakeRequest, stdin io.Reader, stdout, stderr io.Writer, 138 svcs []*testing.Service, register func(srv *grpc.Server, req *protocol.HandshakeRequest) error) error { 139 // Make sure to return only after all active method calls finish. 140 // Otherwise the process can exit before running deferred function 141 // calls on service goroutines. 142 var calls sync.WaitGroup 143 defer calls.Wait() 144 145 // consoleLogger channels logs to stderr 146 consoleLogFunc := func(msg string) { 147 fmt.Fprintln(stderr, msg) 148 } 149 150 // Setup logger to encapsulate the underlying logging mechanism. 151 logger := newFuncLogger(consoleLogFunc) 152 srv := grpc.NewServer(serverOpts(nil, logger, &calls)...) 153 154 // Register core services. 155 if err := registerCoreServices(srv, nil, handshakeReq, register); err != nil { 156 return errors.Wrap(err, "gRPC server initialization failed") 157 } 158 159 // Create a server-scoped context. 160 ctx, cancel := context.WithCancel(context.Background()) 161 defer cancel() 162 163 // Register user-defined gRPC services intended for public use. 164 registerUserServices(ctx, srv, logger, handshakeReq, svcs, true) 165 166 // From now on, catch SIGINT/SIGTERM to stop the server gracefully. 167 sigCh := make(chan os.Signal, 1) 168 defer close(sigCh) 169 signal.Notify(sigCh, unix.SIGINT, unix.SIGTERM) 170 defer signal.Stop(sigCh) 171 sigErrCh := make(chan error, 1) 172 go func() { 173 if sig, ok := <-sigCh; ok { 174 sigErrCh <- errors.Errorf("caught signal %d (%s)", sig, sig) 175 srv.Stop() 176 } 177 }() 178 179 // start gRPC server listening on the tcp port 180 listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", port)) 181 if err != nil { 182 return errors.Wrap(err, "server failed to listen") 183 } 184 185 // Return information regarding the server, paving the way for dynamic port assignment. 186 assignedPort := listener.Addr().(*net.TCPAddr).Port 187 response := &tcpServerResponse{Port: assignedPort} 188 responseBytes, err := json.Marshal(response) 189 if err != nil { 190 return errors.Wrapf(err, "failed to marshal json response: %v", response) 191 } 192 fmt.Fprintln(stdout, string(responseBytes)) 193 194 if err := srv.Serve(listener); err != nil && err != io.EOF { 195 // Replace the error if we saw a signal. 196 select { 197 case err := <-sigErrCh: 198 return err 199 default: 200 } 201 return err 202 } 203 return nil 204 } 205 206 // serverStreamWithContext wraps grpc.ServerStream with overriding Context. 207 type serverStreamWithContext struct { 208 grpc.ServerStream 209 ctx context.Context 210 } 211 212 // Context overrides grpc.ServerStream.Context. 213 func (s *serverStreamWithContext) Context() context.Context { 214 return s.ctx 215 } 216 217 var _ grpc.ServerStream = (*serverStreamWithContext)(nil) 218 219 // serverOpts returns gRPC server-side interceptors to manipulate context. 220 func serverOpts(ls *remoteLoggingServer, logger logging.Logger, calls *sync.WaitGroup) []grpc.ServerOption { 221 // hook is called on every gRPC method call. 222 // It returns a Context to be passed to a gRPC method, a function to be 223 // called on the end of the gRPC method call to compute trailers, and 224 // possibly an error. 225 hook := func(ctx context.Context, method string) (context.Context, func() metadata.MD, error) { 226 // Forward all uncaptured logs to logger 227 ctx = logging.AttachLogger(ctx, logger) 228 229 var outDir string 230 var tl *timing.Log 231 if isUserMethod(method) { 232 md, ok := metadata.FromIncomingContext(ctx) 233 if !ok { 234 return nil, nil, errors.New("metadata not available") 235 } 236 237 var err error 238 outDir, err = ioutil.TempDir("", "rpc-outdir.") 239 if err != nil { 240 return nil, nil, err 241 } 242 243 // Make the directory world-writable so that tests can create files as other users, 244 // and set the sticky bit to prevent users from deleting other users' files. 245 if err := os.Chmod(outDir, 0777|os.ModeSticky); err != nil { 246 return nil, nil, err 247 } 248 249 ctx = testcontext.WithCurrentEntity(ctx, incomingCurrentContext(md, outDir)) 250 tl = timing.NewLog() 251 ctx = timing.NewContext(ctx, tl) 252 } 253 254 trailer := func() metadata.MD { 255 md := make(metadata.MD) 256 257 if isUserMethod(method) { 258 b, err := json.Marshal(tl) 259 if err != nil { 260 logging.Info(ctx, "Failed to marshal timing JSON: ", err) 261 } else { 262 md[metadataTiming] = []string{string(b)} 263 } 264 265 // Send metadataOutDir only if some files were saved in order to avoid extra round-trips. 266 if fis, err := ioutil.ReadDir(outDir); err != nil { 267 logging.Info(ctx, "gRPC output directory is corrupted: ", err) 268 } else if len(fis) == 0 { 269 if err := os.RemoveAll(outDir); err != nil { 270 logging.Info(ctx, "Failed to remove gRPC output directory: ", err) 271 } 272 } else { 273 md[metadataOutDir] = []string{outDir} 274 } 275 } 276 277 if !isLoggingMethod(method) && ls != nil { 278 md[metadataLogLastSeq] = []string{strconv.FormatUint(ls.LastSeq(), 10)} 279 } 280 return md 281 } 282 return ctx, trailer, nil 283 } 284 285 return []grpc.ServerOption{ 286 grpc.UnaryInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (res interface{}, err error) { 287 defer func() { 288 if r := recover(); r != nil { 289 err = status.Error(codes.Internal, fmt.Sprintf("panic: %v", r)) 290 } 291 }() 292 calls.Add(1) 293 defer calls.Done() 294 ctx, trailer, err := hook(ctx, info.FullMethod) 295 if err != nil { 296 return nil, err 297 } 298 defer func() { 299 grpc.SetTrailer(ctx, trailer()) 300 }() 301 return handler(ctx, req) 302 }), 303 grpc.StreamInterceptor(func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { 304 defer func() { 305 if r := recover(); r != nil { 306 err = status.Error(codes.Internal, fmt.Sprintf("panic: %v", r)) 307 } 308 }() 309 calls.Add(1) 310 defer calls.Done() 311 ctx, trailer, err := hook(stream.Context(), info.FullMethod) 312 if err != nil { 313 return err 314 } 315 stream = &serverStreamWithContext{stream, ctx} 316 defer func() { 317 stream.SetTrailer(trailer()) 318 }() 319 return handler(srv, stream) 320 }), 321 grpc.MaxRecvMsgSize(MaxMessageSize), 322 grpc.MaxSendMsgSize(MaxMessageSize), 323 } 324 } 325 326 // registerCoreServices registers core Tast services. 327 // srv is the gRPC server instance 328 // ls is the remote logging server that forwards logs through side channel 329 // HandshakeRequest contains parameters needed to initialize a gRPC server. 330 // svcs is the candidate list of user-defined gRPC services to be registered 331 // register offers a callback hook for additional service registration 332 func registerCoreServices(srv *grpc.Server, ls *remoteLoggingServer, 333 handshakeReq *protocol.HandshakeRequest, register func(srv *grpc.Server, req *protocol.HandshakeRequest) error) error { 334 reflection.Register(srv) 335 if ls != nil { 336 protocol.RegisterLoggingServer(srv, ls) 337 } 338 protocol.RegisterFileTransferServer(srv, newFileTransferServer()) 339 return register(srv, handshakeReq) 340 } 341 342 // newFuncLogger provides setup for logging functionalities 343 func newFuncLogger(logFunc func(msg string)) logging.Logger { 344 // logger provides logging functionality to services while encapsulates the actual 345 // logging destination and mechanism. 346 return logging.NewSinkLogger(logging.LevelInfo, false, logging.NewFuncSink(logFunc)) 347 } 348 349 // registerUserServices registers user defined gRPC services to the gRPC Server 350 // srv is the gRPC server instance 351 // logger provides logging functionality to services 352 // HandshakeRequest contains parameters needed to initialize a gRPC server. 353 // svcs is the candidate list of user-defined gRPC services to be registered 354 // guaranteeCompatibilityOnly determines if the service registration is restricted 355 // only to services with GuaranteeCompatibility set 356 func registerUserServices(ctx context.Context, srv *grpc.Server, logger logging.Logger, 357 handshakeReq *protocol.HandshakeRequest, svcs []*testing.Service, guaranteeCompatibilityOnly bool) error { 358 ctx = logging.AttachLogger(ctx, logger) 359 360 vars := handshakeReq.GetBundleInitParams().GetVars() 361 for _, svc := range svcs { 362 if !guaranteeCompatibilityOnly || svc.GuaranteeCompatibility { 363 svc.Register(srv, testing.NewServiceState(ctx, testing.NewServiceRoot(svc, vars))) 364 } 365 } 366 return nil 367 } 368 369 // startServing kicks off the gRPC server listening through the listener 370 func startServing(srv *grpc.Server, listener net.Listener) error { 371 // From now on, catch SIGINT/SIGTERM to stop the server gracefully. 372 sigCh := make(chan os.Signal, 1) 373 defer close(sigCh) 374 signal.Notify(sigCh, unix.SIGINT, unix.SIGTERM) 375 defer signal.Stop(sigCh) 376 sigErrCh := make(chan error, 1) 377 go func() { 378 if sig, ok := <-sigCh; ok { 379 sigErrCh <- errors.Errorf("caught signal %d (%s)", sig, sig) 380 srv.Stop() 381 } 382 }() 383 384 if err := srv.Serve(listener); err != nil { 385 // Replace the error if we saw a signal. 386 select { 387 case err := <-sigErrCh: 388 return err 389 default: 390 } 391 return err 392 } 393 return nil 394 }