gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/test/packetimpact/runner/main.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  //go:build linux && go1.10
    16  // +build linux,go1.10
    17  
    18  // The runner binary is used as the test runner for PacketImpact tests.
    19  package main
    20  
    21  import (
    22  	"context"
    23  	"encoding/json"
    24  	"errors"
    25  	"flag"
    26  	"fmt"
    27  	"io"
    28  	"log"
    29  	"os"
    30  	"os/exec"
    31  	"path/filepath"
    32  	"runtime"
    33  	"strings"
    34  	"syscall"
    35  
    36  	"github.com/google/gopacket"
    37  	"github.com/google/gopacket/layers"
    38  	"github.com/google/gopacket/pcapgo"
    39  	"github.com/vishvananda/netlink"
    40  	"golang.org/x/sync/errgroup"
    41  	"golang.org/x/sys/unix"
    42  	"gvisor.dev/gvisor/test/packetimpact/dut"
    43  	"gvisor.dev/gvisor/test/packetimpact/internal/testing"
    44  	netdevs "gvisor.dev/gvisor/test/packetimpact/netdevs/netlink"
    45  	"gvisor.dev/gvisor/test/packetimpact/testbench"
    46  )
    47  
    48  type dutArgList []string
    49  
    50  // String implements flag.Value.
    51  func (l *dutArgList) String() string {
    52  	return strings.Join(*l, " ")
    53  }
    54  
    55  // Set implements flag.Value.
    56  func (l *dutArgList) Set(value string) error {
    57  	*l = append(*l, value)
    58  	return nil
    59  }
    60  
    61  func main() {
    62  	const procSelfExe = "/proc/self/exe"
    63  	if os.Args[0] != procSelfExe {
    64  		// For the first time, re-execute in a new user name space and a new
    65  		// network namespace.
    66  		cmd := exec.Command(procSelfExe, os.Args[1:]...)
    67  		cmd.SysProcAttr = &unix.SysProcAttr{
    68  			Cloneflags: unix.CLONE_NEWUSER | unix.CLONE_NEWNET,
    69  			Pdeathsig:  unix.SIGTERM,
    70  			UidMappings: []syscall.SysProcIDMap{
    71  				{
    72  					ContainerID: 0,
    73  					HostID:      os.Getuid(),
    74  					Size:        1,
    75  				},
    76  			},
    77  			GidMappings: []syscall.SysProcIDMap{
    78  				{
    79  					ContainerID: 0,
    80  					HostID:      os.Getgid(),
    81  					Size:        1,
    82  				},
    83  			},
    84  		}
    85  		cmd.Stdout = os.Stdout
    86  		cmd.Stderr = os.Stderr
    87  		if err := cmd.Run(); err != nil {
    88  			if exitStatus, ok := err.(*exec.ExitError); ok {
    89  				os.Exit(exitStatus.ExitCode())
    90  			} else {
    91  				log.Fatalf("unknown failure: %s", err)
    92  			}
    93  		}
    94  		return
    95  	}
    96  
    97  	var (
    98  		dutBinary     string
    99  		testBinary    string
   100  		expectFailure bool
   101  		numDUTs       int
   102  		variant       string
   103  		dutArgs       dutArgList
   104  	)
   105  	fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
   106  	fs.StringVar(&dutBinary, "dut_binary", "", "path to the DUT binary")
   107  	fs.StringVar(&testBinary, "testbench_binary", "", "path to the test binary")
   108  	fs.BoolVar(&expectFailure, "expect_failure", false, "whether the test is expected to fail")
   109  	fs.IntVar(&numDUTs, "num_duts", 1, "number of DUTs to create")
   110  	fs.StringVar(&variant, "variant", "", "test variant could be native, gvisor or fuchsia")
   111  	fs.Var(&dutArgs, "dut_arg", "argument to the DUT binary")
   112  	if err := fs.Parse(os.Args[1:]); err != nil {
   113  		log.Fatal(err)
   114  	}
   115  
   116  	g, ctx := errgroup.WithContext(context.Background())
   117  
   118  	// Create all the DUTs.
   119  	infoCh := make(chan testbench.DUTInfo, numDUTs)
   120  	var duts []*dutProcess
   121  	for i := 0; i < numDUTs; i++ {
   122  		d, err := newDUT(ctx, i, dutBinary, dutArgs)
   123  		if err != nil {
   124  			log.Fatal(err)
   125  		}
   126  		duts = append(duts, d)
   127  		g.Go(func() error {
   128  			info, waitFn, err := d.bootstrap(ctx)
   129  			if err != nil {
   130  				return err
   131  			}
   132  			infoCh <- info
   133  			return waitFn()
   134  		})
   135  	}
   136  
   137  	// Wait for all the DUTs to bootstrap.
   138  	var infos []testbench.DUTInfo
   139  	for i := 0; i < numDUTs; i++ {
   140  		select {
   141  		case <-ctx.Done():
   142  			log.Fatalf("failed to bootstrap dut: %s", g.Wait())
   143  		case info := <-infoCh:
   144  			infos = append(infos, info)
   145  		}
   146  	}
   147  
   148  	dutJSON, err := json.Marshal(&infos)
   149  	if err != nil {
   150  		log.Fatalf("failed to marshal json: %s", err)
   151  	}
   152  
   153  	for _, d := range duts {
   154  		// When the Linux kernel receives a SYN-ACK for a SYN it didn't send, it
   155  		// will respond with an RST. In most packetimpact tests, the SYN is sent
   156  		// by the raw socket, the kernel knows nothing about the connection, this
   157  		// behavior will break lots of TCP related packetimpact tests. To prevent
   158  		// this, we can install the following iptables rules. The raw socket that
   159  		// packetimpact tests use will still be able to see everything.
   160  		for _, iptables := range []string{"/sbin/iptables-nft", "/sbin/ip6tables-nft"} {
   161  			cmd := exec.Command(iptables, "-A", "INPUT", "-i", d.peerIface(), "--proto", "tcp", "-j", "DROP")
   162  			if output, err := cmd.CombinedOutput(); err != nil {
   163  				log.Fatalf("failed to set iptables: %s, output: %s", err, string(output))
   164  			}
   165  		}
   166  		// Start packet capture.
   167  		g.Go(func() error {
   168  			return d.writePcap(ctx, filepath.Base(testBinary))
   169  		})
   170  	}
   171  
   172  	// Start the test itself.
   173  	testResult := make(chan error, 1)
   174  	go func() {
   175  		testArgs := []string{"--dut_infos_json", string(dutJSON)}
   176  		if variant == "native" {
   177  			testArgs = append(testArgs, "-native")
   178  		}
   179  		test := exec.CommandContext(ctx, testBinary, testArgs...)
   180  		test.SysProcAttr = &unix.SysProcAttr{
   181  			Pdeathsig: unix.SIGTERM,
   182  		}
   183  		test.Stderr = os.Stderr
   184  		test.Stdout = os.Stdout
   185  		testResult <- test.Run()
   186  	}()
   187  
   188  	select {
   189  	case <-ctx.Done():
   190  		log.Fatalf("background tasks exited early: %s", g.Wait())
   191  	case err := <-testResult:
   192  		switch {
   193  		case err != nil == expectFailure:
   194  			// Expected.
   195  		case expectFailure:
   196  			log.Fatalf("the test is expected to fail, but it succeeded")
   197  		case err != nil:
   198  			var exitStatus *exec.ExitError
   199  			if errors.As(err, &exitStatus) {
   200  				os.Exit(exitStatus.ExitCode())
   201  			}
   202  			log.Fatalf("unknown error when executing test: %s", err)
   203  		}
   204  	}
   205  }
   206  
   207  type dutProcess struct {
   208  	cmd       *exec.Cmd
   209  	id        int
   210  	completeR *os.File
   211  	dutNetNS  netNS
   212  }
   213  
   214  func newDUT(ctx context.Context, id int, dutBinary string, dutArgs dutArgList) (*dutProcess, error) {
   215  	cmd := exec.CommandContext(ctx, dutBinary, append([]string{
   216  		"--" + dut.CtrlIface, dutSide.ifaceName(ctrlLink, id),
   217  		"--" + dut.TestIface, dutSide.ifaceName(testLink, id),
   218  	}, dutArgs...)...)
   219  
   220  	// Create the pipe for completion signal
   221  	completeR, completeW, err := os.Pipe()
   222  	if err != nil {
   223  		return nil, fmt.Errorf("failed to create pipe for completion signal: %w", err)
   224  	}
   225  
   226  	// Create a new network namespace for the DUT.
   227  	dutNetNS, err := newNetNS()
   228  	if err != nil {
   229  		return nil, fmt.Errorf("failed to create a new namespace for DUT: %w", err)
   230  	}
   231  
   232  	// Pass these two file descriptors to the DUT.
   233  	cmd.ExtraFiles = append(cmd.ExtraFiles, completeW)
   234  
   235  	// Deliver SIGTERM to the child when the runner exits.
   236  	cmd.SysProcAttr = &unix.SysProcAttr{
   237  		Pdeathsig: unix.SIGTERM,
   238  	}
   239  
   240  	// Stream outputs from the DUT binary.
   241  	cmd.Stdout = os.Stdout
   242  	cmd.Stderr = os.Stderr
   243  
   244  	// Now create the veth pairs to connect the DUT and us.
   245  	for _, typ := range []linkType{ctrlLink, testLink} {
   246  		dutSideIfaceName := dutSide.ifaceName(typ, id)
   247  		tbSideIfaceName := tbSide.ifaceName(typ, id)
   248  		dutVeth := netlink.Veth{
   249  			LinkAttrs: netlink.LinkAttrs{
   250  				Name: dutSideIfaceName,
   251  			},
   252  			PeerName: tbSideIfaceName,
   253  		}
   254  		tbVeth := netlink.Veth{
   255  			LinkAttrs: netlink.LinkAttrs{
   256  				Name: tbSideIfaceName,
   257  			},
   258  			PeerName: dutSideIfaceName,
   259  		}
   260  		if err := netlink.LinkAdd(&dutVeth); err != nil {
   261  			return nil, fmt.Errorf("failed to add a %s veth pair for dut-%d: %w", typ, id, err)
   262  		}
   263  
   264  		tbIPv4 := typ.ipv4(uint8(id), 1)
   265  		dutIPv4 := typ.ipv4(uint8(id), 2)
   266  
   267  		// Move the DUT end into the created namespace.
   268  		if err := netlink.LinkSetNsFd(&dutVeth, int(dutNetNS)); err != nil {
   269  			return nil, fmt.Errorf("failed to move %s veth end to dut-%d: %w", typ, id, err)
   270  		}
   271  
   272  		for _, conf := range []struct {
   273  			ns   netNS
   274  			addr *netlink.Addr
   275  			veth *netlink.Veth
   276  		}{
   277  			{ns: currentNetNS, addr: tbIPv4, veth: &tbVeth},
   278  			{ns: dutNetNS, addr: dutIPv4, veth: &dutVeth},
   279  		} {
   280  			if err := conf.ns.Do(func() error {
   281  				// Disable the DAD so that the generated IPv6 address can be used immediately.
   282  				if err := disableDad(conf.veth.Name); err != nil {
   283  					return fmt.Errorf("failed to disable DAD on %s: %w", conf.veth.Name, err)
   284  				}
   285  				// Manually add the IPv4 address.
   286  				if err := netlink.AddrAdd(conf.veth, conf.addr); err != nil {
   287  					return fmt.Errorf("failed to add addr %s to %s: %w", conf.addr, conf.veth.Name, err)
   288  				}
   289  				// Bring the link up.
   290  				if err := netlink.LinkSetUp(conf.veth); err != nil {
   291  					return fmt.Errorf("failed to set %s up: %w", conf.veth.Name, err)
   292  				}
   293  				return nil
   294  			}); err != nil {
   295  				return nil, err
   296  			}
   297  		}
   298  	}
   299  
   300  	// Bring the loopback interface up in both namespaces.
   301  	for _, ns := range []netNS{currentNetNS, dutNetNS} {
   302  		if err := ns.Do(func() error {
   303  			return netlink.LinkSetUp(&netlink.Device{
   304  				LinkAttrs: netlink.LinkAttrs{
   305  					Name: "lo",
   306  				},
   307  			})
   308  		}); err != nil {
   309  			return nil, fmt.Errorf("failed to bring loopback up: %w", err)
   310  		}
   311  	}
   312  
   313  	return &dutProcess{cmd: cmd, id: id, completeR: completeR, dutNetNS: dutNetNS}, nil
   314  }
   315  
   316  func (d *dutProcess) bootstrap(ctx context.Context) (testbench.DUTInfo, func() error, error) {
   317  	if err := d.dutNetNS.Do(func() error {
   318  		return d.cmd.Start()
   319  	}); err != nil {
   320  		return testbench.DUTInfo{}, nil, fmt.Errorf("failed to start DUT %d: %w", d.id, err)
   321  	}
   322  	for _, file := range d.cmd.ExtraFiles {
   323  		if err := file.Close(); err != nil {
   324  			return testbench.DUTInfo{}, nil, fmt.Errorf("close(%d) = %w", file.Fd(), err)
   325  		}
   326  	}
   327  
   328  	bytes, err := io.ReadAll(d.completeR)
   329  	if err != nil {
   330  		return testbench.DUTInfo{}, nil, fmt.Errorf("failed to read from %s complete pipe: %w", d.name(), err)
   331  	}
   332  	if err := d.completeR.Close(); err != nil {
   333  		return testbench.DUTInfo{}, nil, fmt.Errorf("failed to close the read end of completion pipe: %w", err)
   334  	}
   335  	var dutInfo testbench.DUTInfo
   336  	if err := json.Unmarshal(bytes, &dutInfo); err != nil {
   337  		return testbench.DUTInfo{}, nil, fmt.Errorf("invalid response from %s: %w, received: %s", d.name(), err, string(bytes))
   338  	}
   339  	testIface, testIPv4, testIPv6, err := netdevs.IfaceInfo(d.peerIface())
   340  	if err != nil {
   341  		return testbench.DUTInfo{}, nil, fmt.Errorf("failed to gather information about the testbench: %w", err)
   342  	}
   343  	dutInfo.Net.LocalMAC = testIface.Attrs().HardwareAddr
   344  	dutInfo.Net.LocalIPv4 = testIPv4.IP.To4()
   345  	dutInfo.Net.LocalIPv6 = testIPv6.IP
   346  	dutInfo.Net.LocalDevID = uint32(testIface.Attrs().Index)
   347  	dutInfo.Net.LocalDevName = testIface.Attrs().Name
   348  	return dutInfo, d.cmd.Wait, nil
   349  }
   350  
   351  func (d *dutProcess) name() string {
   352  	return fmt.Sprintf("dut-%d", d.id)
   353  }
   354  
   355  func (d *dutProcess) peerIface() string {
   356  	return tbSide.ifaceName(testLink, d.id)
   357  }
   358  
   359  // writePcap creates the packet capture while the test is running.
   360  func (d *dutProcess) writePcap(ctx context.Context, testName string) error {
   361  	iface := d.peerIface()
   362  	// Create the pcap file.
   363  	fileName, err := testing.UndeclaredOutput(fmt.Sprintf("%s_%s.pcap", testName, iface))
   364  	if err != nil {
   365  		return err
   366  	}
   367  	pcap, err := os.Create(fileName)
   368  	if err != nil {
   369  		return fmt.Errorf("open(%s) = %w", fileName, err)
   370  	}
   371  	defer func() {
   372  		if err := pcap.Close(); err != nil {
   373  			panic(fmt.Sprintf("close(%s) = %s", pcap.Name(), err))
   374  		}
   375  	}()
   376  
   377  	// Start the packet capture.
   378  	pcapw := pcapgo.NewWriter(pcap)
   379  	if err := pcapw.WriteFileHeader(1600, layers.LinkTypeEthernet); err != nil {
   380  		return fmt.Errorf("WriteFileHeader: %w", err)
   381  	}
   382  	handle, err := pcapgo.NewEthernetHandle(iface)
   383  	if err != nil {
   384  		return fmt.Errorf("pcapgo.NewEthernetHandle(%s): %w", iface, err)
   385  	}
   386  	source := gopacket.NewPacketSource(handle, layers.LayerTypeEthernet)
   387  	for {
   388  		select {
   389  		case packet := <-source.Packets():
   390  			if err := pcapw.WritePacket(packet.Metadata().CaptureInfo, packet.Data()); err != nil {
   391  				return fmt.Errorf("pcapw.WritePacket(): %w", err)
   392  			}
   393  		case <-ctx.Done():
   394  			return ctx.Err()
   395  		}
   396  	}
   397  }
   398  
   399  // disableDad disables DAD on the iface when assigning IPv6 addrs.
   400  func disableDad(iface string) error {
   401  	// DAD operation and mode on a given interface will be selected according to
   402  	// the maximum value of conf/{all,interface}/accept_dad. So we set it to 0 on
   403  	// both `iface` and `all`.
   404  	for _, name := range []string{iface, "all"} {
   405  		path := fmt.Sprintf("/proc/sys/net/ipv6/conf/%s/accept_dad", name)
   406  		if err := os.WriteFile(path, []byte("0"), 0); err != nil {
   407  			return err
   408  		}
   409  	}
   410  	return nil
   411  }
   412  
   413  // netNS is a network namespace.
   414  type netNS int
   415  
   416  const (
   417  	currentNetNS netNS = -1
   418  )
   419  
   420  // newNetNS creates a new network namespace.
   421  func newNetNS() (netNS, error) {
   422  	ns := currentNetNS
   423  	err := withSavedNetNS(func() error {
   424  		// Create the namespace via unshare(2).
   425  		if err := unix.Unshare(unix.CLONE_NEWNET); err != nil {
   426  			return err
   427  		}
   428  		// Return the created namespace.
   429  		fd, err := openNetNSFD()
   430  		if err != nil {
   431  			return err
   432  		}
   433  		ns = netNS(fd)
   434  		return nil
   435  	})
   436  	return ns, err
   437  }
   438  
   439  // Do calls the function in the given network namespace.
   440  func (ns netNS) Do(f func() error) error {
   441  	if ns == currentNetNS {
   442  		// Simply call the function if we are already in the namespace.
   443  		return f()
   444  	}
   445  	return withSavedNetNS(func() error {
   446  		// Switch to the target namespace.
   447  		if err := unix.Setns(int(ns), unix.CLONE_NEWNET); err != nil {
   448  			return err
   449  		}
   450  		return f()
   451  	})
   452  }
   453  
   454  // linkType describes if the link is for ctrl or test.
   455  type linkType string
   456  
   457  const (
   458  	testLink linkType = "test"
   459  	ctrlLink linkType = "ctrl"
   460  )
   461  
   462  // ipv4 creates an IPv4 address for the given network and host number.
   463  func (l linkType) ipv4(network uint8, host uint8) *netlink.Addr {
   464  	const (
   465  		testNetworkNumber uint8 = 172
   466  		ctrlNetworkNumber uint8 = 192
   467  	)
   468  	var leadingByte uint8
   469  	switch l {
   470  	case testLink:
   471  		leadingByte = testNetworkNumber
   472  	case ctrlLink:
   473  		leadingByte = ctrlNetworkNumber
   474  	default:
   475  		panic(fmt.Sprintf("unknown link type: %s", l))
   476  	}
   477  	addr, err := netlink.ParseAddr(fmt.Sprintf("%d.0.%d.%d/24", leadingByte, network, host))
   478  	if err != nil {
   479  		panic(fmt.Sprintf("failed to parse ip net: %s", err))
   480  	}
   481  	return addr
   482  }
   483  
   484  // side describes which side of the link (tb/dut).
   485  type side string
   486  
   487  const (
   488  	dutSide side = "dut"
   489  	tbSide  side = "tb"
   490  )
   491  
   492  func (s side) ifaceName(typ linkType, id int) string {
   493  	return fmt.Sprintf("%s-%d-%s", s, id, typ)
   494  }
   495  
   496  // withSavedNetNS saves the current namespace and restores it after calling f.
   497  func withSavedNetNS(f func() error) error {
   498  	runtime.LockOSThread()
   499  	defer runtime.UnlockOSThread()
   500  	// Save the current namespace.
   501  	saved, err := openNetNSFD()
   502  	if err != nil {
   503  		return err
   504  	}
   505  	defer func() {
   506  		// Restore the namespace when we return from f.
   507  		if err := unix.Setns(saved, unix.CLONE_NEWNET); err != nil {
   508  			panic(fmt.Sprintf("setns(%d, CLONE_NEWNET) = %s", saved, err))
   509  		}
   510  		if err := unix.Close(saved); err != nil {
   511  			panic(fmt.Sprintf("close(%d) = %s", saved, err))
   512  		}
   513  	}()
   514  	return f()
   515  }
   516  
   517  func openNetNSFD() (int, error) {
   518  	nsPath := fmt.Sprintf("/proc/self/task/%d/ns/net", unix.Gettid())
   519  	return unix.Open(nsPath, unix.O_RDONLY|unix.O_CLOEXEC, 0)
   520  }