
     1  package proxyprocess
     3  import (
     4  	"fmt"
     5  	"io/ioutil"
     6  	"log"
     7  	"os"
     8  	"os/exec"
     9  	"os/signal"
    10  	"sort"
    11  	"strconv"
    12  	"strings"
    13  	"syscall"
    14  	"testing"
    15  	"time"
    16  )
    18  // testLogger is a logger that can be used by tests that require a
    19  // *log.Logger instance.
    20  var testLogger = log.New(os.Stderr, "logger: ", log.LstdFlags)
    22  // testTempDir returns a temporary directory and a cleanup function.
    23  func testTempDir(t *testing.T) (string, func()) {
    24  	t.Helper()
    26  	td, err := ioutil.TempDir("", "test-agent-proxy")
    27  	if err != nil {
    28  		t.Fatalf("err: %s", err)
    29  	}
    31  	return td, func() {
    32  		if err := os.RemoveAll(td); err != nil {
    33  			t.Fatalf("err: %s", err)
    34  		}
    35  	}
    36  }
    38  // helperProcessSentinel is a sentinel value that is put as the first
    39  // argument following "--" and is used to determine if TestHelperProcess
    40  // should run.
    41  const helperProcessSentinel = "WANT_HELPER_PROCESS"
    43  // helperProcess returns an *exec.Cmd that can be used to execute the
    44  // TestHelperProcess function below. This can be used to test multi-process
    45  // interactions.
    46  func helperProcess(s ...string) (*exec.Cmd, func()) {
    47  	cs := []string{"", "--", helperProcessSentinel}
    48  	cs = append(cs, s...)
    50  	cmd := exec.Command(os.Args[0], cs...)
    51  	cmd.Stdout = os.Stdout
    52  	cmd.Stderr = os.Stderr
    53  	destroy := func() {
    54  		if p := cmd.Process; p != nil {
    55  			p.Kill()
    56  		}
    57  	}
    58  	return cmd, destroy
    59  }
    61  // This is not a real test. This is just a helper process kicked off by tests
    62  // using the helperProcess helper function.
    63  func TestHelperProcess(t *testing.T) {
    64  	args := os.Args
    65  	for len(args) > 0 {
    66  		if args[0] == "--" {
    67  			args = args[1:]
    68  			break
    69  		}
    71  		args = args[1:]
    72  	}
    74  	if len(args) == 0 || args[0] != helperProcessSentinel {
    75  		return
    76  	}
    78  	defer os.Exit(0)
    79  	args = args[1:] // strip sentinel value
    80  	cmd, args := args[0], args[1:]
    81  	switch cmd {
    82  	// While running, this creates a file in the given directory (args[0])
    83  	// and deletes it only when it is stopped.
    84  	case "start-stop":
    85  		limitProcessLifetime(2 * time.Minute)
    87  		ch := make(chan os.Signal, 1)
    88  		signal.Notify(ch, os.Interrupt, syscall.SIGTERM)
    89  		defer signal.Stop(ch)
    91  		path := args[0]
    92  		var data []byte
    93  		data = append(data, []byte(os.Getenv(EnvProxyID))...)
    94  		data = append(data, ':')
    95  		data = append(data, []byte(os.Getenv(EnvProxyToken))...)
    97  		if err := ioutil.WriteFile(path, data, 0644); err != nil {
    98  			t.Fatalf("err: %s", err)
    99  		}
   100  		defer os.Remove(path)
   102  		<-ch
   104  	// Restart writes to a file and keeps running while that file still
   105  	// exists. When that file is removed, this process exits. This can be
   106  	// used to test restarting.
   107  	case "restart":
   108  		limitProcessLifetime(2 * time.Minute)
   110  		ch := make(chan os.Signal, 1)
   111  		signal.Notify(ch, os.Interrupt)
   112  		defer signal.Stop(ch)
   114  		// Write the file
   115  		path := args[0]
   116  		if err := ioutil.WriteFile(path, []byte("hello"), 0644); err != nil {
   117  			fmt.Fprintf(os.Stderr, "Error: %s\n", err)
   118  			os.Exit(1)
   119  		}
   121  		// While the file still exists, do nothing. When the file no longer
   122  		// exists, we exit.
   123  		for {
   124  			time.Sleep(25 * time.Millisecond)
   125  			if _, err := os.Stat(path); os.IsNotExist(err) {
   126  				break
   127  			}
   129  			select {
   130  			case <-ch:
   131  				// We received an interrupt, clean exit
   132  				os.Remove(path)
   133  				break
   135  			default:
   136  			}
   137  		}
   138  	case "stop-kill":
   139  		limitProcessLifetime(2 * time.Minute)
   141  		// Setup listeners so it is ignored
   142  		ch := make(chan os.Signal, 1)
   143  		signal.Notify(ch, os.Interrupt)
   144  		defer signal.Stop(ch)
   146  		path := args[0]
   147  		data := []byte(os.Getenv(EnvProxyToken))
   148  		for {
   149  			if err := ioutil.WriteFile(path, data, 0644); err != nil {
   150  				t.Fatalf("err: %s", err)
   151  			}
   152  			time.Sleep(25 * time.Millisecond)
   153  		}
   154  		// Check if the external process can access the enivironmental variables
   155  	case "environ":
   156  		limitProcessLifetime(2 * time.Minute)
   158  		stop := make(chan os.Signal, 1)
   159  		signal.Notify(stop, os.Interrupt)
   160  		defer signal.Stop(stop)
   162  		//Get the path for the file to be written to
   163  		path := args[0]
   164  		var data []byte
   166  		//Get the environmental variables
   167  		envData := os.Environ()
   169  		//Sort the env data for easier comparison
   170  		sort.Strings(envData)
   171  		for _, envVariable := range envData {
   172  			if strings.HasPrefix(envVariable, "CONSUL") || strings.HasPrefix(envVariable, "CONNECT") {
   173  				continue
   174  			}
   175  			data = append(data, envVariable...)
   176  			data = append(data, "\n"...)
   177  		}
   178  		if err := ioutil.WriteFile(path, data, 0644); err != nil {
   179  			t.Fatalf("[Error] File write failed : %s", err)
   180  		}
   182  		// Clean up after we receive the signal to exit
   183  		defer os.Remove(path)
   185  		<-stop
   187  	case "output":
   188  		limitProcessLifetime(2 * time.Minute)
   190  		fmt.Fprintf(os.Stdout, "hello stdout\n")
   191  		fmt.Fprintf(os.Stderr, "hello stderr\n")
   193  		// Sync to be sure it is written out of buffers
   194  		os.Stdout.Sync()
   195  		os.Stderr.Sync()
   197  		// Output a file to signal we've written to stdout/err
   198  		path := args[0]
   199  		if err := ioutil.WriteFile(path, []byte("hello"), 0644); err != nil {
   200  			fmt.Fprintf(os.Stderr, "Error: %s\n", err)
   201  			os.Exit(1)
   202  		}
   204  		<-make(chan struct{})
   206  	// Parent runs the given process in a Daemon and then sleeps until the test
   207  	// code kills it. It exists to test that the Daemon-managed child process
   208  	// survives it's parent exiting which we can't test directly without exiting
   209  	// the test process so we need an extra level of indirection. The test code
   210  	// using this must pass a file path as the first argument for the child
   211  	// processes PID to be written and then must take care to clean up that PID
   212  	// later or the child will be left running forever.
   213  	//
   214  	// If the PID file already exists, it will "adopt" the child rather than
   215  	// launch a new one.
   216  	case "parent":
   217  		limitProcessLifetime(2 * time.Minute)
   219  		// We will write the PID for the child to the file in the first argument
   220  		// then pass rest of args through to command.
   221  		pidFile := args[0]
   223  		cmd, destroyChild := helperProcess(args[1:]...)
   224  		defer destroyChild()
   226  		d := &Daemon{
   227  			Command: cmd,
   228  			Logger:  testLogger,
   229  			PidPath: pidFile,
   230  		}
   232  		_, err := os.Stat(pidFile)
   233  		if err == nil {
   234  			// pidFile exists, read it and "adopt" the process
   235  			bs, err := ioutil.ReadFile(pidFile)
   236  			if err != nil {
   237  				log.Printf("Error: %s", err)
   238  				os.Exit(1)
   239  			}
   240  			pid, err := strconv.Atoi(string(bs))
   241  			if err != nil {
   242  				log.Printf("Error: %s", err)
   243  				os.Exit(1)
   244  			}
   245  			// Make a fake snapshot to load
   246  			snapshot := map[string]interface{}{
   247  				"Pid":         pid,
   248  				"CommandPath": d.Command.Path,
   249  				"CommandArgs": d.Command.Args,
   250  				"CommandDir":  d.Command.Dir,
   251  				"CommandEnv":  d.Command.Env,
   252  				"ProxyToken":  "",
   253  			}
   254  			d.UnmarshalSnapshot(snapshot)
   255  		}
   257  		if err := d.Start(); err != nil {
   258  			log.Printf("Error: %s", err)
   259  			os.Exit(1)
   260  		}
   261  		log.Println("Started child")
   263  		// Wait "forever" (calling test chooses when we exit with signal/Wait to
   264  		// minimize coordination).
   265  		for {
   266  			time.Sleep(time.Hour)
   267  		}
   269  	default:
   270  		fmt.Fprintf(os.Stderr, "Unknown command: %q\n", cmd)
   271  		os.Exit(2)
   272  	}
   273  }
   275  // limitProcessLifetime installs a background goroutine that self-exits after
   276  // the specified duration elapses to prevent leaking processes from tests that
   277  // may spawn them.
   278  func limitProcessLifetime(dur time.Duration) {
   279  	go time.AfterFunc(dur, func() {
   280  		os.Exit(99)
   281  	})
   282  }