github.com/nya3jp/tast@v0.0.0-20230601000426-85c8e4d83a9b/src/go.chromium.org/tast/core/internal/runner/runner.go (about)

     1  // Copyright 2017 The ChromiumOS Authors
     2  // Use of this source code is governed by a BSD-style license that can be
     3  // found in the LICENSE file.
     4  
     5  package runner
     6  
     7  import (
     8  	"context"
     9  	"io"
    10  	"io/ioutil"
    11  	"log"
    12  	"os"
    13  	"path/filepath"
    14  	"sort"
    15  	"strings"
    16  	"time"
    17  
    18  	"github.com/shirou/gopsutil/v3/process"
    19  	"golang.org/x/sys/unix"
    20  
    21  	"go.chromium.org/tast/core/errors"
    22  	"go.chromium.org/tast/core/internal/command"
    23  	"go.chromium.org/tast/core/internal/logging"
    24  	"go.chromium.org/tast/core/internal/protocol"
    25  	"go.chromium.org/tast/core/internal/testing"
    26  )
    27  
    28  const (
    29  	statusSuccess    = 0 // runner was successful
    30  	_                = 1 // deprecated
    31  	statusBadArgs    = 2 // bad arguments were passed to the runner
    32  	_                = 3 // deprecated
    33  	_                = 4 // deprecated
    34  	_                = 5 // deprecated
    35  	statusTestFailed = 6 // one or more tests failed during manual run
    36  	_                = 7 // deprecated
    37  	_                = 8 // deprecated
    38  )
    39  
    40  // Run reads command-line flags from clArgs and performs the requested action.
    41  // clArgs should typically be os.Args[1:]. The caller should exit with the
    42  // returned status code.
    43  func Run(clArgs []string, stdin io.Reader, stdout, stderr io.Writer, scfg *StaticConfig) int {
    44  	ctx := context.Background()
    45  
    46  	if scfg.EnableSyslog {
    47  		if l, err := logging.NewSyslogLogger(); err == nil {
    48  			defer l.Close()
    49  			ctx = logging.AttachLogger(ctx, l)
    50  		}
    51  	}
    52  	logging.Debug(ctx, "Tast local_runner starts")
    53  	defer logging.Debug(ctx, "Tast local_runner ends")
    54  
    55  	// TODO(b/189332919): Remove this hack and write stack traces to stderr
    56  	// once we finish migrating to gRPC-based protocol. This hack is needed
    57  	// because JSON-based protocol is designed to write messages to stderr
    58  	// in case of errors and thus Tast CLI consumes stderr.
    59  	if os.Getenv("TAST_B189332919_STACK_TRACE_FD") == "3" {
    60  		command.InstallSignalHandler(os.NewFile(3, ""), func(os.Signal) {})
    61  	}
    62  
    63  	args, err := parseArgs(clArgs, stderr, scfg)
    64  	if err != nil {
    65  		return command.WriteError(stderr, err)
    66  	}
    67  
    68  	switch args.Mode {
    69  	case modeDeprecatedDirectRun:
    70  		if err := deprecatedDirectRun(ctx, &args.DeprecatedDirectRunConfig, scfg, stdout); err != nil {
    71  			return command.WriteError(stderr, err)
    72  		}
    73  		return statusSuccess
    74  	case modeRPC:
    75  		if err := runRPCServer(scfg, stdin, stdout); err != nil {
    76  			return command.WriteError(stderr, err)
    77  		}
    78  		return statusSuccess
    79  	default:
    80  		return command.WriteError(stderr, command.NewStatusErrorf(statusBadArgs, "invalid mode %v", args.Mode))
    81  	}
    82  }
    83  
    84  func deprecatedDirectRun(ctx context.Context, drcfg *DeprecatedDirectRunConfig, scfg *StaticConfig, stdout io.Writer) error {
    85  	lg := log.New(stdout, "", log.LstdFlags)
    86  
    87  	matcher, err := testing.NewMatcher(drcfg.Patterns)
    88  	if err != nil {
    89  		return err
    90  	}
    91  
    92  	compat, err := startCompatServer(ctx, scfg, &protocol.HandshakeRequest{
    93  		RunnerInitParams: &protocol.RunnerInitParams{
    94  			BundleGlob: drcfg.BundleGlob,
    95  		},
    96  		BundleInitParams: &protocol.BundleInitParams{},
    97  	})
    98  	if err != nil {
    99  		return err
   100  	}
   101  	defer compat.Close()
   102  
   103  	cl := compat.Client()
   104  
   105  	// Enumerate tests to run.
   106  	res, err := cl.ListEntities(ctx, &protocol.ListEntitiesRequest{Features: drcfg.RunConfig(nil).GetFeatures()})
   107  	if err != nil {
   108  		return errors.Wrap(err, "failed to enumerate entities in bundles")
   109  	}
   110  
   111  	var testNames []string
   112  	for _, r := range res.Entities {
   113  		e := r.GetEntity()
   114  		if e.GetType() != protocol.EntityType_TEST {
   115  			continue
   116  		}
   117  		if matcher.Match(e.GetName(), e.GetAttributes()) {
   118  			testNames = append(testNames, e.GetName())
   119  		}
   120  	}
   121  	sort.Strings(testNames)
   122  
   123  	// We expect to not match any tests if both local and remote tests are being run but the
   124  	// user specified a pattern that matched only local or only remote tests rather than tests
   125  	// of both types. Don't bother creating an out dir in that case.
   126  	if len(testNames) == 0 {
   127  		return errors.New("no tests matched")
   128  	}
   129  
   130  	rcfg := drcfg.RunConfig(testNames)
   131  
   132  	created, err := setUpBaseOutDir(rcfg)
   133  	if err != nil {
   134  		return errors.Wrap(err, "failed to set up base out dir")
   135  	}
   136  	// If the runner was executed manually and an out dir wasn't specified, clean up the temp dir that was created.
   137  	if created {
   138  		defer os.RemoveAll(rcfg.GetDirs().GetOutDir())
   139  	}
   140  
   141  	// Call RunTests method and send the initial request.
   142  	srv, err := cl.RunTests(ctx)
   143  	if err != nil {
   144  		return errors.Wrap(err, "RunTests: failed to call")
   145  	}
   146  
   147  	initReq := &protocol.RunTestsRequest{Type: &protocol.RunTestsRequest_RunTestsInit{RunTestsInit: &protocol.RunTestsInit{RunConfig: rcfg}}}
   148  	if err := srv.Send(initReq); err != nil {
   149  		return errors.Wrap(err, "RunTests: failed to send initial request")
   150  	}
   151  
   152  	numTests := 0
   153  	testFailed := false              // true if error seen for current test
   154  	var failedTests []string         // names of tests with errors
   155  	var startTime, endTime time.Time // start of first test and end of last test
   156  
   157  	// Keep reading responses and convert them to control messages.
   158  	for {
   159  		res, err := srv.Recv()
   160  		if err == io.EOF {
   161  			lg.Printf("Ran %d test(s) in %v", numTests, endTime.Sub(startTime).Round(time.Millisecond))
   162  			if len(failedTests) > 0 {
   163  				lg.Printf("%d failed:", len(failedTests))
   164  				for _, t := range failedTests {
   165  					lg.Print("  " + t)
   166  				}
   167  				return command.NewStatusErrorf(statusTestFailed, "test(s) failed")
   168  			}
   169  			return nil
   170  		}
   171  		if err != nil {
   172  			return err
   173  		}
   174  
   175  		switch res := res.GetType().(type) {
   176  		case *protocol.RunTestsResponse_RunLog:
   177  			lg.Print(res.RunLog.GetText())
   178  		case *protocol.RunTestsResponse_EntityStart:
   179  			lg.Print("Running ", res.EntityStart.GetEntity().GetName())
   180  			testFailed = false
   181  			if numTests == 0 {
   182  				startTime = res.EntityStart.GetTime().AsTime()
   183  			}
   184  		case *protocol.RunTestsResponse_EntityLog:
   185  			lg.Print(res.EntityLog.GetText())
   186  		case *protocol.RunTestsResponse_EntityError:
   187  			e := res.EntityError.GetError()
   188  			lg.Printf("Error: [%s:%d] %v", filepath.Base(e.GetLocation().GetFile()), e.GetLocation().GetLine(), e.GetReason())
   189  			testFailed = true
   190  		case *protocol.RunTestsResponse_EntityEnd:
   191  			reasons := res.EntityEnd.GetSkip().GetReasons()
   192  			if len(reasons) > 0 {
   193  				lg.Printf("Skipped %s for missing deps: %s", res.EntityEnd.GetEntityName(), strings.Join(reasons, ", "))
   194  			} else {
   195  				lg.Print("Finished ", res.EntityEnd.GetEntityName())
   196  			}
   197  			lg.Print(strings.Repeat("-", 80))
   198  			if testFailed {
   199  				failedTests = append(failedTests, res.EntityEnd.GetEntityName())
   200  			}
   201  			numTests++
   202  			endTime = res.EntityEnd.GetTime().AsTime()
   203  		}
   204  	}
   205  }
   206  
   207  // setUpBaseOutDir creates and assigns a temporary directory if rcfg.Dirs.OutDir is empty.
   208  // It also ensures that the dir is accessible to all users. The returned boolean created
   209  // indicates whether a new directory was created.
   210  func setUpBaseOutDir(rcfg *protocol.RunConfig) (created bool, err error) {
   211  	defer func() {
   212  		if err != nil && created {
   213  			os.RemoveAll(rcfg.GetDirs().GetOutDir())
   214  			created = false
   215  		}
   216  	}()
   217  
   218  	if rcfg.GetDirs().GetOutDir() == "" {
   219  		if rcfg.GetDirs().OutDir, err = ioutil.TempDir("", "tast_out."); err != nil {
   220  			return false, err
   221  		}
   222  		created = true
   223  	} else {
   224  		if _, err := os.Stat(rcfg.GetDirs().GetOutDir()); os.IsNotExist(err) {
   225  			if err := os.MkdirAll(rcfg.GetDirs().GetOutDir(), 0755); err != nil {
   226  				return false, err
   227  			}
   228  			created = true
   229  		} else if err != nil {
   230  			return false, err
   231  		}
   232  	}
   233  
   234  	// Make the directory traversable in case a test wants to write a file as another user.
   235  	// (Note that we can't guarantee that all the parent directories are also accessible, though.)
   236  	if err := os.Chmod(rcfg.GetDirs().GetOutDir(), 0755); err != nil {
   237  		return created, err
   238  	}
   239  	return created, nil
   240  }
   241  
   242  // killStaleRunners sends sig to the process groups of any other processes sharing
   243  // the current process's executable. Status messages and errors are logged using lf.
   244  func killStaleRunners(ctx context.Context, sig unix.Signal) {
   245  	ourPID := os.Getpid()
   246  	ourExe, err := os.Executable()
   247  	if err != nil {
   248  		logging.Info(ctx, "Failed to look up current executable: ", err)
   249  		return
   250  	}
   251  
   252  	procs, err := process.Processes()
   253  	if err != nil {
   254  		logging.Info(ctx, "Failed to list processes while looking for stale runners: ", err)
   255  		return
   256  	}
   257  	for _, proc := range procs {
   258  		if int(proc.Pid) == ourPID {
   259  			continue
   260  		}
   261  		if exe, err := proc.Exe(); err != nil || exe != ourExe {
   262  			continue
   263  		}
   264  		logging.Infof(ctx, "Sending signal %d to stale %v process group %d", sig, ourExe, proc.Pid)
   265  		if err := unix.Kill(int(-proc.Pid), sig); err != nil {
   266  			logging.Infof(ctx, "Failed killing process group %d: %v", proc.Pid, err)
   267  		}
   268  	}
   269  }