github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/test/testutil/testutil.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package testutil contains utility functions for runsc tests.
    16  package testutil
    17  
    18  import (
    19  	"bufio"
    20  	"context"
    21  	"debug/elf"
    22  	"encoding/base32"
    23  	"encoding/json"
    24  	"flag"
    25  	"fmt"
    26  	"io"
    27  	"io/ioutil"
    28  	"log"
    29  	"math"
    30  	"math/rand"
    31  	"net/http"
    32  	"os"
    33  	"os/exec"
    34  	"os/signal"
    35  	"path"
    36  	"path/filepath"
    37  	"strconv"
    38  	"strings"
    39  	"testing"
    40  	"time"
    41  
    42  	"github.com/cenkalti/backoff"
    43  	specs "github.com/opencontainers/runtime-spec/specs-go"
    44  	"golang.org/x/sys/unix"
    45  	"github.com/SagerNet/gvisor/pkg/sentry/watchdog"
    46  	"github.com/SagerNet/gvisor/pkg/sync"
    47  	"github.com/SagerNet/gvisor/runsc/config"
    48  	"github.com/SagerNet/gvisor/runsc/specutils"
    49  )
    50  
    51  var (
    52  	checkpoint           = flag.Bool("checkpoint", true, "control checkpoint/restore support")
    53  	partition            = flag.Int("partition", 1, "partition number, this is 1-indexed")
    54  	totalPartitions      = flag.Int("total_partitions", 1, "total number of partitions")
    55  	isRunningWithHostNet = flag.Bool("hostnet", false, "whether test is running with hostnet")
    56  )
    57  
    58  // IsCheckpointSupported returns the relevant command line flag.
    59  func IsCheckpointSupported() bool {
    60  	return *checkpoint
    61  }
    62  
    63  // IsRunningWithHostNet returns the relevant command line flag.
    64  func IsRunningWithHostNet() bool {
    65  	return *isRunningWithHostNet
    66  }
    67  
    68  // ImageByName mangles the image name used locally. This depends on the image
    69  // build infrastructure in images/ and tools/vm.
    70  func ImageByName(name string) string {
    71  	return fmt.Sprintf("github.com/SagerNet/images/%s", name)
    72  }
    73  
    74  // ConfigureExePath configures the executable for runsc in the test environment.
    75  func ConfigureExePath() error {
    76  	path, err := FindFile("runsc/runsc")
    77  	if err != nil {
    78  		return err
    79  	}
    80  	specutils.ExePath = path
    81  	return nil
    82  }
    83  
    84  // TmpDir returns the absolute path to a writable directory that can be used as
    85  // scratch by the test.
    86  func TmpDir() string {
    87  	if dir, ok := os.LookupEnv("TEST_TMPDIR"); ok {
    88  		return dir
    89  	}
    90  	return "/tmp"
    91  }
    92  
    93  // Logger is a simple logging wrapper.
    94  //
    95  // This is designed to be implemented by *testing.T.
    96  type Logger interface {
    97  	Name() string
    98  	Logf(fmt string, args ...interface{})
    99  }
   100  
   101  // DefaultLogger logs using the log package.
   102  type DefaultLogger string
   103  
   104  // Name implements Logger.Name.
   105  func (d DefaultLogger) Name() string {
   106  	return string(d)
   107  }
   108  
   109  // Logf implements Logger.Logf.
   110  func (d DefaultLogger) Logf(fmt string, args ...interface{}) {
   111  	log.Printf(fmt, args...)
   112  }
   113  
   114  // multiLogger logs to multiple Loggers.
   115  type multiLogger []Logger
   116  
   117  // Name implements Logger.Name.
   118  func (m multiLogger) Name() string {
   119  	names := make([]string, len(m))
   120  	for i, l := range m {
   121  		names[i] = l.Name()
   122  	}
   123  	return strings.Join(names, "+")
   124  }
   125  
   126  // Logf implements Logger.Logf.
   127  func (m multiLogger) Logf(fmt string, args ...interface{}) {
   128  	for _, l := range m {
   129  		l.Logf(fmt, args...)
   130  	}
   131  }
   132  
   133  // NewMultiLogger returns a new Logger that logs on multiple Loggers.
   134  func NewMultiLogger(loggers ...Logger) Logger {
   135  	return multiLogger(loggers)
   136  }
   137  
   138  // Cmd is a simple wrapper.
   139  type Cmd struct {
   140  	logger Logger
   141  	*exec.Cmd
   142  }
   143  
   144  // CombinedOutput returns the output and logs.
   145  func (c *Cmd) CombinedOutput() ([]byte, error) {
   146  	out, err := c.Cmd.CombinedOutput()
   147  	if len(out) > 0 {
   148  		c.logger.Logf("output: %s", string(out))
   149  	}
   150  	if err != nil {
   151  		c.logger.Logf("error: %v", err)
   152  	}
   153  	return out, err
   154  }
   155  
   156  // Command is a simple wrapper around exec.Command, that logs.
   157  func Command(logger Logger, args ...string) *Cmd {
   158  	logger.Logf("command: %s", strings.Join(args, " "))
   159  	return &Cmd{
   160  		logger: logger,
   161  		Cmd:    exec.Command(args[0], args[1:]...),
   162  	}
   163  }
   164  
   165  // TestConfig returns the default configuration to use in tests. Note that
   166  // 'RootDir' must be set by caller if required.
   167  func TestConfig(t *testing.T) *config.Config {
   168  	logDir := os.TempDir()
   169  	if dir, ok := os.LookupEnv("TEST_UNDECLARED_OUTPUTS_DIR"); ok {
   170  		logDir = dir + "/"
   171  	}
   172  
   173  	// Only register flags if config is being used. Otherwise anyone that uses
   174  	// testutil will get flags registered and they may conflict.
   175  	config.RegisterFlags()
   176  
   177  	conf, err := config.NewFromFlags()
   178  	if err != nil {
   179  		panic(err)
   180  	}
   181  	// Change test defaults.
   182  	conf.Debug = true
   183  	conf.DebugLog = path.Join(logDir, "runsc.log."+t.Name()+".%TIMESTAMP%.%COMMAND%")
   184  	conf.LogPackets = true
   185  	conf.Network = config.NetworkNone
   186  	conf.Strace = true
   187  	conf.TestOnlyAllowRunAsCurrentUserWithoutChroot = true
   188  	conf.WatchdogAction = watchdog.Panic
   189  	return conf
   190  }
   191  
   192  // NewSpecWithArgs creates a simple spec with the given args suitable for use
   193  // in tests.
   194  func NewSpecWithArgs(args ...string) *specs.Spec {
   195  	return &specs.Spec{
   196  		// The host filesystem root is the container root.
   197  		Root: &specs.Root{
   198  			Path:     "/",
   199  			Readonly: true,
   200  		},
   201  		Process: &specs.Process{
   202  			Args: args,
   203  			Env: []string{
   204  				"PATH=" + os.Getenv("PATH"),
   205  			},
   206  			Capabilities: specutils.AllCapabilities(),
   207  		},
   208  		Mounts: []specs.Mount{
   209  			// Hide the host /etc to avoid any side-effects.
   210  			// For example, bash reads /etc/passwd and if it is
   211  			// very big, tests can fail by timeout.
   212  			{
   213  				Type:        "tmpfs",
   214  				Destination: "/etc",
   215  			},
   216  			// Root is readonly, but many tests want to write to tmpdir.
   217  			// This creates a writable mount inside the root. Also, when tmpdir points
   218  			// to "/tmp", it makes the the actual /tmp to be mounted and not a tmpfs
   219  			// inside the sentry.
   220  			{
   221  				Type:        "bind",
   222  				Destination: TmpDir(),
   223  				Source:      TmpDir(),
   224  			},
   225  		},
   226  		Hostname: "runsc-test-hostname",
   227  	}
   228  }
   229  
   230  // SetupRootDir creates a root directory for containers.
   231  func SetupRootDir() (string, func(), error) {
   232  	rootDir, err := ioutil.TempDir(TmpDir(), "containers")
   233  	if err != nil {
   234  		return "", nil, fmt.Errorf("error creating root dir: %v", err)
   235  	}
   236  	return rootDir, func() { os.RemoveAll(rootDir) }, nil
   237  }
   238  
   239  // SetupContainer creates a bundle and root dir for the container, generates a
   240  // test config, and writes the spec to config.json in the bundle dir.
   241  func SetupContainer(spec *specs.Spec, conf *config.Config) (rootDir, bundleDir string, cleanup func(), err error) {
   242  	rootDir, rootCleanup, err := SetupRootDir()
   243  	if err != nil {
   244  		return "", "", nil, err
   245  	}
   246  	conf.RootDir = rootDir
   247  	bundleDir, bundleCleanup, err := SetupBundleDir(spec)
   248  	if err != nil {
   249  		rootCleanup()
   250  		return "", "", nil, err
   251  	}
   252  	return rootDir, bundleDir, func() {
   253  		bundleCleanup()
   254  		rootCleanup()
   255  	}, err
   256  }
   257  
   258  // SetupBundleDir creates a bundle dir and writes the spec to config.json.
   259  func SetupBundleDir(spec *specs.Spec) (string, func(), error) {
   260  	bundleDir, err := ioutil.TempDir(TmpDir(), "bundle")
   261  	if err != nil {
   262  		return "", nil, fmt.Errorf("error creating bundle dir: %v", err)
   263  	}
   264  	cleanup := func() { os.RemoveAll(bundleDir) }
   265  	if err := writeSpec(bundleDir, spec); err != nil {
   266  		cleanup()
   267  		return "", nil, fmt.Errorf("error writing spec: %v", err)
   268  	}
   269  	return bundleDir, cleanup, nil
   270  }
   271  
   272  // writeSpec writes the spec to disk in the given directory.
   273  func writeSpec(dir string, spec *specs.Spec) error {
   274  	b, err := json.Marshal(spec)
   275  	if err != nil {
   276  		return err
   277  	}
   278  	return ioutil.WriteFile(filepath.Join(dir, "config.json"), b, 0755)
   279  }
   280  
   281  // idRandomSrc is a pseudo random generator used to in RandomID.
   282  var idRandomSrc = rand.New(rand.NewSource(time.Now().UnixNano()))
   283  
   284  // idRandomSrcMtx is the mutex protecting idRandomSrc.Read from being used
   285  // concurrently in differnt goroutines.
   286  var idRandomSrcMtx sync.Mutex
   287  
   288  // RandomID returns 20 random bytes following the given prefix.
   289  func RandomID(prefix string) string {
   290  	// Read 20 random bytes.
   291  	b := make([]byte, 20)
   292  	// Rand.Read is not safe for concurrent use. Packetimpact tests can be run in
   293  	// parallel now, so we have to protect the Read with a mutex. Otherwise we'll
   294  	// run into name conflicts.
   295  	// https://golang.org/pkg/math/rand/#Rand.Read
   296  	idRandomSrcMtx.Lock()
   297  	// "[Read] always returns len(p) and a nil error." --godoc
   298  	if _, err := idRandomSrc.Read(b); err != nil {
   299  		idRandomSrcMtx.Unlock()
   300  		panic("rand.Read failed: " + err.Error())
   301  	}
   302  	idRandomSrcMtx.Unlock()
   303  	if prefix != "" {
   304  		prefix = prefix + "-"
   305  	}
   306  	return fmt.Sprintf("%s%s", prefix, base32.StdEncoding.EncodeToString(b))
   307  }
   308  
   309  // RandomContainerID generates a random container id for each test.
   310  //
   311  // The container id is used to create an abstract unix domain socket, which
   312  // must be unique. While the container forbids creating two containers with the
   313  // same name, sometimes between test runs the socket does not get cleaned up
   314  // quickly enough, causing container creation to fail.
   315  func RandomContainerID() string {
   316  	return RandomID("test-container")
   317  }
   318  
   319  // Copy copies file from src to dst.
   320  func Copy(src, dst string) error {
   321  	in, err := os.Open(src)
   322  	if err != nil {
   323  		return err
   324  	}
   325  	defer in.Close()
   326  
   327  	st, err := in.Stat()
   328  	if err != nil {
   329  		return err
   330  	}
   331  
   332  	out, err := os.OpenFile(dst, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, st.Mode().Perm())
   333  	if err != nil {
   334  		return err
   335  	}
   336  	defer out.Close()
   337  
   338  	// Mirror the local user's permissions across all users. This is
   339  	// because as we inject things into the container, the UID/GID will
   340  	// change. Also, the build system may generate artifacts with different
   341  	// modes. At the top-level (volume mapping) we have a big read-only
   342  	// knob that can be applied to prevent modifications.
   343  	//
   344  	// Note that this must be done via a separate Chmod call, otherwise the
   345  	// current process's umask will get in the way.
   346  	var mode os.FileMode
   347  	if st.Mode()&0100 != 0 {
   348  		mode |= 0111
   349  	}
   350  	if st.Mode()&0200 != 0 {
   351  		mode |= 0222
   352  	}
   353  	if st.Mode()&0400 != 0 {
   354  		mode |= 0444
   355  	}
   356  	if err := os.Chmod(dst, mode); err != nil {
   357  		return err
   358  	}
   359  
   360  	_, err = io.Copy(out, in)
   361  	return err
   362  }
   363  
   364  // Poll is a shorthand function to poll for something with given timeout.
   365  func Poll(cb func() error, timeout time.Duration) error {
   366  	ctx, cancel := context.WithTimeout(context.Background(), timeout)
   367  	defer cancel()
   368  	return PollContext(ctx, cb)
   369  }
   370  
   371  // PollContext is like Poll, but takes a context instead of a timeout.
   372  func PollContext(ctx context.Context, cb func() error) error {
   373  	b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx)
   374  	return backoff.Retry(cb, b)
   375  }
   376  
   377  // WaitForHTTP tries GET requests on a port until the call succeeds or timeout.
   378  func WaitForHTTP(ip string, port int, timeout time.Duration) error {
   379  	cb := func() error {
   380  		c := &http.Client{
   381  			// Calculate timeout to be able to do minimum 5 attempts.
   382  			Timeout: timeout / 5,
   383  		}
   384  		url := fmt.Sprintf("http://%s:%d/", ip, port)
   385  		resp, err := c.Get(url)
   386  		if err != nil {
   387  			log.Printf("Waiting %s: %v", url, err)
   388  			return err
   389  		}
   390  		resp.Body.Close()
   391  		return nil
   392  	}
   393  	return Poll(cb, timeout)
   394  }
   395  
   396  // Reaper reaps child processes.
   397  type Reaper struct {
   398  	// mu protects ch, which will be nil if the reaper is not running.
   399  	mu sync.Mutex
   400  	ch chan os.Signal
   401  }
   402  
   403  // Start starts reaping child processes.
   404  func (r *Reaper) Start() {
   405  	r.mu.Lock()
   406  	defer r.mu.Unlock()
   407  
   408  	if r.ch != nil {
   409  		panic("reaper.Start called on a running reaper")
   410  	}
   411  
   412  	r.ch = make(chan os.Signal, 1)
   413  	signal.Notify(r.ch, unix.SIGCHLD)
   414  
   415  	go func() {
   416  		for {
   417  			r.mu.Lock()
   418  			ch := r.ch
   419  			r.mu.Unlock()
   420  			if ch == nil {
   421  				return
   422  			}
   423  
   424  			_, ok := <-ch
   425  			if !ok {
   426  				// Channel closed.
   427  				return
   428  			}
   429  			for {
   430  				cpid, _ := unix.Wait4(-1, nil, unix.WNOHANG, nil)
   431  				if cpid < 1 {
   432  					break
   433  				}
   434  			}
   435  		}
   436  	}()
   437  }
   438  
   439  // Stop stops reaping child processes.
   440  func (r *Reaper) Stop() {
   441  	r.mu.Lock()
   442  	defer r.mu.Unlock()
   443  
   444  	if r.ch == nil {
   445  		panic("reaper.Stop called on a stopped reaper")
   446  	}
   447  
   448  	signal.Stop(r.ch)
   449  	close(r.ch)
   450  	r.ch = nil
   451  }
   452  
   453  // StartReaper is a helper that starts a new Reaper and returns a function to
   454  // stop it.
   455  func StartReaper() func() {
   456  	r := &Reaper{}
   457  	r.Start()
   458  	return r.Stop
   459  }
   460  
   461  // WaitUntilRead reads from the given reader until the wanted string is found
   462  // or until timeout.
   463  func WaitUntilRead(r io.Reader, want string, timeout time.Duration) error {
   464  	sc := bufio.NewScanner(r)
   465  	// done must be accessed atomically. A value greater than 0 indicates
   466  	// that the read loop can exit.
   467  	doneCh := make(chan bool)
   468  	defer close(doneCh)
   469  	go func() {
   470  		for sc.Scan() {
   471  			t := sc.Text()
   472  			if strings.Contains(t, want) {
   473  				doneCh <- true
   474  				return
   475  			}
   476  			select {
   477  			case <-doneCh:
   478  				return
   479  			default:
   480  			}
   481  		}
   482  		doneCh <- false
   483  	}()
   484  
   485  	select {
   486  	case <-time.After(timeout):
   487  		return fmt.Errorf("timeout waiting to read %q", want)
   488  	case res := <-doneCh:
   489  		if !res {
   490  			return fmt.Errorf("reader closed while waiting to read %q", want)
   491  		}
   492  		return nil
   493  	}
   494  }
   495  
   496  // KillCommand kills the process running cmd unless it hasn't been started. It
   497  // returns an error if it cannot kill the process unless the reason is that the
   498  // process has already exited.
   499  //
   500  // KillCommand will also reap the process.
   501  func KillCommand(cmd *exec.Cmd) error {
   502  	if cmd.Process == nil {
   503  		return nil
   504  	}
   505  	if err := cmd.Process.Kill(); err != nil {
   506  		if !strings.Contains(err.Error(), "process already finished") {
   507  			return fmt.Errorf("failed to kill process %v: %v", cmd, err)
   508  		}
   509  	}
   510  	return cmd.Wait()
   511  }
   512  
   513  // WriteTmpFile writes text to a temporary file, closes the file, and returns
   514  // the name of the file. A cleanup function is also returned.
   515  func WriteTmpFile(pattern, text string) (string, func(), error) {
   516  	file, err := ioutil.TempFile(TmpDir(), pattern)
   517  	if err != nil {
   518  		return "", nil, err
   519  	}
   520  	defer file.Close()
   521  	if _, err := file.Write([]byte(text)); err != nil {
   522  		return "", nil, err
   523  	}
   524  	return file.Name(), func() { os.RemoveAll(file.Name()) }, nil
   525  }
   526  
   527  // IsStatic returns true iff the given file is a static binary.
   528  func IsStatic(filename string) (bool, error) {
   529  	f, err := elf.Open(filename)
   530  	if err != nil {
   531  		return false, err
   532  	}
   533  	for _, prog := range f.Progs {
   534  		if prog.Type == elf.PT_INTERP {
   535  			return false, nil // Has interpreter.
   536  		}
   537  	}
   538  	return true, nil
   539  }
   540  
   541  // TouchShardStatusFile indicates to Bazel that the test runner supports
   542  // sharding by creating or updating the last modified date of the file
   543  // specified by TEST_SHARD_STATUS_FILE.
   544  //
   545  // See https://docs.bazel.build/versions/master/test-encyclopedia.html#role-of-the-test-runner.
   546  func TouchShardStatusFile() error {
   547  	if statusFile, ok := os.LookupEnv("TEST_SHARD_STATUS_FILE"); ok {
   548  		cmd := exec.Command("touch", statusFile)
   549  		if b, err := cmd.CombinedOutput(); err != nil {
   550  			return fmt.Errorf("touch %q failed:\n output: %s\n error: %s", statusFile, string(b), err.Error())
   551  		}
   552  	}
   553  	return nil
   554  }
   555  
   556  // TestIndicesForShard returns indices for this test shard based on the
   557  // TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars, as well as
   558  // the passed partition flags.
   559  //
   560  // If either of the env vars are not present, then the function will return all
   561  // tests. If there are more shards than there are tests, then the returned list
   562  // may be empty.
   563  func TestIndicesForShard(numTests int) ([]int, error) {
   564  	var (
   565  		shardIndex = 0
   566  		shardTotal = 1
   567  	)
   568  
   569  	indexStr, indexOk := os.LookupEnv("TEST_SHARD_INDEX")
   570  	totalStr, totalOk := os.LookupEnv("TEST_TOTAL_SHARDS")
   571  	if indexOk && totalOk {
   572  		// Parse index and total to ints.
   573  		var err error
   574  		shardIndex, err = strconv.Atoi(indexStr)
   575  		if err != nil {
   576  			return nil, fmt.Errorf("invalid TEST_SHARD_INDEX %q: %v", indexStr, err)
   577  		}
   578  		shardTotal, err = strconv.Atoi(totalStr)
   579  		if err != nil {
   580  			return nil, fmt.Errorf("invalid TEST_TOTAL_SHARDS %q: %v", totalStr, err)
   581  		}
   582  	}
   583  
   584  	// Combine with the partitions.
   585  	partitionSize := shardTotal
   586  	shardTotal = (*totalPartitions) * shardTotal
   587  	shardIndex = partitionSize*(*partition-1) + shardIndex
   588  
   589  	// Calculate!
   590  	var indices []int
   591  	numBlocks := int(math.Ceil(float64(numTests) / float64(shardTotal)))
   592  	for i := 0; i < numBlocks; i++ {
   593  		pick := i*shardTotal + shardIndex
   594  		if pick < numTests {
   595  			indices = append(indices, pick)
   596  		}
   597  	}
   598  	return indices, nil
   599  }