github.com/nya3jp/tast@v0.0.0-20230601000426-85c8e4d83a9b/src/go.chromium.org/tast/core/internal/rpc/client.go (about)

     1  // Copyright 2019 The ChromiumOS Authors
     2  // Use of this source code is governed by a BSD-style license that can be
     3  // found in the LICENSE file.
     4  
     5  package rpc
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"fmt"
    11  	"io"
    12  	"os"
    13  	"os/exec"
    14  	"strconv"
    15  	"strings"
    16  	"sync/atomic"
    17  
    18  	"github.com/shirou/gopsutil/v3/process"
    19  	"golang.org/x/sys/unix"
    20  	"google.golang.org/grpc"
    21  	"google.golang.org/grpc/codes"
    22  	"google.golang.org/grpc/metadata"
    23  	"google.golang.org/grpc/status"
    24  
    25  	"go.chromium.org/tast/core/errors"
    26  	"go.chromium.org/tast/core/internal/protocol"
    27  	"go.chromium.org/tast/core/internal/testcontext"
    28  	"go.chromium.org/tast/core/internal/timing"
    29  	"go.chromium.org/tast/core/ssh"
    30  	"go.chromium.org/tast/core/testing"
    31  )
    32  
    33  // SSHClient is a Tast gRPC client over an SSH connection.
    34  type SSHClient struct {
    35  	cl  *GenericClient
    36  	cmd *ssh.Cmd
    37  }
    38  
    39  // Conn returns a gRPC connection.
    40  func (c *SSHClient) Conn() *grpc.ClientConn {
    41  	return c.cl.Conn()
    42  }
    43  
    44  // Close closes this client.
    45  func (c *SSHClient) Close(opts ...ssh.RunOption) error {
    46  	closeErr := c.cl.Close()
    47  	c.cmd.Abort()
    48  	// Ignore errors from Wait since Abort above causes it to return context.Canceled.
    49  	c.cmd.Wait(opts...)
    50  	return closeErr
    51  }
    52  
    53  // DialSSH establishes a gRPC connection to an executable on a remote machine.
    54  // proxy if true indicates that HTTP proxy environment variables should be forwarded.
    55  //
    56  // The context passed in must remain valid for as long as the gRPC connection.
    57  // I.e. Don't use the context from within a testing.Poll function.
    58  func DialSSH(ctx context.Context, conn *ssh.Conn, path string, req *protocol.HandshakeRequest, proxy bool) (*SSHClient, error) {
    59  	args := []string{path, "-rpc"}
    60  	if proxy {
    61  		var envArgs []string
    62  		// Proxy-related variables can be either uppercase or lowercase.
    63  		// See https://golang.org/pkg/net/http/#ProxyFromEnvironment.
    64  		for _, name := range []string{
    65  			"HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY",
    66  			"http_proxy", "https_proxy", "no_proxy",
    67  		} {
    68  			if val := os.Getenv(name); val != "" {
    69  				envArgs = append(envArgs, fmt.Sprintf("%s=%s", name, val))
    70  			}
    71  		}
    72  		args = append(append([]string{"env"}, envArgs...), args...)
    73  	}
    74  	testing.ContextLog(ctx, "Running rpc server: ", args)
    75  	cmd := conn.CommandContext(ctx, args[0], args[1:]...)
    76  	stdin, err := cmd.StdinPipe()
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  	stdout, err := cmd.StdoutPipe()
    81  	if err != nil {
    82  		return nil, err
    83  	}
    84  	if err := cmd.Start(); err != nil {
    85  		return nil, errors.Wrap(err, "failed to connect to RPC service on DUT")
    86  	}
    87  
    88  	c, err := NewClient(ctx, stdout, stdin, req)
    89  	if err != nil {
    90  		cmd.Abort()
    91  		cmd.Wait()
    92  		return nil, err
    93  	}
    94  	return &SSHClient{
    95  		cl:  c,
    96  		cmd: cmd,
    97  	}, nil
    98  }
    99  
   100  // ExecClient is a Tast gRPC client over a locally executed subprocess.
   101  type ExecClient struct {
   102  	cl         *GenericClient
   103  	cmd        *exec.Cmd
   104  	newSession bool
   105  }
   106  
   107  // Conn returns a gRPC connection.
   108  func (c *ExecClient) Conn() *grpc.ClientConn {
   109  	return c.cl.Conn()
   110  }
   111  
   112  // PID returns PID of the subprocess.
   113  func (c *ExecClient) PID() int {
   114  	return c.cmd.Process.Pid
   115  }
   116  
   117  // Close closes this client.
   118  func (c *ExecClient) Close() error {
   119  	var firstErr error
   120  	if err := c.cl.Close(); err != nil && firstErr == nil {
   121  		firstErr = err
   122  	}
   123  	if err := c.cmd.Process.Kill(); err != nil && firstErr == nil {
   124  		firstErr = err
   125  	}
   126  	if c.newSession {
   127  		killSession(c.cmd.Process.Pid)
   128  	}
   129  	c.cmd.Wait() // ignore error `signal: killed`
   130  	return firstErr
   131  }
   132  
   133  // DialExec establishes a gRPC connection to an executable on host.
   134  // If newSession is true, a new session is created for the subprocess and its
   135  // descendants so that all of them are killed on closing Client.
   136  func DialExec(ctx context.Context, path string, newSession bool, req *protocol.HandshakeRequest) (*ExecClient, error) {
   137  	cmd := exec.CommandContext(ctx, path, "-rpc")
   138  	stdin, err := cmd.StdinPipe()
   139  	if err != nil {
   140  		return nil, errors.Wrapf(err, "failed to run %s for RPC", path)
   141  	}
   142  	stdout, err := cmd.StdoutPipe()
   143  	if err != nil {
   144  		return nil, errors.Wrapf(err, "failed to run %s for RPC", path)
   145  	}
   146  	cmd.Stderr = os.Stderr // ease debug
   147  	if newSession {
   148  		cmd.SysProcAttr = &unix.SysProcAttr{Setsid: true}
   149  	}
   150  	if err := cmd.Start(); err != nil {
   151  		return nil, errors.Wrapf(err, "failed to run %s for RPC", path)
   152  	}
   153  	c, err := NewClient(ctx, stdout, stdin, req)
   154  	if err != nil {
   155  		cmd.Process.Kill()
   156  		if newSession {
   157  			killSession(cmd.Process.Pid)
   158  		}
   159  		cmd.Wait()
   160  		return nil, err
   161  	}
   162  	return &ExecClient{
   163  		cl:         c,
   164  		cmd:        cmd,
   165  		newSession: newSession,
   166  	}, nil
   167  }
   168  
   169  // GenericClient is a Tast gRPC client.
   170  type GenericClient struct {
   171  	conn *grpc.ClientConn
   172  	log  *remoteLoggingClient
   173  }
   174  
   175  // Conn returns a gRPC connection.
   176  func (c *GenericClient) Conn() *grpc.ClientConn {
   177  	return c.conn
   178  }
   179  
   180  // Close closes this client.
   181  func (c *GenericClient) Close() error {
   182  	var firstErr error
   183  	if err := c.log.Close(); err != nil && firstErr == nil {
   184  		firstErr = err
   185  	}
   186  	if err := c.conn.Close(); err != nil && firstErr == nil {
   187  		firstErr = err
   188  	}
   189  	return firstErr
   190  }
   191  
   192  // NewClient establishes a gRPC connection to a test bundle executable using r
   193  // and w.
   194  // Callers are responsible for closing the underlying connection of r/w after
   195  // the client is closed.
   196  func NewClient(ctx context.Context, r io.Reader, w io.Writer, req *protocol.HandshakeRequest, opts ...grpc.DialOption) (_ *GenericClient, retErr error) {
   197  	if err := sendRawMessage(w, req); err != nil {
   198  		return nil, err
   199  	}
   200  	res := &protocol.HandshakeResponse{}
   201  	if err := receiveRawMessage(r, res); err != nil {
   202  		return nil, err
   203  	}
   204  	if res.Error != nil {
   205  		return nil, errors.Errorf("bundle returned error: %s", res.Error.GetReason())
   206  	}
   207  
   208  	lazyLog := newLazyRemoteLoggingClient()
   209  	conn, err := NewPipeClientConn(ctx, r, w, append(clientOpts(lazyLog), opts...)...)
   210  	if err != nil {
   211  		return nil, errors.Wrap(err, "failed to establish RPC connection")
   212  	}
   213  	defer func() {
   214  		if retErr != nil {
   215  			conn.Close()
   216  		}
   217  	}()
   218  
   219  	log, err := newRemoteLoggingClient(ctx, conn)
   220  	if err != nil {
   221  		return nil, errors.Wrap(err, "failed to start remote logging")
   222  	}
   223  
   224  	lazyLog.SetClient(log)
   225  
   226  	return &GenericClient{
   227  		conn: conn,
   228  		log:  log,
   229  	}, nil
   230  }
   231  
   232  var alwaysAllowedServices = []string{
   233  	"tast.cros.baserpc.FaillogService",
   234  	"tast.cros.baserpc.FileSystem",
   235  }
   236  
   237  // clientOpts returns gRPC client-side interceptors to manipulate context and
   238  // make sure all clients use the same GRPC send/recv message size.
   239  func clientOpts(lazyLog *lazyRemoteLoggingClient) []grpc.DialOption {
   240  	// hook is called on every gRPC method call.
   241  	// It returns a Context to be passed to a gRPC invocation, a function to be
   242  	// called on the end of the gRPC method call to process trailers, and
   243  	// possibly an error.
   244  	hook := func(ctx context.Context, cc *grpc.ClientConn, method string) (context.Context, func(metadata.MD) error, error) {
   245  		if isUserMethod(method) {
   246  			// Reject an outgoing RPC call if its service is not declared in ServiceDeps.
   247  			svcs, ok := testcontext.ServiceDeps(ctx)
   248  			if !ok {
   249  				return nil, nil, status.Errorf(codes.FailedPrecondition, "refusing to call %s because ServiceDeps is unavailable (using a wrong context?)", method)
   250  			}
   251  			svcs = append(svcs, alwaysAllowedServices...)
   252  			matched := false
   253  			for _, svc := range svcs {
   254  				if strings.HasPrefix(method, fmt.Sprintf("/%s/", svc)) {
   255  					matched = true
   256  					break
   257  				}
   258  			}
   259  			if !matched {
   260  				return nil, nil, status.Errorf(codes.FailedPrecondition, "refusing to call %s because it is not declared in ServiceDeps", method)
   261  			}
   262  		}
   263  
   264  		after := func(trailer metadata.MD) error {
   265  			var firstErr error
   266  			if isUserMethod(method) {
   267  				if err := processTimingTrailer(ctx, trailer.Get(metadataTiming)); err != nil && firstErr == nil {
   268  					firstErr = err
   269  				}
   270  				if err := processOutDirTrailer(ctx, cc, trailer.Get(metadataOutDir)); err != nil && firstErr == nil {
   271  					firstErr = err
   272  				}
   273  			}
   274  			if !isLoggingMethod(method) {
   275  				if err := processLoggingTrailer(ctx, lazyLog, trailer.Get(metadataLogLastSeq)); err != nil && firstErr == nil {
   276  					firstErr = err
   277  				}
   278  			}
   279  			return firstErr
   280  		}
   281  		return metadata.NewOutgoingContext(ctx, outgoingMetadata(ctx)), after, nil
   282  	}
   283  
   284  	return []grpc.DialOption{
   285  		grpc.WithUnaryInterceptor(func(ctx context.Context, method string, req, reply interface{},
   286  			cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
   287  			ctx, after, err := hook(ctx, cc, method)
   288  			if err != nil {
   289  				return err
   290  			}
   291  
   292  			var trailer metadata.MD
   293  			opts = append([]grpc.CallOption{grpc.Trailer(&trailer)}, opts...)
   294  			retErr := invoker(ctx, method, req, reply, cc, opts...)
   295  			if err := after(trailer); err != nil && retErr == nil {
   296  				retErr = err
   297  			}
   298  			return retErr
   299  		}),
   300  		grpc.WithStreamInterceptor(func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn,
   301  			method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
   302  			ctx, after, err := hook(ctx, cc, method)
   303  			if err != nil {
   304  				return nil, err
   305  			}
   306  			stream, err := streamer(ctx, desc, cc, method, opts...)
   307  			return &clientStreamWithAfter{ClientStream: stream, after: after}, err
   308  		}),
   309  		grpc.WithDefaultCallOptions(
   310  			grpc.MaxCallRecvMsgSize(MaxMessageSize),
   311  			grpc.MaxCallSendMsgSize(MaxMessageSize),
   312  		),
   313  	}
   314  }
   315  
   316  func processTimingTrailer(ctx context.Context, values []string) error {
   317  	if len(values) == 0 {
   318  		return nil
   319  	}
   320  	if len(values) >= 2 {
   321  		return errors.Errorf("gRPC trailer %s contains %d values", metadataTiming, len(values))
   322  	}
   323  
   324  	var tl timing.Log
   325  	if err := json.Unmarshal([]byte(values[0]), &tl); err != nil {
   326  		return errors.Wrapf(err, "failed to parse gRPC trailer %s", metadataTiming)
   327  	}
   328  	if _, stg, ok := timing.FromContext(ctx); ok {
   329  		if err := stg.Import(&tl); err != nil {
   330  			return errors.Wrap(err, "failed to import gRPC timing log")
   331  		}
   332  	}
   333  	return nil
   334  }
   335  
   336  func processOutDirTrailer(ctx context.Context, cc *grpc.ClientConn, values []string) error {
   337  	if len(values) == 0 {
   338  		return nil
   339  	}
   340  	if len(values) >= 2 {
   341  		return errors.Errorf("gRPC trailer %s contains %d values", metadataOutDir, len(values))
   342  	}
   343  
   344  	src := values[0]
   345  	dst, ok := testcontext.OutDir(ctx)
   346  	if !ok {
   347  		return errors.New("output directory not associated to the context")
   348  	}
   349  
   350  	if err := pullDirectory(ctx, protocol.NewFileTransferClient(cc), src, dst); err != nil {
   351  		return errors.Wrap(err, "failed to pull output files from gRPC service")
   352  	}
   353  	return nil
   354  }
   355  
   356  func processLoggingTrailer(ctx context.Context, lazyLog *lazyRemoteLoggingClient, values []string) error {
   357  	if len(values) == 0 {
   358  		return nil
   359  	}
   360  	if len(values) >= 2 {
   361  		return errors.Errorf("gRPC trailer %s contains %d values", metadataLogLastSeq, len(values))
   362  	}
   363  
   364  	seq, err := strconv.ParseUint(values[0], 10, 64)
   365  	if err != nil {
   366  		return errors.Wrapf(err, "failed to parse gRPC trailer %s", metadataLogLastSeq)
   367  	}
   368  
   369  	if err := lazyLog.Wait(ctx, seq); err != nil {
   370  		return errors.Wrap(err, "failed to wait for pending logs")
   371  	}
   372  	return nil
   373  }
   374  
   375  // clientStreamWithAfter wraps grpc.ClientStream with a function to be called
   376  // on the end of the streaming call.
   377  type clientStreamWithAfter struct {
   378  	grpc.ClientStream
   379  	after func(trailer metadata.MD) error
   380  	done  bool
   381  }
   382  
   383  func (s *clientStreamWithAfter) RecvMsg(m interface{}) error {
   384  	retErr := s.ClientStream.RecvMsg(m)
   385  	if retErr == nil {
   386  		return nil
   387  	}
   388  
   389  	if s.done {
   390  		return retErr
   391  	}
   392  	s.done = true
   393  
   394  	if err := s.after(s.Trailer()); err != nil && retErr == io.EOF {
   395  		retErr = err
   396  	}
   397  	return retErr
   398  }
   399  
   400  // lazyRemoteLoggingClient wraps remoteLoggingClient for lazy initialization.
   401  // We have to install logging hooks on starting a gRPC connection, but
   402  // remoteLoggingClient can be started only after a gRPC connection is ready.
   403  // lazyRemoteLoggingClient allows logging hooks to access remoteLoggingClient
   404  // after it becomes available.
   405  // lazyRemoteLoggingClient is goroutine-safe.
   406  type lazyRemoteLoggingClient struct {
   407  	client atomic.Value
   408  }
   409  
   410  func newLazyRemoteLoggingClient() *lazyRemoteLoggingClient {
   411  	return &lazyRemoteLoggingClient{}
   412  }
   413  
   414  func (l *lazyRemoteLoggingClient) SetClient(client *remoteLoggingClient) {
   415  	l.client.Store(client)
   416  }
   417  
   418  func (l *lazyRemoteLoggingClient) Wait(ctx context.Context, seq uint64) error {
   419  	client, ok := l.client.Load().(*remoteLoggingClient)
   420  	if !ok {
   421  		return nil
   422  	}
   423  	return client.Wait(ctx, seq)
   424  }
   425  
   426  // killSession makes a best-effort attempt to kill all processes in session sid.
   427  // It makes several passes over the list of running processes, sending sig to any
   428  // that are part of the session. After it doesn't find any new processes, it returns.
   429  // Note that this is racy: it's possible (but hopefully unlikely) that continually-forking
   430  // processes could spawn children that don't get killed.
   431  func killSession(sid int) {
   432  	const maxPasses = 3
   433  	for i := 0; i < maxPasses; i++ {
   434  		pids, err := process.Pids()
   435  		if err != nil {
   436  			return
   437  		}
   438  		n := 0
   439  		for _, pid := range pids {
   440  			pid := int(pid)
   441  			if s, err := unix.Getsid(pid); err == nil && s == sid {
   442  				unix.Kill(pid, unix.SIGKILL)
   443  				n++
   444  			}
   445  		}
   446  		// If we didn't find any processes in the session, we're done.
   447  		if n == 0 {
   448  			return
   449  		}
   450  	}
   451  }