github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/test/testutil/sh.go (about) 1 // Copyright 2020 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package testutil 16 17 import ( 18 "bytes" 19 "context" 20 "fmt" 21 "io" 22 "os" 23 "os/exec" 24 "strings" 25 "time" 26 27 "github.com/kr/pty" 28 "golang.org/x/sys/unix" 29 ) 30 31 // Prompt is used as shell prompt. 32 // It is meant to be unique enough to not be seen in command outputs. 33 const Prompt = "PROMPT> " 34 35 // Simplistic shell string escape. 36 func shellEscape(s string) string { 37 // specialChars is used to determine whether s needs quoting at all. 38 const specialChars = "\\'\"`${[|&;<>()*?! \t\n" 39 // If s needs quoting, escapedChars is the set of characters that are 40 // escaped with a backslash. 41 const escapedChars = "\\\"$`" 42 if len(s) == 0 { 43 return "''" 44 } 45 if !strings.ContainsAny(s, specialChars) { 46 return s 47 } 48 var b bytes.Buffer 49 b.WriteString("\"") 50 for _, c := range s { 51 if strings.ContainsAny(string(c), escapedChars) { 52 b.WriteString("\\") 53 } 54 b.WriteRune(c) 55 } 56 b.WriteString("\"") 57 return b.String() 58 } 59 60 type byteOrError struct { 61 b byte 62 err error 63 } 64 65 // Shell manages a /bin/sh invocation with convenience functions to handle I/O. 66 // The shell is run in its own interactive TTY and should present its prompt. 67 type Shell struct { 68 // cmd is a reference to the underlying sh process. 69 cmd *exec.Cmd 70 // cmdFinished is closed when cmd exits. 71 cmdFinished chan struct{} 72 73 // echo is whether the shell will echo input back to us. 74 // This helps setting expectations of getting feedback of written bytes. 75 echo bool 76 // Control characters we expect to see in the shell. 77 controlCharIntr string 78 controlCharEOF string 79 80 // ptyMaster and ptyReplica are the TTY pair associated with the shell. 81 ptyMaster *os.File 82 ptyReplica *os.File 83 // readCh is a channel where everything read from ptyMaster is written. 84 readCh chan byteOrError 85 86 // logger is used for logging. It may be nil. 87 logger Logger 88 } 89 90 // cleanup kills the shell process and closes the TTY. 91 // Users of this library get a reference to this function with NewShell. 92 func (s *Shell) cleanup() { 93 s.logf("cleanup", "Shell cleanup started.") 94 if s.cmd.ProcessState == nil { 95 if err := s.cmd.Process.Kill(); err != nil { 96 s.logf("cleanup", "cannot kill shell process: %v", err) 97 } 98 // We don't log the error returned by Wait because the monitorExit 99 // goroutine will already do so. 100 s.cmd.Wait() 101 } 102 s.ptyReplica.Close() 103 s.ptyMaster.Close() 104 // Wait for monitorExit goroutine to write exit status to the debug log. 105 <-s.cmdFinished 106 // Empty out everything in the readCh, but don't wait too long for it. 107 var extraBytes bytes.Buffer 108 unreadTimeout := time.After(100 * time.Millisecond) 109 unreadLoop: 110 for { 111 select { 112 case r, ok := <-s.readCh: 113 if !ok { 114 break unreadLoop 115 } else if r.err == nil { 116 extraBytes.WriteByte(r.b) 117 } 118 case <-unreadTimeout: 119 break unreadLoop 120 } 121 } 122 if extraBytes.Len() > 0 { 123 s.logIO("unread", extraBytes.Bytes(), nil) 124 } 125 s.logf("cleanup", "Shell cleanup complete.") 126 } 127 128 // logIO logs byte I/O to both standard logging and the test log, if provided. 129 func (s *Shell) logIO(prefix string, b []byte, err error) { 130 var sb strings.Builder 131 if len(b) > 0 { 132 sb.WriteString(fmt.Sprintf("%q", b)) 133 } else { 134 sb.WriteString("(nothing)") 135 } 136 if err != nil { 137 sb.WriteString(fmt.Sprintf(" [error: %v]", err)) 138 } 139 s.logf(prefix, "%s", sb.String()) 140 } 141 142 // logf logs something to both standard logging and the test log, if provided. 143 func (s *Shell) logf(prefix, format string, values ...interface{}) { 144 if s.logger != nil { 145 s.logger.Logf("[%s] %s", prefix, fmt.Sprintf(format, values...)) 146 } 147 } 148 149 // monitorExit waits for the shell process to exit and logs the exit result. 150 func (s *Shell) monitorExit() { 151 if err := s.cmd.Wait(); err != nil { 152 s.logf("cmd", "shell process terminated: %v", err) 153 } else { 154 s.logf("cmd", "shell process terminated successfully") 155 } 156 close(s.cmdFinished) 157 } 158 159 // reader continuously reads the shell output and populates readCh. 160 func (s *Shell) reader(ctx context.Context) { 161 b := make([]byte, 4096) 162 defer close(s.readCh) 163 for { 164 select { 165 case <-s.cmdFinished: 166 // Shell process terminated; stop trying to read. 167 return 168 case <-ctx.Done(): 169 // Shell process will also have terminated in this case; 170 // stop trying to read. 171 // We don't print an error here because doing so would print this in the 172 // normal case where the context passed to NewShell is canceled at the 173 // end of a successful test. 174 return 175 default: 176 // Shell still running, try reading. 177 } 178 if got, err := s.ptyMaster.Read(b); err != nil { 179 s.readCh <- byteOrError{err: err} 180 if err == io.EOF { 181 return 182 } 183 } else { 184 for i := 0; i < got; i++ { 185 s.readCh <- byteOrError{b: b[i]} 186 } 187 } 188 } 189 } 190 191 // readByte reads a single byte, respecting the context. 192 func (s *Shell) readByte(ctx context.Context) (byte, error) { 193 select { 194 case <-ctx.Done(): 195 return 0, ctx.Err() 196 case r := <-s.readCh: 197 return r.b, r.err 198 } 199 } 200 201 // readLoop reads as many bytes as possible until the context expires, b is 202 // full, or a short time passes. It returns how many bytes it has successfully 203 // read. 204 func (s *Shell) readLoop(ctx context.Context, b []byte) (int, error) { 205 soonCtx, soonCancel := context.WithTimeout(ctx, 5*time.Second) 206 defer soonCancel() 207 var i int 208 for i = 0; i < len(b) && soonCtx.Err() == nil; i++ { 209 next, err := s.readByte(soonCtx) 210 if err != nil { 211 if i > 0 { 212 s.logIO("read", b[:i-1], err) 213 } else { 214 s.logIO("read", nil, err) 215 } 216 return i, err 217 } 218 b[i] = next 219 } 220 s.logIO("read", b[:i], soonCtx.Err()) 221 return i, soonCtx.Err() 222 } 223 224 // readLine reads a single line. Strips out all \r characters for convenience. 225 // Upon error, it will still return what it has read so far. 226 // It will also exit quickly if the line content it has read so far (without a 227 // line break) matches `prompt`. 228 func (s *Shell) readLine(ctx context.Context, prompt string) ([]byte, error) { 229 soonCtx, soonCancel := context.WithTimeout(ctx, 5*time.Second) 230 defer soonCancel() 231 var lineData bytes.Buffer 232 var b byte 233 var err error 234 for soonCtx.Err() == nil && b != '\n' { 235 b, err = s.readByte(soonCtx) 236 if err != nil { 237 data := lineData.Bytes() 238 s.logIO("read", data, err) 239 return data, err 240 } 241 if b != '\r' { 242 lineData.WriteByte(b) 243 } 244 if bytes.Equal(lineData.Bytes(), []byte(prompt)) { 245 // Assume that there will not be any further output if we get the prompt. 246 // This avoids waiting for the read deadline just to read the prompt. 247 break 248 } 249 } 250 data := lineData.Bytes() 251 s.logIO("read", data, soonCtx.Err()) 252 return data, soonCtx.Err() 253 } 254 255 // Expect verifies that the next `len(want)` bytes we read match `want`. 256 func (s *Shell) Expect(ctx context.Context, want []byte) error { 257 errPrefix := fmt.Sprintf("want(%q)", want) 258 b := make([]byte, len(want)) 259 got, err := s.readLoop(ctx, b) 260 if err != nil { 261 if ctx.Err() != nil { 262 return fmt.Errorf("%s: context done (%w), got: %q", errPrefix, err, b[:got]) 263 } 264 return fmt.Errorf("%s: %w", errPrefix, err) 265 } 266 if got < len(want) { 267 return fmt.Errorf("%s: short read (read %d bytes, expected %d): %q", errPrefix, got, len(want), b[:got]) 268 } 269 if !bytes.Equal(b, want) { 270 return fmt.Errorf("got %q want %q", b, want) 271 } 272 return nil 273 } 274 275 // ExpectString verifies that the next `len(want)` bytes we read match `want`. 276 func (s *Shell) ExpectString(ctx context.Context, want string) error { 277 return s.Expect(ctx, []byte(want)) 278 } 279 280 // ExpectPrompt verifies that the next few bytes we read are the shell prompt. 281 func (s *Shell) ExpectPrompt(ctx context.Context) error { 282 return s.ExpectString(ctx, Prompt) 283 } 284 285 // ExpectEmptyLine verifies that the next few bytes we read are an empty line, 286 // as defined by any number of carriage or line break characters. 287 func (s *Shell) ExpectEmptyLine(ctx context.Context) error { 288 line, err := s.readLine(ctx, Prompt) 289 if err != nil { 290 return fmt.Errorf("cannot read line: %w", err) 291 } 292 if strings.Trim(string(line), "\r\n") != "" { 293 return fmt.Errorf("line was not empty: %q", line) 294 } 295 return nil 296 } 297 298 // ExpectLine verifies that the next `len(want)` bytes we read match `want`, 299 // followed by carriage returns or newline characters. 300 func (s *Shell) ExpectLine(ctx context.Context, want string) error { 301 if err := s.ExpectString(ctx, want); err != nil { 302 return err 303 } 304 if err := s.ExpectEmptyLine(ctx); err != nil { 305 return fmt.Errorf("ExpectLine(%q): no line break: %w", want, err) 306 } 307 return nil 308 } 309 310 // Write writes `b` to the shell and verifies that all of them get written. 311 func (s *Shell) Write(b []byte) error { 312 written, err := s.ptyMaster.Write(b) 313 s.logIO("write", b[:written], err) 314 if err != nil { 315 return fmt.Errorf("write(%q): %w", b, err) 316 } 317 if written != len(b) { 318 return fmt.Errorf("write(%q): wrote %d of %d bytes (%q)", b, written, len(b), b[:written]) 319 } 320 return nil 321 } 322 323 // WriteLine writes `line` (to which \n will be appended) to the shell. 324 // If the shell is in `echo` mode, it will also check that we got these bytes 325 // back to read. 326 func (s *Shell) WriteLine(ctx context.Context, line string) error { 327 if err := s.Write([]byte(line + "\n")); err != nil { 328 return err 329 } 330 if s.echo { 331 // We expect to see everything we've typed. 332 if err := s.ExpectLine(ctx, line); err != nil { 333 return fmt.Errorf("echo: %w", err) 334 } 335 } 336 return nil 337 } 338 339 // StartCommand is a convenience wrapper for WriteLine that mimics entering a 340 // command line and pressing Enter. It does some basic shell argument escaping. 341 func (s *Shell) StartCommand(ctx context.Context, cmd ...string) error { 342 escaped := make([]string, len(cmd)) 343 for i, arg := range cmd { 344 escaped[i] = shellEscape(arg) 345 } 346 return s.WriteLine(ctx, strings.Join(escaped, " ")) 347 } 348 349 // GetCommandOutput gets all following bytes until the prompt is encountered. 350 // This is useful for matching the output of a command. 351 // All \r are removed for ease of matching. 352 func (s *Shell) GetCommandOutput(ctx context.Context) ([]byte, error) { 353 return s.ReadUntil(ctx, Prompt) 354 } 355 356 // ReadUntil gets all following bytes until a certain line is encountered. 357 // This final line is not returned as part of the output, but everything before 358 // it (including the \n) is included. 359 // This is useful for matching the output of a command. 360 // All \r are removed for ease of matching. 361 func (s *Shell) ReadUntil(ctx context.Context, finalLine string) ([]byte, error) { 362 var output bytes.Buffer 363 for ctx.Err() == nil { 364 line, err := s.readLine(ctx, finalLine) 365 if err != nil { 366 return nil, err 367 } 368 if bytes.Equal(line, []byte(finalLine)) { 369 break 370 } 371 // readLine ensures that `line` either matches `finalLine` or contains \n. 372 // Thus we can be confident that `line` has a \n here. 373 output.Write(line) 374 } 375 return output.Bytes(), ctx.Err() 376 } 377 378 // RunCommand is a convenience wrapper for StartCommand + GetCommandOutput. 379 func (s *Shell) RunCommand(ctx context.Context, cmd ...string) ([]byte, error) { 380 if err := s.StartCommand(ctx, cmd...); err != nil { 381 return nil, err 382 } 383 return s.GetCommandOutput(ctx) 384 } 385 386 // RefreshSTTY interprets output from `stty -a` to check whether we are in echo 387 // mode and other settings. 388 // It will assume that any line matching `expectPrompt` means the end of 389 // the `stty -a` output. 390 // Why do this rather than using `tcgets`? Because this function can be used in 391 // conjunction with sub-shell processes that can allocate their own TTYs. 392 func (s *Shell) RefreshSTTY(ctx context.Context, expectPrompt string) error { 393 // Temporarily assume we will not get any output. 394 // If echo is actually on, we'll get the "stty -a" line as if it was command 395 // output. This is OK because we parse the output generously. 396 s.echo = false 397 if err := s.WriteLine(ctx, "stty -a"); err != nil { 398 return fmt.Errorf("could not run `stty -a`: %w", err) 399 } 400 sttyOutput, err := s.ReadUntil(ctx, expectPrompt) 401 if err != nil { 402 return fmt.Errorf("cannot get `stty -a` output: %w", err) 403 } 404 405 // Set default control characters in case we can't see them in the output. 406 s.controlCharIntr = "^C" 407 s.controlCharEOF = "^D" 408 // stty output has two general notations: 409 // `a = b;` (for control characters), and `option` vs `-option` (for boolean 410 // options). We parse both kinds here. 411 // For `a = b;`, `controlChar` contains `a`, and `previousToken` is used to 412 // set `controlChar` to `previousToken` when we see an "=" token. 413 var previousToken, controlChar string 414 for _, token := range strings.Fields(string(sttyOutput)) { 415 if controlChar != "" { 416 value := strings.TrimSuffix(token, ";") 417 switch controlChar { 418 case "intr": 419 s.controlCharIntr = value 420 case "eof": 421 s.controlCharEOF = value 422 } 423 controlChar = "" 424 } else { 425 switch token { 426 case "=": 427 controlChar = previousToken 428 case "-echo": 429 s.echo = false 430 case "echo": 431 s.echo = true 432 } 433 } 434 previousToken = token 435 } 436 s.logf("stty", "refreshed settings: echo=%v, intr=%q, eof=%q", s.echo, s.controlCharIntr, s.controlCharEOF) 437 return nil 438 } 439 440 // sendControlCode sends `code` to the shell and expects to see `repr`. 441 // If `expectLinebreak` is true, it also expects to see a linebreak. 442 func (s *Shell) sendControlCode(ctx context.Context, code byte, repr string, expectLinebreak bool) error { 443 if err := s.Write([]byte{code}); err != nil { 444 return fmt.Errorf("cannot send %q: %w", code, err) 445 } 446 if err := s.ExpectString(ctx, repr); err != nil { 447 return fmt.Errorf("did not see %s: %w", repr, err) 448 } 449 if expectLinebreak { 450 if err := s.ExpectEmptyLine(ctx); err != nil { 451 return fmt.Errorf("linebreak after %s: %v", repr, err) 452 } 453 } 454 return nil 455 } 456 457 // SendInterrupt sends the \x03 (Ctrl+C) control character to the shell. 458 func (s *Shell) SendInterrupt(ctx context.Context, expectLinebreak bool) error { 459 return s.sendControlCode(ctx, 0x03, s.controlCharIntr, expectLinebreak) 460 } 461 462 // SendEOF sends the \x04 (Ctrl+D) control character to the shell. 463 func (s *Shell) SendEOF(ctx context.Context, expectLinebreak bool) error { 464 return s.sendControlCode(ctx, 0x04, s.controlCharEOF, expectLinebreak) 465 } 466 467 // NewShell returns a new managed sh process along with a cleanup function. 468 // The caller is expected to call this function once it no longer needs the 469 // shell. 470 // The optional passed-in logger will be used for logging. 471 func NewShell(ctx context.Context, logger Logger) (*Shell, func(), error) { 472 ptyMaster, ptyReplica, err := pty.Open() 473 if err != nil { 474 return nil, nil, fmt.Errorf("cannot create PTY: %w", err) 475 } 476 cmd := exec.CommandContext(ctx, "/bin/sh", "--noprofile", "--norc", "-i") 477 cmd.Stdin = ptyReplica 478 cmd.Stdout = ptyReplica 479 cmd.Stderr = ptyReplica 480 cmd.SysProcAttr = &unix.SysProcAttr{ 481 Setsid: true, 482 Setctty: true, 483 Ctty: 0, 484 } 485 cmd.Env = append(cmd.Env, fmt.Sprintf("PS1=%s", Prompt)) 486 if err := cmd.Start(); err != nil { 487 return nil, nil, fmt.Errorf("cannot start shell: %w", err) 488 } 489 s := &Shell{ 490 cmd: cmd, 491 cmdFinished: make(chan struct{}), 492 ptyMaster: ptyMaster, 493 ptyReplica: ptyReplica, 494 readCh: make(chan byteOrError, 1<<20), 495 logger: logger, 496 } 497 s.logf("creation", "Shell spawned.") 498 go s.monitorExit() 499 go s.reader(ctx) 500 setupCtx, setupCancel := context.WithTimeout(ctx, 5*time.Second) 501 defer setupCancel() 502 // We expect to see the prompt immediately on startup, 503 // since the shell is started in interactive mode. 504 if err := s.ExpectPrompt(setupCtx); err != nil { 505 s.cleanup() 506 return nil, nil, fmt.Errorf("did not get initial prompt: %w", err) 507 } 508 s.logf("creation", "Initial prompt observed.") 509 // Get initial TTY settings. 510 if err := s.RefreshSTTY(setupCtx, Prompt); err != nil { 511 s.cleanup() 512 return nil, nil, fmt.Errorf("cannot get initial STTY settings: %w", err) 513 } 514 return s, s.cleanup, nil 515 }