github.com/rootless-containers/rootlesskit/v2@v2.3.4/pkg/port/testsuite/testsuite.go (about)

     1  package testsuite
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"os"
    11  	"os/exec"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"syscall"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/rootless-containers/rootlesskit/v2/pkg/port"
    20  )
    21  
    22  const (
    23  	reexecKeyMode   = "rootlesskit-port-testsuite.mode"
    24  	reexecKeyOpaque = "rootlesskit-port-testsuite.opaque"
    25  	reexecKeyQuitFD = "rootlesskit-port-testsuite.quitfd"
    26  )
    27  
    28  func Main(m *testing.M, cf func() port.ChildDriver) {
    29  	switch mode := os.Getenv(reexecKeyMode); mode {
    30  	case "":
    31  		os.Exit(m.Run())
    32  	case "child":
    33  	default:
    34  		panic(fmt.Errorf("unknown mode: %q", mode))
    35  	}
    36  	var opaque map[string]string
    37  	if err := json.Unmarshal([]byte(os.Getenv(reexecKeyOpaque)), &opaque); err != nil {
    38  		panic(err)
    39  	}
    40  	quit := make(chan struct{})
    41  	errCh := make(chan error)
    42  	go func() {
    43  		d := cf()
    44  		dErr := d.RunChildDriver(opaque, quit, "")
    45  		errCh <- dErr
    46  	}()
    47  	quitFD, err := strconv.Atoi(os.Getenv(reexecKeyQuitFD))
    48  	if err != nil {
    49  		panic(err)
    50  	}
    51  	quitR := os.NewFile(uintptr(quitFD), "")
    52  	defer quitR.Close()
    53  	if _, err = io.ReadAll(quitR); err != nil {
    54  		panic(err)
    55  	}
    56  	quit <- struct{}{}
    57  	err = <-errCh
    58  	if err != nil {
    59  		panic(err)
    60  	}
    61  	// when race detector is enabled, it takes about 1s after leaving from Main()
    62  }
    63  
    64  func Run(t *testing.T, pf func() port.ParentDriver) {
    65  	RunTCP(t, pf)
    66  	RunTCP4(t, pf)
    67  	RunUDP(t, pf)
    68  	RunUDP4(t, pf)
    69  }
    70  
    71  func RunTCP(t *testing.T, pf func() port.ParentDriver) {
    72  	t.Run("TestTCP", func(t *testing.T) { TestProto(t, "tcp", pf()) })
    73  }
    74  
    75  func RunTCP4(t *testing.T, pf func() port.ParentDriver) {
    76  	t.Run("TestTCP4", func(t *testing.T) { TestProto(t, "tcp4", pf()) })
    77  }
    78  
    79  func RunUDP(t *testing.T, pf func() port.ParentDriver) {
    80  	t.Run("TestUDP", func(t *testing.T) { TestProto(t, "udp", pf()) })
    81  }
    82  
    83  func RunUDP4(t *testing.T, pf func() port.ParentDriver) {
    84  	t.Run("TestUDP4", func(t *testing.T) { TestProto(t, "udp4", pf()) })
    85  }
    86  
    87  func TestProto(t *testing.T, proto string, d port.ParentDriver) {
    88  	ensureDeps(t, "nsenter")
    89  	t.Logf("creating USER+NET namespace")
    90  	opaque := d.OpaqueForChild()
    91  	opaqueJSON, err := json.Marshal(opaque)
    92  	if err != nil {
    93  		t.Fatal(err)
    94  	}
    95  	pr, pw, err := os.Pipe()
    96  	if err != nil {
    97  		t.Fatal(err)
    98  	}
    99  	cmd := exec.Command("/proc/self/exe")
   100  	cmd.Stdout = os.Stderr
   101  	cmd.Stderr = os.Stderr
   102  	cmd.Env = append([]string{
   103  		reexecKeyMode + "=child",
   104  		reexecKeyOpaque + "=" + string(opaqueJSON),
   105  		reexecKeyQuitFD + "=3"}, os.Environ()...)
   106  	cmd.SysProcAttr = &syscall.SysProcAttr{
   107  		Pdeathsig:  syscall.SIGKILL,
   108  		Cloneflags: syscall.CLONE_NEWUSER | syscall.CLONE_NEWNET,
   109  		UidMappings: []syscall.SysProcIDMap{
   110  			{
   111  				ContainerID: 0,
   112  				HostID:      os.Geteuid(),
   113  				Size:        1,
   114  			},
   115  		},
   116  		GidMappings: []syscall.SysProcIDMap{
   117  			{
   118  				ContainerID: 0,
   119  				HostID:      os.Getegid(),
   120  				Size:        1,
   121  			},
   122  		},
   123  	}
   124  	cmd.ExtraFiles = []*os.File{pr}
   125  	if err := cmd.Start(); err != nil {
   126  		t.Fatal(err)
   127  	}
   128  	defer func() {
   129  		pw.Close()
   130  		cmd.Wait()
   131  	}()
   132  	childPID := cmd.Process.Pid
   133  	if out, err := nsenterExec(childPID, "ip", "link", "set", "lo", "up"); err != nil {
   134  		t.Fatalf("%v, out=%s", err, string(out))
   135  	}
   136  	testProtoWithPID(t, proto, d, childPID)
   137  }
   138  
   139  func testProtoWithPID(t *testing.T, proto string, d port.ParentDriver, childPID int) {
   140  	ensureDeps(t, "nsenter", "ip", "nc")
   141  	// [child]parent
   142  	pairs := map[int]int{
   143  		// FIXME: flaky
   144  		80:   (childPID + 80) % 60000,
   145  		8080: (childPID + 8080) % 60000,
   146  	}
   147  	if proto == "tcp" {
   148  		for _, parentPort := range pairs {
   149  			var d net.Dialer
   150  			d.Timeout = 50 * time.Millisecond
   151  			_, err := d.Dial(proto, fmt.Sprintf("127.0.0.1:%d", parentPort))
   152  			if err == nil {
   153  				t.Fatalf("port %d is already used?", parentPort)
   154  			}
   155  		}
   156  	}
   157  
   158  	t.Logf("namespace pid: %d", childPID)
   159  	initComplete := make(chan struct{})
   160  	quit := make(chan struct{})
   161  	driverErr := make(chan error)
   162  	go func() {
   163  		cctx := &port.ChildContext{
   164  			IP: nil, // we don't have tap device in this test suite
   165  		}
   166  		driverErr <- d.RunParentDriver(initComplete, quit, cctx)
   167  	}()
   168  	select {
   169  	case <-initComplete:
   170  	case err := <-driverErr:
   171  		t.Fatal(err)
   172  	}
   173  	var wg sync.WaitGroup
   174  	for c, p := range pairs {
   175  		childP, parentP := c, p
   176  		wg.Add(1)
   177  		go func() {
   178  			testProtoRoutine(t, proto, d, childPID, childP, parentP)
   179  			wg.Done()
   180  		}()
   181  	}
   182  	wg.Wait()
   183  	quit <- struct{}{}
   184  	err := <-driverErr
   185  	if err != nil {
   186  		t.Fatal(err)
   187  	}
   188  }
   189  
   190  func nsenterExec(pid int, cmdss ...string) ([]byte, error) {
   191  	cmd := exec.Command("nsenter",
   192  		append([]string{"-U", "--preserve-credential", "-n", "-t", strconv.Itoa(pid)},
   193  			cmdss...)...)
   194  	cmd.SysProcAttr = &syscall.SysProcAttr{
   195  		Pdeathsig: syscall.SIGKILL,
   196  	}
   197  	return cmd.CombinedOutput()
   198  }
   199  
   200  // FIXME: support IPv6
   201  func testProtoRoutine(t *testing.T, proto string, d port.ParentDriver, childPID, childP, parentP int) {
   202  	stdoutR, stdoutW := io.Pipe()
   203  	var ncFlags []string
   204  	switch proto {
   205  	case "tcp", "tcp4":
   206  		// NOP
   207  	case "udp", "udp4":
   208  		ncFlags = append(ncFlags, "-u")
   209  	default:
   210  		panic("invalid proto")
   211  	}
   212  	cmd := exec.Command("nsenter", append(
   213  		[]string{"-U", "--preserve-credential", "-n", "-t", strconv.Itoa(childPID),
   214  			"nc"}, append(ncFlags, []string{"-l", strconv.Itoa(childP)}...)...)...)
   215  	cmd.SysProcAttr = &syscall.SysProcAttr{
   216  		Pdeathsig: syscall.SIGKILL,
   217  	}
   218  	cmd.Stdout = stdoutW
   219  	cmd.Stderr = os.Stderr
   220  	if err := cmd.Start(); err != nil {
   221  		// NOTE: t.Fatal is not thread-safe while t.Error is (see godoc testing)
   222  		panic(err)
   223  	}
   224  	defer cmd.Process.Kill()
   225  	portStatus, err := d.AddPort(context.TODO(),
   226  		port.Spec{
   227  			Proto:      proto,
   228  			ParentIP:   "127.0.0.1",
   229  			ParentPort: parentP,
   230  			ChildPort:  childP,
   231  		})
   232  	if err != nil {
   233  		panic(err)
   234  	}
   235  	t.Logf("opened port: %+v", portStatus)
   236  	if proto == "udp" || proto == "udp4" {
   237  		// Dial does not return an error for UDP even if the port is not exposed yet
   238  		time.Sleep(1 * time.Second)
   239  	}
   240  	var conn net.Conn
   241  	for i := 0; i < 5; i++ {
   242  		var dialer net.Dialer
   243  		conn, err = dialer.Dial(proto, fmt.Sprintf("127.0.0.1:%d", parentP))
   244  		if i == 4 && err != nil {
   245  			panic(err)
   246  		}
   247  		if conn != nil && err == nil {
   248  			break
   249  		}
   250  		time.Sleep(time.Duration(i*5) * time.Millisecond)
   251  	}
   252  	wBytes := []byte(fmt.Sprintf("test-%s-%d-%d-%d", proto, childPID, childP, parentP))
   253  	if _, err := conn.Write(wBytes); err != nil {
   254  		panic(err)
   255  	}
   256  	switch proto {
   257  	case "tcp", "tcp4":
   258  		if err := conn.(*net.TCPConn).CloseWrite(); err != nil {
   259  			panic(err)
   260  		}
   261  	case "udp", "udp4":
   262  		if err := conn.(*net.UDPConn).Close(); err != nil {
   263  			panic(err)
   264  		}
   265  	}
   266  	rBytes := make([]byte, len(wBytes))
   267  	if _, err := stdoutR.Read(rBytes); err != nil {
   268  		panic(err)
   269  	}
   270  	if bytes.Compare(wBytes, rBytes) != 0 {
   271  		panic(fmt.Errorf("expected %q, got %q", string(wBytes), string(rBytes)))
   272  	}
   273  	if proto == "tcp" || proto == "tcp4" {
   274  		if err := conn.Close(); err != nil {
   275  			panic(err)
   276  		}
   277  		if err := cmd.Wait(); err != nil {
   278  			panic(err)
   279  		}
   280  	} else {
   281  		// nc -u does not exit automatically
   282  		syscall.Kill(cmd.Process.Pid, syscall.SIGKILL)
   283  	}
   284  	if err := d.RemovePort(context.TODO(), portStatus.ID); err != nil {
   285  		panic(err)
   286  	}
   287  	t.Logf("closed port ID %d", portStatus.ID)
   288  }
   289  
   290  func ensureDeps(t testing.TB, deps ...string) {
   291  	for _, dep := range deps {
   292  		if _, err := exec.LookPath(dep); err != nil {
   293  			t.Skipf("%q not found: %v", dep, err)
   294  		}
   295  	}
   296  }
   297  
   298  func TLogWriter(t testing.TB, s string) io.Writer {
   299  	return &tLogWriter{t: t, s: s}
   300  }
   301  
   302  type tLogWriter struct {
   303  	t testing.TB
   304  	s string
   305  }
   306  
   307  func (w *tLogWriter) Write(p []byte) (int, error) {
   308  	w.t.Logf("%s: %s", w.s, strings.TrimSuffix(string(p), "\n"))
   309  	return len(p), nil
   310  }