github.com/pulumi/pulumi/sdk/v3@v3.108.1/go/common/util/cmdutil/term_test.go (about)

     1  // Copyright 2016-2023, Pulumi Corporation.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cmdutil
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"errors"
    21  	"fmt"
    22  	"io"
    23  	"os/exec"
    24  	"path/filepath"
    25  	"runtime"
    26  	"sync"
    27  	"testing"
    28  	"time"
    29  
    30  	ps "github.com/mitchellh/go-ps"
    31  	"github.com/pulumi/pulumi/sdk/v3/go/common/testing/iotest"
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  )
    35  
    36  func TestTerminate_gracefulShutdown(t *testing.T) {
    37  	t.Parallel()
    38  
    39  	// This test runs commands in a child process, signals them,
    40  	// and expects them to shutdown gracefully.
    41  	//
    42  	// The contract for the child process is as follows:
    43  	//
    44  	//   - It MUST print something to stdout when it is ready to receive signals.
    45  	//   - It MUST exit with a zero code if it receives a SIGINT.
    46  	//   - It MUST exit with a non-zero code if the signal wasn't received within 3 seconds.
    47  	//   - It MAY print diagnostic messages to stderr.
    48  
    49  	tests := []struct {
    50  		desc string
    51  		prog testProgram
    52  	}{
    53  		{desc: "go", prog: goTestProgram.From("graceful.go")},
    54  		{desc: "node", prog: nodeTestProgram.From("graceful.js")},
    55  		{desc: "python", prog: pythonTestProgram.From("graceful.py")},
    56  		{desc: "with child", prog: goTestProgram.From("graceful_with_child.go")},
    57  	}
    58  
    59  	for _, tt := range tests {
    60  		tt := tt
    61  		t.Run(tt.desc, func(t *testing.T) {
    62  			t.Parallel()
    63  
    64  			cmd := tt.prog.Build(t)
    65  
    66  			var stdout lockedBuffer
    67  			cmd.Stdout = io.MultiWriter(&stdout, iotest.LogWriterPrefixed(t, "stdout: "))
    68  			cmd.Stderr = iotest.LogWriterPrefixed(t, "stderr: ")
    69  			require.NoError(t, cmd.Start(), "error starting child process")
    70  
    71  			done := make(chan struct{})
    72  			go func() {
    73  				defer close(done)
    74  
    75  				// Wait until the child process is ready to receive signals.
    76  				for stdout.Len() == 0 {
    77  					time.Sleep(10 * time.Millisecond)
    78  				}
    79  
    80  				ok, err := TerminateProcessGroup(cmd.Process, 1*time.Second)
    81  				assert.True(t, ok, "child process did not exit gracefully")
    82  				assert.NoError(t, err, "error terminating child process")
    83  			}()
    84  
    85  			err := cmd.Wait()
    86  			if isWaitAlreadyExited(err) {
    87  				err = nil
    88  			}
    89  			assert.NoError(t, err, "child did not exit cleanly")
    90  
    91  			<-done
    92  		})
    93  	}
    94  }
    95  
    96  func TestTerminate_gracefulShutdown_exitError(t *testing.T) {
    97  	t.Parallel()
    98  
    99  	// This test runs commands in a child process, signals them,
   100  	// and expects them to shutdown gracefully
   101  	// but with a non-zero exit code.
   102  
   103  	cmd := goTestProgram.From("graceful.go").Args("-exit-code", "1").Build(t)
   104  
   105  	var stdout lockedBuffer
   106  	cmd.Stdout = io.MultiWriter(&stdout, iotest.LogWriterPrefixed(t, "stdout: "))
   107  	cmd.Stderr = iotest.LogWriterPrefixed(t, "stderr: ")
   108  	require.NoError(t, cmd.Start(), "error starting child process")
   109  
   110  	// Wait until the child process is ready to receive signals.
   111  	for stdout.Len() == 0 {
   112  		time.Sleep(10 * time.Millisecond)
   113  	}
   114  
   115  	ok, err := TerminateProcessGroup(cmd.Process, 1*time.Second)
   116  	assert.True(t, ok, "child process did not exit gracefully")
   117  	require.Error(t, err, "child process must exit with non-zero code")
   118  
   119  	var exitErr *exec.ExitError
   120  	if assert.ErrorAs(t, err, &exitErr, "expected ExitError from child process") {
   121  		assert.Equal(t, 1, exitErr.ExitCode(), "unexpected exit code from child process")
   122  	}
   123  }
   124  
   125  func TestTerminate_forceKill(t *testing.T) {
   126  	t.Parallel()
   127  
   128  	// This test runs commands in a child process, signals them,
   129  	// and expects them to not exit in a timely manner.
   130  	//
   131  	// The contract for the child process is the same as gracefulShutdown,
   132  	// except:
   133  	//
   134  	//   - It MUST freeze for at least 1 second after it receives a SIGINT.
   135  	//   - It MAY exit with a non-zero code if it receives a SIGINT.
   136  
   137  	tests := []struct {
   138  		desc string
   139  		prog testProgram
   140  	}{
   141  		{desc: "go", prog: goTestProgram.From("frozen.go")},
   142  		{desc: "node", prog: nodeTestProgram.From("frozen.js")},
   143  		{desc: "python", prog: pythonTestProgram.From("frozen.py")},
   144  	}
   145  
   146  	for _, tt := range tests {
   147  		tt := tt
   148  		t.Run(tt.desc, func(t *testing.T) {
   149  			t.Parallel()
   150  
   151  			cmd := tt.prog.Build(t)
   152  
   153  			var stdout lockedBuffer
   154  			cmd.Stdout = io.MultiWriter(&stdout, iotest.LogWriterPrefixed(t, "stdout: "))
   155  			cmd.Stderr = iotest.LogWriterPrefixed(t, "stderr: ")
   156  			require.NoError(t, cmd.Start(), "error starting child process")
   157  
   158  			// Wait until the child process is ready to receive signals.
   159  			for stdout.Len() == 0 {
   160  				time.Sleep(10 * time.Millisecond)
   161  			}
   162  
   163  			pid := cmd.Process.Pid
   164  			done := make(chan struct{})
   165  			go func() {
   166  				defer close(done)
   167  
   168  				ok, err := TerminateProcessGroup(cmd.Process, 50*time.Millisecond)
   169  				assert.False(t, ok, "child process should not exit gracefully")
   170  				assert.NoError(t, err, "error terminating child process")
   171  			}()
   172  
   173  			select {
   174  			case <-done:
   175  				// continue
   176  
   177  			case <-time.After(200 * time.Millisecond):
   178  				// If the process is not killed,
   179  				// cmd.Wait() will block until it exits.
   180  				t.Fatal("Took too long to kill child process")
   181  			}
   182  
   183  			assert.NoError(t,
   184  				waitPidDead(pid, 100*time.Millisecond),
   185  				"error waiting for process to die")
   186  		})
   187  	}
   188  }
   189  
   190  func TestTerminate_forceKill_processGroup(t *testing.T) {
   191  	t.Parallel()
   192  
   193  	// This is a variant of TestTerminate_forceKill
   194  	// that verifies that a child process of the test process
   195  	// is also killed.
   196  
   197  	cmd := goTestProgram.From("frozen_with_child.go").Build(t)
   198  
   199  	var stdout lockedBuffer
   200  	cmd.Stdout = io.MultiWriter(&stdout, iotest.LogWriterPrefixed(t, "stdout: "))
   201  	cmd.Stderr = iotest.LogWriterPrefixed(t, "stderr: ")
   202  	require.NoError(t, cmd.Start(), "error starting child process")
   203  
   204  	// Wait until the child process is ready to receive signals.
   205  	for stdout.Len() == 0 {
   206  		time.Sleep(10 * time.Millisecond)
   207  	}
   208  
   209  	pid := cmd.Process.Pid
   210  	childPid := -1
   211  
   212  	procs, err := ps.Processes()
   213  	require.NoError(t, err, "error listing processes")
   214  	for _, proc := range procs {
   215  		if proc.PPid() == pid {
   216  			childPid = proc.Pid()
   217  			break
   218  		}
   219  	}
   220  	require.NotEqual(t, -1, childPid, "child process not found")
   221  
   222  	done := make(chan struct{})
   223  	go func() {
   224  		defer close(done)
   225  
   226  		ok, err := TerminateProcessGroup(cmd.Process, time.Millisecond)
   227  		assert.False(t, ok, "child process should not exit gracefully")
   228  		assert.NoError(t, err, "error terminating child process")
   229  	}()
   230  
   231  	select {
   232  	case <-done:
   233  		// continue
   234  
   235  	case <-time.After(100 * time.Millisecond):
   236  		// If the child process is not killed,
   237  		// cmd.Wait() will block until it exits.
   238  		t.Fatal("Took too long to kill child process")
   239  	}
   240  
   241  	for _, pid := range []int{pid, childPid} {
   242  		assert.NoError(t,
   243  			waitPidDead(pid, 100*time.Millisecond),
   244  			"error waiting for process to die")
   245  	}
   246  }
   247  
   248  func TestTerminate_unhandledInterrupt(t *testing.T) {
   249  	t.Parallel()
   250  
   251  	// This test runs programs that do not have an interrupt handler.
   252  	// Contract for child process:
   253  	//
   254  	// - It MUST print to stdout when it's ready.
   255  	// - It MUST exit with a non-zero code if it does not get terminated within 3 seconds.
   256  
   257  	tests := []struct {
   258  		desc string
   259  		prog testProgram
   260  	}{
   261  		{desc: "go", prog: goTestProgram.From("unhandled.go")},
   262  		{desc: "node", prog: nodeTestProgram.From("unhandled.js")},
   263  		{desc: "python", prog: pythonTestProgram.From("unhandled.py")},
   264  	}
   265  
   266  	for _, tt := range tests {
   267  		tt := tt
   268  		t.Run(tt.desc, func(t *testing.T) {
   269  			t.Parallel()
   270  
   271  			cmd := tt.prog.Build(t)
   272  
   273  			var stdout lockedBuffer
   274  			cmd.Stdout = io.MultiWriter(&stdout, iotest.LogWriterPrefixed(t, "stdout: "))
   275  			cmd.Stderr = iotest.LogWriterPrefixed(t, "stderr: ")
   276  			require.NoError(t, cmd.Start(), "error starting child process")
   277  
   278  			// Wait until the child process is ready to receive signals.
   279  			for stdout.Len() == 0 {
   280  				time.Sleep(10 * time.Millisecond)
   281  			}
   282  
   283  			pid := cmd.Process.Pid
   284  			done := make(chan struct{})
   285  			go func() {
   286  				defer close(done)
   287  
   288  				ok, err := TerminateProcessGroup(cmd.Process, 200*time.Millisecond)
   289  				assert.True(t, ok, "child process did not exit gracefully")
   290  				assert.Error(t, err, "child process should have exited with an error")
   291  			}()
   292  
   293  			select {
   294  			case <-done:
   295  				// continue
   296  
   297  			case <-time.After(200 * time.Millisecond):
   298  				// Took too long to kill the child process.
   299  				t.Fatal("Took too long to kill child process")
   300  			}
   301  
   302  			assert.NoError(t,
   303  				waitPidDead(pid, 100*time.Millisecond),
   304  				"error waiting for process to die")
   305  		})
   306  	}
   307  }
   308  
   309  type testProgramKind int
   310  
   311  const (
   312  	goTestProgram testProgramKind = iota
   313  	nodeTestProgram
   314  	pythonTestProgram
   315  )
   316  
   317  func (k testProgramKind) String() string {
   318  	switch k {
   319  	case goTestProgram:
   320  		return "go"
   321  	case nodeTestProgram:
   322  		return "node"
   323  	case pythonTestProgram:
   324  		return "python"
   325  	default:
   326  		return fmt.Sprintf("testProgramKind(%d)", int(k))
   327  	}
   328  }
   329  
   330  // From builds a testProgram of this kind
   331  // with the given source file.
   332  //
   333  // Usage:
   334  //
   335  //	goTestProgram.From("main.go")
   336  func (k testProgramKind) From(path string) testProgram {
   337  	return testProgram{
   338  		kind: k,
   339  		src:  path,
   340  	}
   341  }
   342  
   343  // testProgram is a test program inside the testdata directory.
   344  type testProgram struct {
   345  	// kind is the kind of test program.
   346  	kind testProgramKind
   347  
   348  	// src is the path to the source file
   349  	// relative to the testdata directory.
   350  	src string
   351  
   352  	// args specifies additional arguments to pass to the program.
   353  	args []string
   354  }
   355  
   356  func (p testProgram) Args(args ...string) testProgram {
   357  	p.args = args
   358  	return p
   359  }
   360  
   361  // Build builds an exec.Cmd for the test program.
   362  // It skips the test if the program runner is not found.
   363  func (p testProgram) Build(t *testing.T) (cmd *exec.Cmd) {
   364  	t.Helper()
   365  
   366  	defer func() {
   367  		// Make sure that the returned command
   368  		// is part of the process group.
   369  		if cmd != nil {
   370  			RegisterProcessGroup(cmd)
   371  		}
   372  	}()
   373  
   374  	src := filepath.Join("testdata", p.src)
   375  	switch p.kind {
   376  	case goTestProgram:
   377  		goBin := lookPathOrSkip(t, "go")
   378  		bin := filepath.Join(t.TempDir(), "main")
   379  		if runtime.GOOS == "windows" {
   380  			bin += ".exe"
   381  		}
   382  
   383  		buildCmd := exec.Command(goBin, "build", "-o", bin, src)
   384  		buildOutput := iotest.LogWriterPrefixed(t, "build: ")
   385  		buildCmd.Stdout = buildOutput
   386  		buildCmd.Stderr = buildOutput
   387  		require.NoError(t, buildCmd.Run(), "error building test program")
   388  
   389  		return exec.Command(bin, p.args...)
   390  
   391  	case nodeTestProgram:
   392  		nodeBin := lookPathOrSkip(t, "node")
   393  		return exec.Command(nodeBin, append([]string{src}, p.args...)...)
   394  
   395  	case pythonTestProgram:
   396  		pythonBin := lookPathOrSkip(t, "python")
   397  		return exec.Command(pythonBin, append([]string{src}, p.args...)...)
   398  
   399  	default:
   400  		t.Fatalf("unknown test program kind: %v", p.kind)
   401  		return nil
   402  	}
   403  }
   404  
   405  func lookPathOrSkip(t *testing.T, name string) string {
   406  	path, err := exec.LookPath(name)
   407  	if err != nil {
   408  		t.Skipf("Skipping test: %q not found: %v", name, err)
   409  	}
   410  	return path
   411  }
   412  
   413  // lockedBuffer is a thread-safe bytes.Buffer
   414  // that can be used to capture stdout/stderr of a command.
   415  type lockedBuffer struct {
   416  	mu sync.RWMutex
   417  	b  bytes.Buffer
   418  }
   419  
   420  func (b *lockedBuffer) Write(p []byte) (int, error) {
   421  	b.mu.Lock()
   422  	defer b.mu.Unlock()
   423  	return b.b.Write(p)
   424  }
   425  
   426  func (b *lockedBuffer) Len() int {
   427  	b.mu.RLock()
   428  	defer b.mu.RUnlock()
   429  	return b.b.Len()
   430  }
   431  
   432  // Waits until the process with the given pid doesn't exist anymore
   433  // or the given timeout has elapsed.
   434  //
   435  // Returns an error if the timeout has elapsed.
   436  func waitPidDead(pid int, timeout time.Duration) error {
   437  	ctx, cancel := context.WithTimeout(context.Background(), timeout)
   438  	defer cancel()
   439  
   440  	var (
   441  		proc ps.Process
   442  		err  error
   443  	)
   444  	for {
   445  		select {
   446  		case <-ctx.Done():
   447  			var errs []error
   448  			if proc != nil {
   449  				errs = append(errs, fmt.Errorf("process %d still exists: %v", pid, proc))
   450  			}
   451  			if err != nil {
   452  				errs = append(errs, fmt.Errorf("find process: %w", err))
   453  			}
   454  
   455  			return fmt.Errorf("waitPidDead %v: %w", pid, errors.Join(errs...))
   456  
   457  		default:
   458  			proc, err = ps.FindProcess(pid)
   459  			if err == nil && proc == nil {
   460  				return nil
   461  			}
   462  			time.Sleep(10 * time.Millisecond)
   463  		}
   464  	}
   465  }