github.com/nya3jp/tast@v0.0.0-20230601000426-85c8e4d83a9b/src/go.chromium.org/tast/core/internal/rpc/server.go (about)

     1  // Copyright 2019 The ChromiumOS Authors
     2  // Use of this source code is governed by a BSD-style license that can be
     3  // found in the LICENSE file.
     4  
     5  package rpc
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"fmt"
    11  	"io"
    12  	"io/ioutil"
    13  	"net"
    14  	"os"
    15  	"os/signal"
    16  	"strconv"
    17  	"sync"
    18  
    19  	"golang.org/x/sys/unix"
    20  	"google.golang.org/grpc"
    21  	"google.golang.org/grpc/codes"
    22  	"google.golang.org/grpc/metadata"
    23  	"google.golang.org/grpc/reflection"
    24  	"google.golang.org/grpc/status"
    25  
    26  	"go.chromium.org/tast/core/errors"
    27  	"go.chromium.org/tast/core/internal/logging"
    28  	"go.chromium.org/tast/core/internal/protocol"
    29  	"go.chromium.org/tast/core/internal/testcontext"
    30  	"go.chromium.org/tast/core/internal/testing"
    31  	"go.chromium.org/tast/core/internal/timing"
    32  )
    33  
    34  // MaxMessageSize is used to tell Tast's GRPC servers and clients the maximum size of messages.
    35  const MaxMessageSize = 1024 * 1024 * 8
    36  
    37  // RunServer runs a gRPC server on r/w channels.
    38  // register is called back to register core services. svcs is a list of
    39  // user-defined gRPC services to be registered if the client requests them in
    40  // HandshakeRequest.
    41  // RunServer blocks until the client connection is closed or it encounters an
    42  // error.
    43  func RunServer(r io.Reader, w io.Writer, svcs []*testing.Service, register func(srv *grpc.Server, req *protocol.HandshakeRequest) error) error {
    44  	// In case w is stdout or stderr, writing data to it after it is closed
    45  	// causes SIGPIPE to be delivered to the process, which by default
    46  	// terminates the process without running deferred cleanup calls.
    47  	// To avoid the issue, ignore SIGPIPE while running the gRPC server.
    48  	// See https://golang.org/pkg/os/signal/#hdr-SIGPIPE for more details.
    49  	signal.Ignore(unix.SIGPIPE)
    50  	defer signal.Reset(unix.SIGPIPE)
    51  
    52  	var req protocol.HandshakeRequest
    53  	if err := receiveRawMessage(r, &req); err != nil {
    54  		return err
    55  	}
    56  
    57  	// Make sure to return only after all active method calls finish.
    58  	// Otherwise the process can exit before running deferred function
    59  	// calls on service goroutines.
    60  	var calls sync.WaitGroup
    61  	defer calls.Wait()
    62  
    63  	// Start a remote logging server. It is used to forward logs from
    64  	// user-defined gRPC services via side channels.
    65  	ls := newRemoteLoggingServer()
    66  
    67  	// Setup logger to encapsulate the underlying logging mechanism.
    68  	logger := newFuncLogger(ls.Log)
    69  	srv := grpc.NewServer(serverOpts(ls, logger, &calls)...)
    70  
    71  	// Register core services.
    72  	regErr := registerCoreServices(srv, ls, &req, register)
    73  
    74  	// Create a server-scoped context.
    75  	ctx, cancel := context.WithCancel(context.Background())
    76  	defer cancel()
    77  
    78  	// Register user-defined gRPC services if requested.
    79  	if req.GetNeedUserServices() {
    80  		registerUserServices(ctx, srv, logger, &req, svcs, false)
    81  	}
    82  
    83  	if regErr != nil {
    84  		err := errors.Wrap(regErr, "gRPC server initialization failed")
    85  		res := &protocol.HandshakeResponse{
    86  			Error: &protocol.HandshakeError{
    87  				Reason: fmt.Sprintf("gRPC server initialization failed: %v", err),
    88  			},
    89  		}
    90  		sendRawMessage(w, res)
    91  		return err
    92  	}
    93  
    94  	if err := sendRawMessage(w, &protocol.HandshakeResponse{}); err != nil {
    95  		return err
    96  	}
    97  
    98  	// From now on, catch SIGINT/SIGTERM to stop the server gracefully.
    99  	sigCh := make(chan os.Signal, 1)
   100  	defer close(sigCh)
   101  	signal.Notify(sigCh, unix.SIGINT, unix.SIGTERM)
   102  	defer signal.Stop(sigCh)
   103  	sigErrCh := make(chan error, 1)
   104  	go func() {
   105  		if sig, ok := <-sigCh; ok {
   106  			sigErrCh <- errors.Errorf("caught signal %d (%s)", sig, sig)
   107  			srv.Stop()
   108  		}
   109  	}()
   110  
   111  	if err := srv.Serve(NewPipeListener(r, w)); err != nil && err != io.EOF {
   112  		// Replace the error if we saw a signal.
   113  		select {
   114  		case err := <-sigErrCh:
   115  			return err
   116  		default:
   117  		}
   118  		return err
   119  	}
   120  	return nil
   121  }
   122  
   123  // tcpServerResponse contains the return value for RunTCPServer.
   124  type tcpServerResponse struct {
   125  	// Port represents the TCP port number the gRPC server is listening on.
   126  	Port int `json:"port"`
   127  }
   128  
   129  // RunTCPServer runs a gRPC server listening on the specified port thought TCP
   130  // Port contains the TCP port number where gRPC server listens to
   131  // HandshakeRequest contains parameters needed to initialize a gRPC server.
   132  // stdin is the linux standard input.
   133  // stdout is the linux standard output.
   134  // stderr is the linux standard error.
   135  // svcs is the candidate list of user-defined gRPC services and they will be
   136  // registered if GuaranteeCompatibility is set.
   137  func RunTCPServer(port int, handshakeReq *protocol.HandshakeRequest, stdin io.Reader, stdout, stderr io.Writer,
   138  	svcs []*testing.Service, register func(srv *grpc.Server, req *protocol.HandshakeRequest) error) error {
   139  	// Make sure to return only after all active method calls finish.
   140  	// Otherwise the process can exit before running deferred function
   141  	// calls on service goroutines.
   142  	var calls sync.WaitGroup
   143  	defer calls.Wait()
   144  
   145  	// consoleLogger channels logs to stderr
   146  	consoleLogFunc := func(msg string) {
   147  		fmt.Fprintln(stderr, msg)
   148  	}
   149  
   150  	// Setup logger to encapsulate the underlying logging mechanism.
   151  	logger := newFuncLogger(consoleLogFunc)
   152  	srv := grpc.NewServer(serverOpts(nil, logger, &calls)...)
   153  
   154  	// Register core services.
   155  	if err := registerCoreServices(srv, nil, handshakeReq, register); err != nil {
   156  		return errors.Wrap(err, "gRPC server initialization failed")
   157  	}
   158  
   159  	// Create a server-scoped context.
   160  	ctx, cancel := context.WithCancel(context.Background())
   161  	defer cancel()
   162  
   163  	// Register user-defined gRPC services intended for public use.
   164  	registerUserServices(ctx, srv, logger, handshakeReq, svcs, true)
   165  
   166  	// From now on, catch SIGINT/SIGTERM to stop the server gracefully.
   167  	sigCh := make(chan os.Signal, 1)
   168  	defer close(sigCh)
   169  	signal.Notify(sigCh, unix.SIGINT, unix.SIGTERM)
   170  	defer signal.Stop(sigCh)
   171  	sigErrCh := make(chan error, 1)
   172  	go func() {
   173  		if sig, ok := <-sigCh; ok {
   174  			sigErrCh <- errors.Errorf("caught signal %d (%s)", sig, sig)
   175  			srv.Stop()
   176  		}
   177  	}()
   178  
   179  	// start gRPC server listening on the tcp port
   180  	listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", port))
   181  	if err != nil {
   182  		return errors.Wrap(err, "server failed to listen")
   183  	}
   184  
   185  	// Return information regarding the server, paving the way for dynamic port assignment.
   186  	assignedPort := listener.Addr().(*net.TCPAddr).Port
   187  	response := &tcpServerResponse{Port: assignedPort}
   188  	responseBytes, err := json.Marshal(response)
   189  	if err != nil {
   190  		return errors.Wrapf(err, "failed to marshal json response: %v", response)
   191  	}
   192  	fmt.Fprintln(stdout, string(responseBytes))
   193  
   194  	if err := srv.Serve(listener); err != nil && err != io.EOF {
   195  		// Replace the error if we saw a signal.
   196  		select {
   197  		case err := <-sigErrCh:
   198  			return err
   199  		default:
   200  		}
   201  		return err
   202  	}
   203  	return nil
   204  }
   205  
   206  // serverStreamWithContext wraps grpc.ServerStream with overriding Context.
   207  type serverStreamWithContext struct {
   208  	grpc.ServerStream
   209  	ctx context.Context
   210  }
   211  
   212  // Context overrides grpc.ServerStream.Context.
   213  func (s *serverStreamWithContext) Context() context.Context {
   214  	return s.ctx
   215  }
   216  
   217  var _ grpc.ServerStream = (*serverStreamWithContext)(nil)
   218  
   219  // serverOpts returns gRPC server-side interceptors to manipulate context.
   220  func serverOpts(ls *remoteLoggingServer, logger logging.Logger, calls *sync.WaitGroup) []grpc.ServerOption {
   221  	// hook is called on every gRPC method call.
   222  	// It returns a Context to be passed to a gRPC method, a function to be
   223  	// called on the end of the gRPC method call to compute trailers, and
   224  	// possibly an error.
   225  	hook := func(ctx context.Context, method string) (context.Context, func() metadata.MD, error) {
   226  		// Forward all uncaptured logs to logger
   227  		ctx = logging.AttachLogger(ctx, logger)
   228  
   229  		var outDir string
   230  		var tl *timing.Log
   231  		if isUserMethod(method) {
   232  			md, ok := metadata.FromIncomingContext(ctx)
   233  			if !ok {
   234  				return nil, nil, errors.New("metadata not available")
   235  			}
   236  
   237  			var err error
   238  			outDir, err = ioutil.TempDir("", "rpc-outdir.")
   239  			if err != nil {
   240  				return nil, nil, err
   241  			}
   242  
   243  			// Make the directory world-writable so that tests can create files as other users,
   244  			// and set the sticky bit to prevent users from deleting other users' files.
   245  			if err := os.Chmod(outDir, 0777|os.ModeSticky); err != nil {
   246  				return nil, nil, err
   247  			}
   248  
   249  			ctx = testcontext.WithCurrentEntity(ctx, incomingCurrentContext(md, outDir))
   250  			tl = timing.NewLog()
   251  			ctx = timing.NewContext(ctx, tl)
   252  		}
   253  
   254  		trailer := func() metadata.MD {
   255  			md := make(metadata.MD)
   256  
   257  			if isUserMethod(method) {
   258  				b, err := json.Marshal(tl)
   259  				if err != nil {
   260  					logging.Info(ctx, "Failed to marshal timing JSON: ", err)
   261  				} else {
   262  					md[metadataTiming] = []string{string(b)}
   263  				}
   264  
   265  				// Send metadataOutDir only if some files were saved in order to avoid extra round-trips.
   266  				if fis, err := ioutil.ReadDir(outDir); err != nil {
   267  					logging.Info(ctx, "gRPC output directory is corrupted: ", err)
   268  				} else if len(fis) == 0 {
   269  					if err := os.RemoveAll(outDir); err != nil {
   270  						logging.Info(ctx, "Failed to remove gRPC output directory: ", err)
   271  					}
   272  				} else {
   273  					md[metadataOutDir] = []string{outDir}
   274  				}
   275  			}
   276  
   277  			if !isLoggingMethod(method) && ls != nil {
   278  				md[metadataLogLastSeq] = []string{strconv.FormatUint(ls.LastSeq(), 10)}
   279  			}
   280  			return md
   281  		}
   282  		return ctx, trailer, nil
   283  	}
   284  
   285  	return []grpc.ServerOption{
   286  		grpc.UnaryInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (res interface{}, err error) {
   287  			defer func() {
   288  				if r := recover(); r != nil {
   289  					err = status.Error(codes.Internal, fmt.Sprintf("panic: %v", r))
   290  				}
   291  			}()
   292  			calls.Add(1)
   293  			defer calls.Done()
   294  			ctx, trailer, err := hook(ctx, info.FullMethod)
   295  			if err != nil {
   296  				return nil, err
   297  			}
   298  			defer func() {
   299  				grpc.SetTrailer(ctx, trailer())
   300  			}()
   301  			return handler(ctx, req)
   302  		}),
   303  		grpc.StreamInterceptor(func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) {
   304  			defer func() {
   305  				if r := recover(); r != nil {
   306  					err = status.Error(codes.Internal, fmt.Sprintf("panic: %v", r))
   307  				}
   308  			}()
   309  			calls.Add(1)
   310  			defer calls.Done()
   311  			ctx, trailer, err := hook(stream.Context(), info.FullMethod)
   312  			if err != nil {
   313  				return err
   314  			}
   315  			stream = &serverStreamWithContext{stream, ctx}
   316  			defer func() {
   317  				stream.SetTrailer(trailer())
   318  			}()
   319  			return handler(srv, stream)
   320  		}),
   321  		grpc.MaxRecvMsgSize(MaxMessageSize),
   322  		grpc.MaxSendMsgSize(MaxMessageSize),
   323  	}
   324  }
   325  
   326  // registerCoreServices registers core Tast services.
   327  // srv is the gRPC server instance
   328  // ls is the remote logging server that forwards logs through side channel
   329  // HandshakeRequest contains parameters needed to initialize a gRPC server.
   330  // svcs is the candidate list of user-defined gRPC services to be registered
   331  // register offers a callback hook for additional service registration
   332  func registerCoreServices(srv *grpc.Server, ls *remoteLoggingServer,
   333  	handshakeReq *protocol.HandshakeRequest, register func(srv *grpc.Server, req *protocol.HandshakeRequest) error) error {
   334  	reflection.Register(srv)
   335  	if ls != nil {
   336  		protocol.RegisterLoggingServer(srv, ls)
   337  	}
   338  	protocol.RegisterFileTransferServer(srv, newFileTransferServer())
   339  	return register(srv, handshakeReq)
   340  }
   341  
   342  // newFuncLogger provides setup for logging functionalities
   343  func newFuncLogger(logFunc func(msg string)) logging.Logger {
   344  	// logger provides logging functionality to services while encapsulates the actual
   345  	// logging destination and mechanism.
   346  	return logging.NewSinkLogger(logging.LevelInfo, false, logging.NewFuncSink(logFunc))
   347  }
   348  
   349  // registerUserServices registers user defined gRPC services to the gRPC Server
   350  // srv is the gRPC server instance
   351  // logger provides logging functionality to services
   352  // HandshakeRequest contains parameters needed to initialize a gRPC server.
   353  // svcs is the candidate list of user-defined gRPC services to be registered
   354  // guaranteeCompatibilityOnly determines if the service registration is restricted
   355  // only to services with GuaranteeCompatibility set
   356  func registerUserServices(ctx context.Context, srv *grpc.Server, logger logging.Logger,
   357  	handshakeReq *protocol.HandshakeRequest, svcs []*testing.Service, guaranteeCompatibilityOnly bool) error {
   358  	ctx = logging.AttachLogger(ctx, logger)
   359  
   360  	vars := handshakeReq.GetBundleInitParams().GetVars()
   361  	for _, svc := range svcs {
   362  		if !guaranteeCompatibilityOnly || svc.GuaranteeCompatibility {
   363  			svc.Register(srv, testing.NewServiceState(ctx, testing.NewServiceRoot(svc, vars)))
   364  		}
   365  	}
   366  	return nil
   367  }
   368  
   369  // startServing kicks off the gRPC server listening through the listener
   370  func startServing(srv *grpc.Server, listener net.Listener) error {
   371  	// From now on, catch SIGINT/SIGTERM to stop the server gracefully.
   372  	sigCh := make(chan os.Signal, 1)
   373  	defer close(sigCh)
   374  	signal.Notify(sigCh, unix.SIGINT, unix.SIGTERM)
   375  	defer signal.Stop(sigCh)
   376  	sigErrCh := make(chan error, 1)
   377  	go func() {
   378  		if sig, ok := <-sigCh; ok {
   379  			sigErrCh <- errors.Errorf("caught signal %d (%s)", sig, sig)
   380  			srv.Stop()
   381  		}
   382  	}()
   383  
   384  	if err := srv.Serve(listener); err != nil {
   385  		// Replace the error if we saw a signal.
   386  		select {
   387  		case err := <-sigErrCh:
   388  			return err
   389  		default:
   390  		}
   391  		return err
   392  	}
   393  	return nil
   394  }