github.com/rootless-containers/rootlesskit/v2@v2.3.4/pkg/port/testsuite/testsuite.go (about) 1 package testsuite 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "io" 9 "net" 10 "os" 11 "os/exec" 12 "strconv" 13 "strings" 14 "sync" 15 "syscall" 16 "testing" 17 "time" 18 19 "github.com/rootless-containers/rootlesskit/v2/pkg/port" 20 ) 21 22 const ( 23 reexecKeyMode = "rootlesskit-port-testsuite.mode" 24 reexecKeyOpaque = "rootlesskit-port-testsuite.opaque" 25 reexecKeyQuitFD = "rootlesskit-port-testsuite.quitfd" 26 ) 27 28 func Main(m *testing.M, cf func() port.ChildDriver) { 29 switch mode := os.Getenv(reexecKeyMode); mode { 30 case "": 31 os.Exit(m.Run()) 32 case "child": 33 default: 34 panic(fmt.Errorf("unknown mode: %q", mode)) 35 } 36 var opaque map[string]string 37 if err := json.Unmarshal([]byte(os.Getenv(reexecKeyOpaque)), &opaque); err != nil { 38 panic(err) 39 } 40 quit := make(chan struct{}) 41 errCh := make(chan error) 42 go func() { 43 d := cf() 44 dErr := d.RunChildDriver(opaque, quit, "") 45 errCh <- dErr 46 }() 47 quitFD, err := strconv.Atoi(os.Getenv(reexecKeyQuitFD)) 48 if err != nil { 49 panic(err) 50 } 51 quitR := os.NewFile(uintptr(quitFD), "") 52 defer quitR.Close() 53 if _, err = io.ReadAll(quitR); err != nil { 54 panic(err) 55 } 56 quit <- struct{}{} 57 err = <-errCh 58 if err != nil { 59 panic(err) 60 } 61 // when race detector is enabled, it takes about 1s after leaving from Main() 62 } 63 64 func Run(t *testing.T, pf func() port.ParentDriver) { 65 RunTCP(t, pf) 66 RunTCP4(t, pf) 67 RunUDP(t, pf) 68 RunUDP4(t, pf) 69 } 70 71 func RunTCP(t *testing.T, pf func() port.ParentDriver) { 72 t.Run("TestTCP", func(t *testing.T) { TestProto(t, "tcp", pf()) }) 73 } 74 75 func RunTCP4(t *testing.T, pf func() port.ParentDriver) { 76 t.Run("TestTCP4", func(t *testing.T) { TestProto(t, "tcp4", pf()) }) 77 } 78 79 func RunUDP(t *testing.T, pf func() port.ParentDriver) { 80 t.Run("TestUDP", func(t *testing.T) { TestProto(t, "udp", pf()) }) 81 } 82 83 func RunUDP4(t *testing.T, pf func() port.ParentDriver) { 84 t.Run("TestUDP4", func(t *testing.T) { TestProto(t, "udp4", pf()) }) 85 } 86 87 func TestProto(t *testing.T, proto string, d port.ParentDriver) { 88 ensureDeps(t, "nsenter") 89 t.Logf("creating USER+NET namespace") 90 opaque := d.OpaqueForChild() 91 opaqueJSON, err := json.Marshal(opaque) 92 if err != nil { 93 t.Fatal(err) 94 } 95 pr, pw, err := os.Pipe() 96 if err != nil { 97 t.Fatal(err) 98 } 99 cmd := exec.Command("/proc/self/exe") 100 cmd.Stdout = os.Stderr 101 cmd.Stderr = os.Stderr 102 cmd.Env = append([]string{ 103 reexecKeyMode + "=child", 104 reexecKeyOpaque + "=" + string(opaqueJSON), 105 reexecKeyQuitFD + "=3"}, os.Environ()...) 106 cmd.SysProcAttr = &syscall.SysProcAttr{ 107 Pdeathsig: syscall.SIGKILL, 108 Cloneflags: syscall.CLONE_NEWUSER | syscall.CLONE_NEWNET, 109 UidMappings: []syscall.SysProcIDMap{ 110 { 111 ContainerID: 0, 112 HostID: os.Geteuid(), 113 Size: 1, 114 }, 115 }, 116 GidMappings: []syscall.SysProcIDMap{ 117 { 118 ContainerID: 0, 119 HostID: os.Getegid(), 120 Size: 1, 121 }, 122 }, 123 } 124 cmd.ExtraFiles = []*os.File{pr} 125 if err := cmd.Start(); err != nil { 126 t.Fatal(err) 127 } 128 defer func() { 129 pw.Close() 130 cmd.Wait() 131 }() 132 childPID := cmd.Process.Pid 133 if out, err := nsenterExec(childPID, "ip", "link", "set", "lo", "up"); err != nil { 134 t.Fatalf("%v, out=%s", err, string(out)) 135 } 136 testProtoWithPID(t, proto, d, childPID) 137 } 138 139 func testProtoWithPID(t *testing.T, proto string, d port.ParentDriver, childPID int) { 140 ensureDeps(t, "nsenter", "ip", "nc") 141 // [child]parent 142 pairs := map[int]int{ 143 // FIXME: flaky 144 80: (childPID + 80) % 60000, 145 8080: (childPID + 8080) % 60000, 146 } 147 if proto == "tcp" { 148 for _, parentPort := range pairs { 149 var d net.Dialer 150 d.Timeout = 50 * time.Millisecond 151 _, err := d.Dial(proto, fmt.Sprintf("127.0.0.1:%d", parentPort)) 152 if err == nil { 153 t.Fatalf("port %d is already used?", parentPort) 154 } 155 } 156 } 157 158 t.Logf("namespace pid: %d", childPID) 159 initComplete := make(chan struct{}) 160 quit := make(chan struct{}) 161 driverErr := make(chan error) 162 go func() { 163 cctx := &port.ChildContext{ 164 IP: nil, // we don't have tap device in this test suite 165 } 166 driverErr <- d.RunParentDriver(initComplete, quit, cctx) 167 }() 168 select { 169 case <-initComplete: 170 case err := <-driverErr: 171 t.Fatal(err) 172 } 173 var wg sync.WaitGroup 174 for c, p := range pairs { 175 childP, parentP := c, p 176 wg.Add(1) 177 go func() { 178 testProtoRoutine(t, proto, d, childPID, childP, parentP) 179 wg.Done() 180 }() 181 } 182 wg.Wait() 183 quit <- struct{}{} 184 err := <-driverErr 185 if err != nil { 186 t.Fatal(err) 187 } 188 } 189 190 func nsenterExec(pid int, cmdss ...string) ([]byte, error) { 191 cmd := exec.Command("nsenter", 192 append([]string{"-U", "--preserve-credential", "-n", "-t", strconv.Itoa(pid)}, 193 cmdss...)...) 194 cmd.SysProcAttr = &syscall.SysProcAttr{ 195 Pdeathsig: syscall.SIGKILL, 196 } 197 return cmd.CombinedOutput() 198 } 199 200 // FIXME: support IPv6 201 func testProtoRoutine(t *testing.T, proto string, d port.ParentDriver, childPID, childP, parentP int) { 202 stdoutR, stdoutW := io.Pipe() 203 var ncFlags []string 204 switch proto { 205 case "tcp", "tcp4": 206 // NOP 207 case "udp", "udp4": 208 ncFlags = append(ncFlags, "-u") 209 default: 210 panic("invalid proto") 211 } 212 cmd := exec.Command("nsenter", append( 213 []string{"-U", "--preserve-credential", "-n", "-t", strconv.Itoa(childPID), 214 "nc"}, append(ncFlags, []string{"-l", strconv.Itoa(childP)}...)...)...) 215 cmd.SysProcAttr = &syscall.SysProcAttr{ 216 Pdeathsig: syscall.SIGKILL, 217 } 218 cmd.Stdout = stdoutW 219 cmd.Stderr = os.Stderr 220 if err := cmd.Start(); err != nil { 221 // NOTE: t.Fatal is not thread-safe while t.Error is (see godoc testing) 222 panic(err) 223 } 224 defer cmd.Process.Kill() 225 portStatus, err := d.AddPort(context.TODO(), 226 port.Spec{ 227 Proto: proto, 228 ParentIP: "127.0.0.1", 229 ParentPort: parentP, 230 ChildPort: childP, 231 }) 232 if err != nil { 233 panic(err) 234 } 235 t.Logf("opened port: %+v", portStatus) 236 if proto == "udp" || proto == "udp4" { 237 // Dial does not return an error for UDP even if the port is not exposed yet 238 time.Sleep(1 * time.Second) 239 } 240 var conn net.Conn 241 for i := 0; i < 5; i++ { 242 var dialer net.Dialer 243 conn, err = dialer.Dial(proto, fmt.Sprintf("127.0.0.1:%d", parentP)) 244 if i == 4 && err != nil { 245 panic(err) 246 } 247 if conn != nil && err == nil { 248 break 249 } 250 time.Sleep(time.Duration(i*5) * time.Millisecond) 251 } 252 wBytes := []byte(fmt.Sprintf("test-%s-%d-%d-%d", proto, childPID, childP, parentP)) 253 if _, err := conn.Write(wBytes); err != nil { 254 panic(err) 255 } 256 switch proto { 257 case "tcp", "tcp4": 258 if err := conn.(*net.TCPConn).CloseWrite(); err != nil { 259 panic(err) 260 } 261 case "udp", "udp4": 262 if err := conn.(*net.UDPConn).Close(); err != nil { 263 panic(err) 264 } 265 } 266 rBytes := make([]byte, len(wBytes)) 267 if _, err := stdoutR.Read(rBytes); err != nil { 268 panic(err) 269 } 270 if bytes.Compare(wBytes, rBytes) != 0 { 271 panic(fmt.Errorf("expected %q, got %q", string(wBytes), string(rBytes))) 272 } 273 if proto == "tcp" || proto == "tcp4" { 274 if err := conn.Close(); err != nil { 275 panic(err) 276 } 277 if err := cmd.Wait(); err != nil { 278 panic(err) 279 } 280 } else { 281 // nc -u does not exit automatically 282 syscall.Kill(cmd.Process.Pid, syscall.SIGKILL) 283 } 284 if err := d.RemovePort(context.TODO(), portStatus.ID); err != nil { 285 panic(err) 286 } 287 t.Logf("closed port ID %d", portStatus.ID) 288 } 289 290 func ensureDeps(t testing.TB, deps ...string) { 291 for _, dep := range deps { 292 if _, err := exec.LookPath(dep); err != nil { 293 t.Skipf("%q not found: %v", dep, err) 294 } 295 } 296 } 297 298 func TLogWriter(t testing.TB, s string) io.Writer { 299 return &tLogWriter{t: t, s: s} 300 } 301 302 type tLogWriter struct { 303 t testing.TB 304 s string 305 } 306 307 func (w *tLogWriter) Write(p []byte) (int, error) { 308 w.t.Logf("%s: %s", w.s, strings.TrimSuffix(string(p), "\n")) 309 return len(p), nil 310 }