github.com/bigcommerce/nomad@v0.9.3-bc/client/allocrunner/taskrunner/task_runner_test.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"os"
    10  	"path/filepath"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/golang/snappy"
    16  	"github.com/hashicorp/nomad/client/allocdir"
    17  	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
    18  	"github.com/hashicorp/nomad/client/config"
    19  	"github.com/hashicorp/nomad/client/consul"
    20  	consulapi "github.com/hashicorp/nomad/client/consul"
    21  	"github.com/hashicorp/nomad/client/devicemanager"
    22  	"github.com/hashicorp/nomad/client/pluginmanager/drivermanager"
    23  	cstate "github.com/hashicorp/nomad/client/state"
    24  	ctestutil "github.com/hashicorp/nomad/client/testutil"
    25  	"github.com/hashicorp/nomad/client/vaultclient"
    26  	agentconsul "github.com/hashicorp/nomad/command/agent/consul"
    27  	mockdriver "github.com/hashicorp/nomad/drivers/mock"
    28  	"github.com/hashicorp/nomad/drivers/rawexec"
    29  	"github.com/hashicorp/nomad/helper/testlog"
    30  	"github.com/hashicorp/nomad/nomad/mock"
    31  	"github.com/hashicorp/nomad/nomad/structs"
    32  	"github.com/hashicorp/nomad/plugins/device"
    33  	"github.com/hashicorp/nomad/plugins/drivers"
    34  	"github.com/hashicorp/nomad/testutil"
    35  	"github.com/kr/pretty"
    36  	"github.com/stretchr/testify/assert"
    37  	"github.com/stretchr/testify/require"
    38  )
    39  
    40  type MockTaskStateUpdater struct {
    41  	ch chan struct{}
    42  }
    43  
    44  func NewMockTaskStateUpdater() *MockTaskStateUpdater {
    45  	return &MockTaskStateUpdater{
    46  		ch: make(chan struct{}, 1),
    47  	}
    48  }
    49  
    50  func (m *MockTaskStateUpdater) TaskStateUpdated() {
    51  	select {
    52  	case m.ch <- struct{}{}:
    53  	default:
    54  	}
    55  }
    56  
    57  // testTaskRunnerConfig returns a taskrunner.Config for the given alloc+task
    58  // plus a cleanup func.
    59  func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName string) (*Config, func()) {
    60  	logger := testlog.HCLogger(t)
    61  	clientConf, cleanup := config.TestClientConfig(t)
    62  
    63  	// Find the task
    64  	var thisTask *structs.Task
    65  	for _, tg := range alloc.Job.TaskGroups {
    66  		for _, task := range tg.Tasks {
    67  			if task.Name == taskName {
    68  				if thisTask != nil {
    69  					cleanup()
    70  					t.Fatalf("multiple tasks named %q; cannot use this helper", taskName)
    71  				}
    72  				thisTask = task
    73  			}
    74  		}
    75  	}
    76  	if thisTask == nil {
    77  		cleanup()
    78  		t.Fatalf("could not find task %q", taskName)
    79  	}
    80  
    81  	// Create the alloc dir + task dir
    82  	allocPath := filepath.Join(clientConf.AllocDir, alloc.ID)
    83  	allocDir := allocdir.NewAllocDir(logger, allocPath)
    84  	if err := allocDir.Build(); err != nil {
    85  		cleanup()
    86  		t.Fatalf("error building alloc dir: %v", err)
    87  	}
    88  	taskDir := allocDir.NewTaskDir(taskName)
    89  
    90  	trCleanup := func() {
    91  		if err := allocDir.Destroy(); err != nil {
    92  			t.Logf("error destroying alloc dir: %v", err)
    93  		}
    94  		cleanup()
    95  	}
    96  
    97  	conf := &Config{
    98  		Alloc:              alloc,
    99  		ClientConfig:       clientConf,
   100  		Consul:             consulapi.NewMockConsulServiceClient(t, logger),
   101  		Task:               thisTask,
   102  		TaskDir:            taskDir,
   103  		Logger:             clientConf.Logger,
   104  		Vault:              vaultclient.NewMockVaultClient(),
   105  		StateDB:            cstate.NoopDB{},
   106  		StateUpdater:       NewMockTaskStateUpdater(),
   107  		DeviceManager:      devicemanager.NoopMockManager(),
   108  		DriverManager:      drivermanager.TestDriverManager(t),
   109  		ServersContactedCh: make(chan struct{}),
   110  	}
   111  	return conf, trCleanup
   112  }
   113  
   114  // runTestTaskRunner runs a TaskRunner and returns its configuration as well as
   115  // a cleanup function that ensures the runner is stopped and cleaned up. Tests
   116  // which need to change the Config *must* use testTaskRunnerConfig instead.
   117  func runTestTaskRunner(t *testing.T, alloc *structs.Allocation, taskName string) (*TaskRunner, *Config, func()) {
   118  	config, cleanup := testTaskRunnerConfig(t, alloc, taskName)
   119  
   120  	tr, err := NewTaskRunner(config)
   121  	require.NoError(t, err)
   122  	go tr.Run()
   123  
   124  	return tr, config, func() {
   125  		tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   126  		cleanup()
   127  	}
   128  }
   129  
   130  // TestTaskRunner_Restore_Running asserts restoring a running task does not
   131  // rerun the task.
   132  func TestTaskRunner_Restore_Running(t *testing.T) {
   133  	t.Parallel()
   134  	require := require.New(t)
   135  
   136  	alloc := mock.BatchAlloc()
   137  	alloc.Job.TaskGroups[0].Count = 1
   138  	task := alloc.Job.TaskGroups[0].Tasks[0]
   139  	task.Driver = "mock_driver"
   140  	task.Config = map[string]interface{}{
   141  		"run_for": "2s",
   142  	}
   143  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   144  	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
   145  	defer cleanup()
   146  
   147  	// Run the first TaskRunner
   148  	origTR, err := NewTaskRunner(conf)
   149  	require.NoError(err)
   150  	go origTR.Run()
   151  	defer origTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   152  
   153  	// Wait for it to be running
   154  	testWaitForTaskToStart(t, origTR)
   155  
   156  	// Cause TR to exit without shutting down task
   157  	origTR.Shutdown()
   158  
   159  	// Start a new TaskRunner and make sure it does not rerun the task
   160  	newTR, err := NewTaskRunner(conf)
   161  	require.NoError(err)
   162  
   163  	// Do the Restore
   164  	require.NoError(newTR.Restore())
   165  
   166  	go newTR.Run()
   167  	defer newTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   168  
   169  	// Wait for new task runner to exit when the process does
   170  	<-newTR.WaitCh()
   171  
   172  	// Assert that the process was only started once
   173  	started := 0
   174  	state := newTR.TaskState()
   175  	require.Equal(structs.TaskStateDead, state.State)
   176  	for _, ev := range state.Events {
   177  		if ev.Type == structs.TaskStarted {
   178  			started++
   179  		}
   180  	}
   181  	assert.Equal(t, 1, started)
   182  }
   183  
   184  // setupRestoreFailureTest starts a service, shuts down the task runner, and
   185  // kills the task before restarting a new TaskRunner. The new TaskRunner is
   186  // returned once it is running and waiting in pending along with a cleanup
   187  // func.
   188  func setupRestoreFailureTest(t *testing.T, alloc *structs.Allocation) (*TaskRunner, *Config, func()) {
   189  	t.Parallel()
   190  
   191  	task := alloc.Job.TaskGroups[0].Tasks[0]
   192  	task.Driver = "raw_exec"
   193  	task.Config = map[string]interface{}{
   194  		"command": "sleep",
   195  		"args":    []string{"30"},
   196  	}
   197  	conf, cleanup1 := testTaskRunnerConfig(t, alloc, task.Name)
   198  	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between runs
   199  
   200  	// Run the first TaskRunner
   201  	origTR, err := NewTaskRunner(conf)
   202  	require.NoError(t, err)
   203  	go origTR.Run()
   204  	cleanup2 := func() {
   205  		origTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   206  		cleanup1()
   207  	}
   208  
   209  	// Wait for it to be running
   210  	testWaitForTaskToStart(t, origTR)
   211  
   212  	handle := origTR.getDriverHandle()
   213  	require.NotNil(t, handle)
   214  	taskID := handle.taskID
   215  
   216  	// Cause TR to exit without shutting down task
   217  	origTR.Shutdown()
   218  
   219  	// Get the driver
   220  	driverPlugin, err := conf.DriverManager.Dispense(rawexec.PluginID.Name)
   221  	require.NoError(t, err)
   222  	rawexecDriver := driverPlugin.(*rawexec.Driver)
   223  
   224  	// Assert the task is still running despite TR having exited
   225  	taskStatus, err := rawexecDriver.InspectTask(taskID)
   226  	require.NoError(t, err)
   227  	require.Equal(t, drivers.TaskStateRunning, taskStatus.State)
   228  
   229  	// Kill the task so it fails to recover when restore is called
   230  	require.NoError(t, rawexecDriver.DestroyTask(taskID, true))
   231  	_, err = rawexecDriver.InspectTask(taskID)
   232  	require.EqualError(t, err, drivers.ErrTaskNotFound.Error())
   233  
   234  	// Create a new TaskRunner and Restore the task
   235  	conf.ServersContactedCh = make(chan struct{})
   236  	newTR, err := NewTaskRunner(conf)
   237  	require.NoError(t, err)
   238  
   239  	// Assert the TR will wait on servers because reattachment failed
   240  	require.NoError(t, newTR.Restore())
   241  	require.True(t, newTR.waitOnServers)
   242  
   243  	// Start new TR
   244  	go newTR.Run()
   245  	cleanup3 := func() {
   246  		newTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   247  		cleanup2()
   248  		cleanup1()
   249  	}
   250  
   251  	// Assert task has not been restarted
   252  	_, err = rawexecDriver.InspectTask(taskID)
   253  	require.EqualError(t, err, drivers.ErrTaskNotFound.Error())
   254  	ts := newTR.TaskState()
   255  	require.Equal(t, structs.TaskStatePending, ts.State)
   256  
   257  	return newTR, conf, cleanup3
   258  }
   259  
   260  // TestTaskRunner_Restore_Restart asserts restoring a dead task blocks until
   261  // MarkAlive is called. #1795
   262  func TestTaskRunner_Restore_Restart(t *testing.T) {
   263  	newTR, conf, cleanup := setupRestoreFailureTest(t, mock.Alloc())
   264  	defer cleanup()
   265  
   266  	// Fake contacting the server by closing the chan
   267  	close(conf.ServersContactedCh)
   268  
   269  	testutil.WaitForResult(func() (bool, error) {
   270  		ts := newTR.TaskState().State
   271  		return ts == structs.TaskStateRunning, fmt.Errorf("expected task to be running but found %q", ts)
   272  	}, func(err error) {
   273  		require.NoError(t, err)
   274  	})
   275  }
   276  
   277  // TestTaskRunner_Restore_Kill asserts restoring a dead task blocks until
   278  // the task is killed. #1795
   279  func TestTaskRunner_Restore_Kill(t *testing.T) {
   280  	newTR, _, cleanup := setupRestoreFailureTest(t, mock.Alloc())
   281  	defer cleanup()
   282  
   283  	// Sending the task a terminal update shouldn't kill it or unblock it
   284  	alloc := newTR.Alloc().Copy()
   285  	alloc.DesiredStatus = structs.AllocDesiredStatusStop
   286  	newTR.Update(alloc)
   287  
   288  	require.Equal(t, structs.TaskStatePending, newTR.TaskState().State)
   289  
   290  	// AllocRunner will immediately kill tasks after sending a terminal
   291  	// update.
   292  	newTR.Kill(context.Background(), structs.NewTaskEvent(structs.TaskKilling))
   293  
   294  	select {
   295  	case <-newTR.WaitCh():
   296  		// It died as expected!
   297  	case <-time.After(10 * time.Second):
   298  		require.Fail(t, "timeout waiting for task to die")
   299  	}
   300  }
   301  
   302  // TestTaskRunner_Restore_Update asserts restoring a dead task blocks until
   303  // Update is called. #1795
   304  func TestTaskRunner_Restore_Update(t *testing.T) {
   305  	newTR, conf, cleanup := setupRestoreFailureTest(t, mock.Alloc())
   306  	defer cleanup()
   307  
   308  	// Fake Client.runAllocs behavior by calling Update then closing chan
   309  	alloc := newTR.Alloc().Copy()
   310  	newTR.Update(alloc)
   311  
   312  	// Update alone should not unblock the test
   313  	require.Equal(t, structs.TaskStatePending, newTR.TaskState().State)
   314  
   315  	// Fake Client.runAllocs behavior of closing chan after Update
   316  	close(conf.ServersContactedCh)
   317  
   318  	testutil.WaitForResult(func() (bool, error) {
   319  		ts := newTR.TaskState().State
   320  		return ts == structs.TaskStateRunning, fmt.Errorf("expected task to be running but found %q", ts)
   321  	}, func(err error) {
   322  		require.NoError(t, err)
   323  	})
   324  }
   325  
   326  // TestTaskRunner_Restore_System asserts restoring a dead system task does not
   327  // block.
   328  func TestTaskRunner_Restore_System(t *testing.T) {
   329  	t.Parallel()
   330  
   331  	alloc := mock.Alloc()
   332  	alloc.Job.Type = structs.JobTypeSystem
   333  	task := alloc.Job.TaskGroups[0].Tasks[0]
   334  	task.Driver = "raw_exec"
   335  	task.Config = map[string]interface{}{
   336  		"command": "sleep",
   337  		"args":    []string{"30"},
   338  	}
   339  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   340  	defer cleanup()
   341  	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between runs
   342  
   343  	// Run the first TaskRunner
   344  	origTR, err := NewTaskRunner(conf)
   345  	require.NoError(t, err)
   346  	go origTR.Run()
   347  	defer origTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   348  
   349  	// Wait for it to be running
   350  	testWaitForTaskToStart(t, origTR)
   351  
   352  	handle := origTR.getDriverHandle()
   353  	require.NotNil(t, handle)
   354  	taskID := handle.taskID
   355  
   356  	// Cause TR to exit without shutting down task
   357  	origTR.Shutdown()
   358  
   359  	// Get the driver
   360  	driverPlugin, err := conf.DriverManager.Dispense(rawexec.PluginID.Name)
   361  	require.NoError(t, err)
   362  	rawexecDriver := driverPlugin.(*rawexec.Driver)
   363  
   364  	// Assert the task is still running despite TR having exited
   365  	taskStatus, err := rawexecDriver.InspectTask(taskID)
   366  	require.NoError(t, err)
   367  	require.Equal(t, drivers.TaskStateRunning, taskStatus.State)
   368  
   369  	// Kill the task so it fails to recover when restore is called
   370  	require.NoError(t, rawexecDriver.DestroyTask(taskID, true))
   371  	_, err = rawexecDriver.InspectTask(taskID)
   372  	require.EqualError(t, err, drivers.ErrTaskNotFound.Error())
   373  
   374  	// Create a new TaskRunner and Restore the task
   375  	conf.ServersContactedCh = make(chan struct{})
   376  	newTR, err := NewTaskRunner(conf)
   377  	require.NoError(t, err)
   378  
   379  	// Assert the TR will not wait on servers even though reattachment
   380  	// failed because it is a system task.
   381  	require.NoError(t, newTR.Restore())
   382  	require.False(t, newTR.waitOnServers)
   383  
   384  	// Nothing should have closed the chan
   385  	select {
   386  	case <-conf.ServersContactedCh:
   387  		require.Fail(t, "serversContactedCh was closed but should not have been")
   388  	default:
   389  	}
   390  
   391  	testutil.WaitForResult(func() (bool, error) {
   392  		ts := newTR.TaskState().State
   393  		return ts == structs.TaskStateRunning, fmt.Errorf("expected task to be running but found %q", ts)
   394  	}, func(err error) {
   395  		require.NoError(t, err)
   396  	})
   397  }
   398  
   399  // TestTaskRunner_TaskEnv_Interpolated asserts driver configurations are
   400  // interpolated.
   401  func TestTaskRunner_TaskEnv_Interpolated(t *testing.T) {
   402  	t.Parallel()
   403  	require := require.New(t)
   404  
   405  	alloc := mock.BatchAlloc()
   406  	alloc.Job.TaskGroups[0].Meta = map[string]string{
   407  		"common_user": "somebody",
   408  	}
   409  	task := alloc.Job.TaskGroups[0].Tasks[0]
   410  	task.Meta = map[string]string{
   411  		"foo": "bar",
   412  	}
   413  
   414  	// Use interpolation from both node attributes and meta vars
   415  	task.Config = map[string]interface{}{
   416  		"run_for":       "1ms",
   417  		"stdout_string": `${node.region} ${NOMAD_META_foo} ${NOMAD_META_common_user}`,
   418  	}
   419  
   420  	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
   421  	defer cleanup()
   422  
   423  	// Wait for task to complete
   424  	select {
   425  	case <-tr.WaitCh():
   426  	case <-time.After(3 * time.Second):
   427  		require.Fail("timeout waiting for task to exit")
   428  	}
   429  
   430  	// Get the mock driver plugin
   431  	driverPlugin, err := conf.DriverManager.Dispense(mockdriver.PluginID.Name)
   432  	require.NoError(err)
   433  	mockDriver := driverPlugin.(*mockdriver.Driver)
   434  
   435  	// Assert its config has been properly interpolated
   436  	driverCfg, mockCfg := mockDriver.GetTaskConfig()
   437  	require.NotNil(driverCfg)
   438  	require.NotNil(mockCfg)
   439  	assert.Equal(t, "global bar somebody", mockCfg.StdoutString)
   440  }
   441  
   442  // TestTaskRunner_TaskEnv_Chroot asserts chroot drivers use chroot paths and
   443  // not host paths.
   444  func TestTaskRunner_TaskEnv_Chroot(t *testing.T) {
   445  	ctestutil.ExecCompatible(t)
   446  	t.Parallel()
   447  	require := require.New(t)
   448  
   449  	alloc := mock.BatchAlloc()
   450  	task := alloc.Job.TaskGroups[0].Tasks[0]
   451  	task.Driver = "exec"
   452  	task.Config = map[string]interface{}{
   453  		"command": "bash",
   454  		"args": []string{"-c", "echo $NOMAD_ALLOC_DIR; " +
   455  			"echo $NOMAD_TASK_DIR; " +
   456  			"echo $NOMAD_SECRETS_DIR; " +
   457  			"echo $PATH; ",
   458  		},
   459  	}
   460  
   461  	// Expect chroot paths and host $PATH
   462  	exp := fmt.Sprintf(`/alloc
   463  /local
   464  /secrets
   465  %s
   466  `, os.Getenv("PATH"))
   467  
   468  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   469  	defer cleanup()
   470  
   471  	// Remove /sbin and /usr from chroot
   472  	conf.ClientConfig.ChrootEnv = map[string]string{
   473  		"/bin":            "/bin",
   474  		"/etc":            "/etc",
   475  		"/lib":            "/lib",
   476  		"/lib32":          "/lib32",
   477  		"/lib64":          "/lib64",
   478  		"/run/resolvconf": "/run/resolvconf",
   479  	}
   480  
   481  	tr, err := NewTaskRunner(conf)
   482  	require.NoError(err)
   483  	go tr.Run()
   484  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   485  
   486  	// Wait for task to exit
   487  	select {
   488  	case <-tr.WaitCh():
   489  	case <-time.After(15 * time.Second):
   490  		require.Fail("timeout waiting for task to exit")
   491  	}
   492  
   493  	// Read stdout
   494  	p := filepath.Join(conf.TaskDir.LogDir, task.Name+".stdout.0")
   495  	stdout, err := ioutil.ReadFile(p)
   496  	require.NoError(err)
   497  	require.Equalf(exp, string(stdout), "expected: %s\n\nactual: %s\n", exp, stdout)
   498  }
   499  
   500  // TestTaskRunner_TaskEnv_Image asserts image drivers use chroot paths and
   501  // not host paths. Host env vars should also be excluded.
   502  func TestTaskRunner_TaskEnv_Image(t *testing.T) {
   503  	ctestutil.DockerCompatible(t)
   504  	t.Parallel()
   505  	require := require.New(t)
   506  
   507  	alloc := mock.BatchAlloc()
   508  	task := alloc.Job.TaskGroups[0].Tasks[0]
   509  	task.Driver = "docker"
   510  	task.Config = map[string]interface{}{
   511  		"image":        "redis:3.2-alpine",
   512  		"network_mode": "none",
   513  		"command":      "sh",
   514  		"args": []string{"-c", "echo $NOMAD_ALLOC_DIR; " +
   515  			"echo $NOMAD_TASK_DIR; " +
   516  			"echo $NOMAD_SECRETS_DIR; " +
   517  			"echo $PATH",
   518  		},
   519  	}
   520  
   521  	// Expect chroot paths and image specific PATH
   522  	exp := `/alloc
   523  /local
   524  /secrets
   525  /usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
   526  `
   527  
   528  	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
   529  	defer cleanup()
   530  
   531  	// Wait for task to exit
   532  	select {
   533  	case <-tr.WaitCh():
   534  	case <-time.After(15 * time.Second):
   535  		require.Fail("timeout waiting for task to exit")
   536  	}
   537  
   538  	// Read stdout
   539  	p := filepath.Join(conf.TaskDir.LogDir, task.Name+".stdout.0")
   540  	stdout, err := ioutil.ReadFile(p)
   541  	require.NoError(err)
   542  	require.Equalf(exp, string(stdout), "expected: %s\n\nactual: %s\n", exp, stdout)
   543  }
   544  
   545  // TestTaskRunner_TaskEnv_None asserts raw_exec uses host paths and env vars.
   546  func TestTaskRunner_TaskEnv_None(t *testing.T) {
   547  	t.Parallel()
   548  	require := require.New(t)
   549  
   550  	alloc := mock.BatchAlloc()
   551  	task := alloc.Job.TaskGroups[0].Tasks[0]
   552  	task.Driver = "raw_exec"
   553  	task.Config = map[string]interface{}{
   554  		"command": "sh",
   555  		"args": []string{"-c", "echo $NOMAD_ALLOC_DIR; " +
   556  			"echo $NOMAD_TASK_DIR; " +
   557  			"echo $NOMAD_SECRETS_DIR; " +
   558  			"echo $PATH",
   559  		},
   560  	}
   561  
   562  	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
   563  	defer cleanup()
   564  
   565  	// Expect host paths
   566  	root := filepath.Join(conf.ClientConfig.AllocDir, alloc.ID)
   567  	taskDir := filepath.Join(root, task.Name)
   568  	exp := fmt.Sprintf(`%s/alloc
   569  %s/local
   570  %s/secrets
   571  %s
   572  `, root, taskDir, taskDir, os.Getenv("PATH"))
   573  
   574  	// Wait for task to exit
   575  	select {
   576  	case <-tr.WaitCh():
   577  	case <-time.After(15 * time.Second):
   578  		require.Fail("timeout waiting for task to exit")
   579  	}
   580  
   581  	// Read stdout
   582  	p := filepath.Join(conf.TaskDir.LogDir, task.Name+".stdout.0")
   583  	stdout, err := ioutil.ReadFile(p)
   584  	require.NoError(err)
   585  	require.Equalf(exp, string(stdout), "expected: %s\n\nactual: %s\n", exp, stdout)
   586  }
   587  
   588  // Test that devices get sent to the driver
   589  func TestTaskRunner_DevicePropogation(t *testing.T) {
   590  	t.Parallel()
   591  	require := require.New(t)
   592  
   593  	// Create a mock alloc that has a gpu
   594  	alloc := mock.BatchAlloc()
   595  	alloc.Job.TaskGroups[0].Count = 1
   596  	task := alloc.Job.TaskGroups[0].Tasks[0]
   597  	task.Driver = "mock_driver"
   598  	task.Config = map[string]interface{}{
   599  		"run_for": "100ms",
   600  	}
   601  	tRes := alloc.AllocatedResources.Tasks[task.Name]
   602  	tRes.Devices = append(tRes.Devices, &structs.AllocatedDeviceResource{Type: "mock"})
   603  
   604  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   605  	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
   606  	defer cleanup()
   607  
   608  	// Setup the devicemanager
   609  	dm, ok := conf.DeviceManager.(*devicemanager.MockManager)
   610  	require.True(ok)
   611  
   612  	dm.ReserveF = func(d *structs.AllocatedDeviceResource) (*device.ContainerReservation, error) {
   613  		res := &device.ContainerReservation{
   614  			Envs: map[string]string{
   615  				"ABC": "123",
   616  			},
   617  			Mounts: []*device.Mount{
   618  				{
   619  					ReadOnly: true,
   620  					TaskPath: "foo",
   621  					HostPath: "bar",
   622  				},
   623  			},
   624  			Devices: []*device.DeviceSpec{
   625  				{
   626  					TaskPath:    "foo",
   627  					HostPath:    "bar",
   628  					CgroupPerms: "123",
   629  				},
   630  			},
   631  		}
   632  		return res, nil
   633  	}
   634  
   635  	// Run the TaskRunner
   636  	tr, err := NewTaskRunner(conf)
   637  	require.NoError(err)
   638  	go tr.Run()
   639  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   640  
   641  	// Wait for task to complete
   642  	select {
   643  	case <-tr.WaitCh():
   644  	case <-time.After(3 * time.Second):
   645  	}
   646  
   647  	// Get the mock driver plugin
   648  	driverPlugin, err := conf.DriverManager.Dispense(mockdriver.PluginID.Name)
   649  	require.NoError(err)
   650  	mockDriver := driverPlugin.(*mockdriver.Driver)
   651  
   652  	// Assert its config has been properly interpolated
   653  	driverCfg, _ := mockDriver.GetTaskConfig()
   654  	require.NotNil(driverCfg)
   655  	require.Len(driverCfg.Devices, 1)
   656  	require.Equal(driverCfg.Devices[0].Permissions, "123")
   657  	require.Len(driverCfg.Mounts, 1)
   658  	require.Equal(driverCfg.Mounts[0].TaskPath, "foo")
   659  	require.Contains(driverCfg.Env, "ABC")
   660  }
   661  
   662  // mockEnvHook is a test hook that sets an env var and done=true. It fails if
   663  // it's called more than once.
   664  type mockEnvHook struct {
   665  	called int
   666  }
   667  
   668  func (*mockEnvHook) Name() string {
   669  	return "mock_env_hook"
   670  }
   671  
   672  func (h *mockEnvHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error {
   673  	h.called++
   674  
   675  	resp.Done = true
   676  	resp.Env = map[string]string{
   677  		"mock_hook": "1",
   678  	}
   679  
   680  	return nil
   681  }
   682  
   683  // TestTaskRunner_Restore_HookEnv asserts that re-running prestart hooks with
   684  // hook environments set restores the environment without re-running done
   685  // hooks.
   686  func TestTaskRunner_Restore_HookEnv(t *testing.T) {
   687  	t.Parallel()
   688  	require := require.New(t)
   689  
   690  	alloc := mock.BatchAlloc()
   691  	task := alloc.Job.TaskGroups[0].Tasks[0]
   692  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   693  	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between prestart calls
   694  	defer cleanup()
   695  
   696  	tr, err := NewTaskRunner(conf)
   697  	require.NoError(err)
   698  
   699  	// Override the default hooks to only run the mock hook
   700  	mockHook := &mockEnvHook{}
   701  	tr.runnerHooks = []interfaces.TaskHook{mockHook}
   702  
   703  	// Manually run prestart hooks
   704  	require.NoError(tr.prestart())
   705  
   706  	// Assert env was called
   707  	require.Equal(1, mockHook.called)
   708  
   709  	// Re-running prestart hooks should *not* call done mock hook
   710  	require.NoError(tr.prestart())
   711  
   712  	// Assert env was called
   713  	require.Equal(1, mockHook.called)
   714  
   715  	// Assert the env is still set
   716  	env := tr.envBuilder.Build().All()
   717  	require.Contains(env, "mock_hook")
   718  	require.Equal("1", env["mock_hook"])
   719  }
   720  
   721  // This test asserts that we can recover from an "external" plugin exiting by
   722  // retrieving a new instance of the driver and recovering the task.
   723  func TestTaskRunner_RecoverFromDriverExiting(t *testing.T) {
   724  	t.Parallel()
   725  	require := require.New(t)
   726  
   727  	// Create an allocation using the mock driver that exits simulating the
   728  	// driver crashing. We can then test that the task runner recovers from this
   729  	alloc := mock.BatchAlloc()
   730  	task := alloc.Job.TaskGroups[0].Tasks[0]
   731  	task.Driver = "mock_driver"
   732  	task.Config = map[string]interface{}{
   733  		"plugin_exit_after": "1s",
   734  		"run_for":           "5s",
   735  	}
   736  
   737  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   738  	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between prestart calls
   739  	defer cleanup()
   740  
   741  	tr, err := NewTaskRunner(conf)
   742  	require.NoError(err)
   743  
   744  	start := time.Now()
   745  	go tr.Run()
   746  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   747  
   748  	// Wait for the task to be running
   749  	testWaitForTaskToStart(t, tr)
   750  
   751  	// Get the task ID
   752  	tr.stateLock.RLock()
   753  	l := tr.localState.TaskHandle
   754  	require.NotNil(l)
   755  	require.NotNil(l.Config)
   756  	require.NotEmpty(l.Config.ID)
   757  	id := l.Config.ID
   758  	tr.stateLock.RUnlock()
   759  
   760  	// Get the mock driver plugin
   761  	driverPlugin, err := conf.DriverManager.Dispense(mockdriver.PluginID.Name)
   762  	require.NoError(err)
   763  	mockDriver := driverPlugin.(*mockdriver.Driver)
   764  
   765  	// Wait for the task to start
   766  	testutil.WaitForResult(func() (bool, error) {
   767  		// Get the handle and check that it was recovered
   768  		handle := mockDriver.GetHandle(id)
   769  		if handle == nil {
   770  			return false, fmt.Errorf("nil handle")
   771  		}
   772  		if !handle.Recovered {
   773  			return false, fmt.Errorf("handle not recovered")
   774  		}
   775  		return true, nil
   776  	}, func(err error) {
   777  		t.Fatal(err.Error())
   778  	})
   779  
   780  	// Wait for task to complete
   781  	select {
   782  	case <-tr.WaitCh():
   783  	case <-time.After(10 * time.Second):
   784  	}
   785  
   786  	// Ensure that we actually let the task complete
   787  	require.True(time.Now().Sub(start) > 5*time.Second)
   788  
   789  	// Check it finished successfully
   790  	state := tr.TaskState()
   791  	require.True(state.Successful())
   792  }
   793  
   794  // TestTaskRunner_ShutdownDelay asserts services are removed from Consul
   795  // ${shutdown_delay} seconds before killing the process.
   796  func TestTaskRunner_ShutdownDelay(t *testing.T) {
   797  	t.Parallel()
   798  
   799  	alloc := mock.Alloc()
   800  	task := alloc.Job.TaskGroups[0].Tasks[0]
   801  	task.Services[0].Tags = []string{"tag1"}
   802  	task.Services = task.Services[:1] // only need 1 for this test
   803  	task.Driver = "mock_driver"
   804  	task.Config = map[string]interface{}{
   805  		"run_for": "1000s",
   806  	}
   807  
   808  	// No shutdown escape hatch for this delay, so don't set it too high
   809  	task.ShutdownDelay = 1000 * time.Duration(testutil.TestMultiplier()) * time.Millisecond
   810  
   811  	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
   812  	defer cleanup()
   813  
   814  	mockConsul := conf.Consul.(*consul.MockConsulServiceClient)
   815  
   816  	// Wait for the task to start
   817  	testWaitForTaskToStart(t, tr)
   818  
   819  	testutil.WaitForResult(func() (bool, error) {
   820  		ops := mockConsul.GetOps()
   821  		if n := len(ops); n != 1 {
   822  			return false, fmt.Errorf("expected 1 consul operation. Found %d", n)
   823  		}
   824  		return ops[0].Op == "add", fmt.Errorf("consul operation was not a registration: %#v", ops[0])
   825  	}, func(err error) {
   826  		t.Fatalf("err: %v", err)
   827  	})
   828  
   829  	// Asynchronously kill task
   830  	killSent := time.Now()
   831  	killed := make(chan struct{})
   832  	go func() {
   833  		defer close(killed)
   834  		assert.NoError(t, tr.Kill(context.Background(), structs.NewTaskEvent("test")))
   835  	}()
   836  
   837  	// Wait for *2* deregistration calls (due to needing to remove both
   838  	// canary tag variants)
   839  WAIT:
   840  	for {
   841  		ops := mockConsul.GetOps()
   842  		switch n := len(ops); n {
   843  		case 1, 2:
   844  			// Waiting for both deregistration calls
   845  		case 3:
   846  			require.Equalf(t, "remove", ops[1].Op, "expected deregistration but found: %#v", ops[1])
   847  			require.Equalf(t, "remove", ops[2].Op, "expected deregistration but found: %#v", ops[2])
   848  			break WAIT
   849  		default:
   850  			// ?!
   851  			t.Fatalf("unexpected number of consul operations: %d\n%s", n, pretty.Sprint(ops))
   852  
   853  		}
   854  
   855  		select {
   856  		case <-killed:
   857  			t.Fatal("killed while service still registered")
   858  		case <-time.After(10 * time.Millisecond):
   859  		}
   860  	}
   861  
   862  	// Wait for actual exit
   863  	select {
   864  	case <-tr.WaitCh():
   865  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
   866  		t.Fatalf("timeout")
   867  	}
   868  
   869  	<-killed
   870  	killDur := time.Now().Sub(killSent)
   871  	if killDur < task.ShutdownDelay {
   872  		t.Fatalf("task killed before shutdown_delay (killed_after: %s; shutdown_delay: %s",
   873  			killDur, task.ShutdownDelay,
   874  		)
   875  	}
   876  }
   877  
   878  // TestTaskRunner_Dispatch_Payload asserts that a dispatch job runs and the
   879  // payload was written to disk.
   880  func TestTaskRunner_Dispatch_Payload(t *testing.T) {
   881  	t.Parallel()
   882  
   883  	alloc := mock.BatchAlloc()
   884  	task := alloc.Job.TaskGroups[0].Tasks[0]
   885  	task.Driver = "mock_driver"
   886  	task.Config = map[string]interface{}{
   887  		"run_for": "1s",
   888  	}
   889  
   890  	fileName := "test"
   891  	task.DispatchPayload = &structs.DispatchPayloadConfig{
   892  		File: fileName,
   893  	}
   894  	alloc.Job.ParameterizedJob = &structs.ParameterizedJobConfig{}
   895  
   896  	// Add a payload (they're snappy encoded bytes)
   897  	expected := []byte("hello world")
   898  	compressed := snappy.Encode(nil, expected)
   899  	alloc.Job.Payload = compressed
   900  
   901  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
   902  	defer cleanup()
   903  
   904  	// Wait for it to finish
   905  	testutil.WaitForResult(func() (bool, error) {
   906  		ts := tr.TaskState()
   907  		return ts.State == structs.TaskStateDead, fmt.Errorf("%v", ts.State)
   908  	}, func(err error) {
   909  		require.NoError(t, err)
   910  	})
   911  
   912  	// Should have exited successfully
   913  	ts := tr.TaskState()
   914  	require.False(t, ts.Failed)
   915  	require.Zero(t, ts.Restarts)
   916  
   917  	// Check that the file was written to disk properly
   918  	payloadPath := filepath.Join(tr.taskDir.LocalDir, fileName)
   919  	data, err := ioutil.ReadFile(payloadPath)
   920  	require.NoError(t, err)
   921  	require.Equal(t, expected, data)
   922  }
   923  
   924  // TestTaskRunner_SignalFailure asserts that signal errors are properly
   925  // propagated from the driver to TaskRunner.
   926  func TestTaskRunner_SignalFailure(t *testing.T) {
   927  	t.Parallel()
   928  
   929  	alloc := mock.Alloc()
   930  	task := alloc.Job.TaskGroups[0].Tasks[0]
   931  	task.Driver = "mock_driver"
   932  	errMsg := "test forcing failure"
   933  	task.Config = map[string]interface{}{
   934  		"run_for":      "10m",
   935  		"signal_error": errMsg,
   936  	}
   937  
   938  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
   939  	defer cleanup()
   940  
   941  	testWaitForTaskToStart(t, tr)
   942  
   943  	require.EqualError(t, tr.Signal(&structs.TaskEvent{}, "SIGINT"), errMsg)
   944  }
   945  
   946  // TestTaskRunner_RestartTask asserts that restarting a task works and emits a
   947  // Restarting event.
   948  func TestTaskRunner_RestartTask(t *testing.T) {
   949  	t.Parallel()
   950  
   951  	alloc := mock.Alloc()
   952  	task := alloc.Job.TaskGroups[0].Tasks[0]
   953  	task.Driver = "mock_driver"
   954  	task.Config = map[string]interface{}{
   955  		"run_for": "10m",
   956  	}
   957  
   958  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
   959  	defer cleanup()
   960  
   961  	testWaitForTaskToStart(t, tr)
   962  
   963  	// Restart task. Send a RestartSignal event like check watcher. Restart
   964  	// handler emits the Restarting event.
   965  	event := structs.NewTaskEvent(structs.TaskRestartSignal).SetRestartReason("test")
   966  	const fail = false
   967  	tr.Restart(context.Background(), event.Copy(), fail)
   968  
   969  	// Wait for it to restart and be running again
   970  	testutil.WaitForResult(func() (bool, error) {
   971  		ts := tr.TaskState()
   972  		if ts.Restarts != 1 {
   973  			return false, fmt.Errorf("expected 1 restart but found %d\nevents: %s",
   974  				ts.Restarts, pretty.Sprint(ts.Events))
   975  		}
   976  		if ts.State != structs.TaskStateRunning {
   977  			return false, fmt.Errorf("expected running but received %s", ts.State)
   978  		}
   979  		return true, nil
   980  	}, func(err error) {
   981  		require.NoError(t, err)
   982  	})
   983  
   984  	// Assert the expected Restarting event was emitted
   985  	found := false
   986  	events := tr.TaskState().Events
   987  	for _, e := range events {
   988  		if e.Type == structs.TaskRestartSignal {
   989  			found = true
   990  			require.Equal(t, event.Time, e.Time)
   991  			require.Equal(t, event.RestartReason, e.RestartReason)
   992  			require.Contains(t, e.DisplayMessage, event.RestartReason)
   993  		}
   994  	}
   995  	require.True(t, found, "restarting task event not found", pretty.Sprint(events))
   996  }
   997  
   998  // TestTaskRunner_CheckWatcher_Restart asserts that when enabled an unhealthy
   999  // Consul check will cause a task to restart following restart policy rules.
  1000  func TestTaskRunner_CheckWatcher_Restart(t *testing.T) {
  1001  	t.Parallel()
  1002  
  1003  	alloc := mock.Alloc()
  1004  
  1005  	// Make the restart policy fail within this test
  1006  	tg := alloc.Job.TaskGroups[0]
  1007  	tg.RestartPolicy.Attempts = 2
  1008  	tg.RestartPolicy.Interval = 1 * time.Minute
  1009  	tg.RestartPolicy.Delay = 10 * time.Millisecond
  1010  	tg.RestartPolicy.Mode = structs.RestartPolicyModeFail
  1011  
  1012  	task := tg.Tasks[0]
  1013  	task.Driver = "mock_driver"
  1014  	task.Config = map[string]interface{}{
  1015  		"run_for": "10m",
  1016  	}
  1017  
  1018  	// Make the task register a check that fails
  1019  	task.Services[0].Checks[0] = &structs.ServiceCheck{
  1020  		Name:     "test-restarts",
  1021  		Type:     structs.ServiceCheckTCP,
  1022  		Interval: 50 * time.Millisecond,
  1023  		CheckRestart: &structs.CheckRestart{
  1024  			Limit: 2,
  1025  			Grace: 100 * time.Millisecond,
  1026  		},
  1027  	}
  1028  
  1029  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1030  	defer cleanup()
  1031  
  1032  	// Replace mock Consul ServiceClient, with the real ServiceClient
  1033  	// backed by a mock consul whose checks are always unhealthy.
  1034  	consulAgent := agentconsul.NewMockAgent()
  1035  	consulAgent.SetStatus("critical")
  1036  	consulClient := agentconsul.NewServiceClient(consulAgent, conf.Logger, true)
  1037  	go consulClient.Run()
  1038  	defer consulClient.Shutdown()
  1039  
  1040  	conf.Consul = consulClient
  1041  
  1042  	tr, err := NewTaskRunner(conf)
  1043  	require.NoError(t, err)
  1044  
  1045  	expectedEvents := []string{
  1046  		"Received",
  1047  		"Task Setup",
  1048  		"Started",
  1049  		"Restart Signaled",
  1050  		"Terminated",
  1051  		"Restarting",
  1052  		"Started",
  1053  		"Restart Signaled",
  1054  		"Terminated",
  1055  		"Restarting",
  1056  		"Started",
  1057  		"Restart Signaled",
  1058  		"Terminated",
  1059  		"Not Restarting",
  1060  	}
  1061  
  1062  	// Bump maxEvents so task events aren't dropped
  1063  	tr.maxEvents = 100
  1064  
  1065  	go tr.Run()
  1066  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1067  
  1068  	// Wait until the task exits. Don't simply wait for it to run as it may
  1069  	// get restarted and terminated before the test is able to observe it
  1070  	// running.
  1071  	select {
  1072  	case <-tr.WaitCh():
  1073  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1074  		require.Fail(t, "timeout")
  1075  	}
  1076  
  1077  	state := tr.TaskState()
  1078  	actualEvents := make([]string, len(state.Events))
  1079  	for i, e := range state.Events {
  1080  		actualEvents[i] = string(e.Type)
  1081  	}
  1082  	require.Equal(t, actualEvents, expectedEvents)
  1083  
  1084  	require.Equal(t, structs.TaskStateDead, state.State)
  1085  	require.True(t, state.Failed, pretty.Sprint(state))
  1086  }
  1087  
  1088  // TestTaskRunner_BlockForVault asserts tasks do not start until a vault token
  1089  // is derived.
  1090  func TestTaskRunner_BlockForVault(t *testing.T) {
  1091  	t.Parallel()
  1092  
  1093  	alloc := mock.BatchAlloc()
  1094  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1095  	task.Config = map[string]interface{}{
  1096  		"run_for": "0s",
  1097  	}
  1098  	task.Vault = &structs.Vault{Policies: []string{"default"}}
  1099  
  1100  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1101  	defer cleanup()
  1102  
  1103  	// Control when we get a Vault token
  1104  	token := "1234"
  1105  	waitCh := make(chan struct{})
  1106  	handler := func(*structs.Allocation, []string) (map[string]string, error) {
  1107  		<-waitCh
  1108  		return map[string]string{task.Name: token}, nil
  1109  	}
  1110  	vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
  1111  	vaultClient.DeriveTokenFn = handler
  1112  
  1113  	tr, err := NewTaskRunner(conf)
  1114  	require.NoError(t, err)
  1115  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1116  	go tr.Run()
  1117  
  1118  	// Assert TR blocks on vault token (does *not* exit)
  1119  	select {
  1120  	case <-tr.WaitCh():
  1121  		require.Fail(t, "tr exited before vault unblocked")
  1122  	case <-time.After(1 * time.Second):
  1123  	}
  1124  
  1125  	// Assert task state is still Pending
  1126  	require.Equal(t, structs.TaskStatePending, tr.TaskState().State)
  1127  
  1128  	// Unblock vault token
  1129  	close(waitCh)
  1130  
  1131  	// TR should exit now that it's unblocked by vault as its a batch job
  1132  	// with 0 sleeping.
  1133  	select {
  1134  	case <-tr.WaitCh():
  1135  	case <-time.After(15 * time.Second * time.Duration(testutil.TestMultiplier())):
  1136  		require.Fail(t, "timed out waiting for batch task to exit")
  1137  	}
  1138  
  1139  	// Assert task exited successfully
  1140  	finalState := tr.TaskState()
  1141  	require.Equal(t, structs.TaskStateDead, finalState.State)
  1142  	require.False(t, finalState.Failed)
  1143  
  1144  	// Check that the token is on disk
  1145  	tokenPath := filepath.Join(conf.TaskDir.SecretsDir, vaultTokenFile)
  1146  	data, err := ioutil.ReadFile(tokenPath)
  1147  	require.NoError(t, err)
  1148  	require.Equal(t, token, string(data))
  1149  
  1150  	// Check the token was revoked
  1151  	testutil.WaitForResult(func() (bool, error) {
  1152  		if len(vaultClient.StoppedTokens()) != 1 {
  1153  			return false, fmt.Errorf("Expected a stopped token %q but found: %v", token, vaultClient.StoppedTokens())
  1154  		}
  1155  
  1156  		if a := vaultClient.StoppedTokens()[0]; a != token {
  1157  			return false, fmt.Errorf("got stopped token %q; want %q", a, token)
  1158  		}
  1159  		return true, nil
  1160  	}, func(err error) {
  1161  		require.Fail(t, err.Error())
  1162  	})
  1163  }
  1164  
  1165  // TestTaskRunner_DeriveToken_Retry asserts that if a recoverable error is
  1166  // returned when deriving a vault token a task will continue to block while
  1167  // it's retried.
  1168  func TestTaskRunner_DeriveToken_Retry(t *testing.T) {
  1169  	t.Parallel()
  1170  	alloc := mock.BatchAlloc()
  1171  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1172  	task.Vault = &structs.Vault{Policies: []string{"default"}}
  1173  
  1174  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1175  	defer cleanup()
  1176  
  1177  	// Fail on the first attempt to derive a vault token
  1178  	token := "1234"
  1179  	count := 0
  1180  	handler := func(*structs.Allocation, []string) (map[string]string, error) {
  1181  		if count > 0 {
  1182  			return map[string]string{task.Name: token}, nil
  1183  		}
  1184  
  1185  		count++
  1186  		return nil, structs.NewRecoverableError(fmt.Errorf("Want a retry"), true)
  1187  	}
  1188  	vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
  1189  	vaultClient.DeriveTokenFn = handler
  1190  
  1191  	tr, err := NewTaskRunner(conf)
  1192  	require.NoError(t, err)
  1193  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1194  	go tr.Run()
  1195  
  1196  	// Wait for TR to exit and check its state
  1197  	select {
  1198  	case <-tr.WaitCh():
  1199  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1200  		require.Fail(t, "timed out waiting for task runner to exit")
  1201  	}
  1202  
  1203  	state := tr.TaskState()
  1204  	require.Equal(t, structs.TaskStateDead, state.State)
  1205  	require.False(t, state.Failed)
  1206  
  1207  	require.Equal(t, 1, count)
  1208  
  1209  	// Check that the token is on disk
  1210  	tokenPath := filepath.Join(conf.TaskDir.SecretsDir, vaultTokenFile)
  1211  	data, err := ioutil.ReadFile(tokenPath)
  1212  	require.NoError(t, err)
  1213  	require.Equal(t, token, string(data))
  1214  
  1215  	// Check the token was revoked
  1216  	testutil.WaitForResult(func() (bool, error) {
  1217  		if len(vaultClient.StoppedTokens()) != 1 {
  1218  			return false, fmt.Errorf("Expected a stopped token: %v", vaultClient.StoppedTokens())
  1219  		}
  1220  
  1221  		if a := vaultClient.StoppedTokens()[0]; a != token {
  1222  			return false, fmt.Errorf("got stopped token %q; want %q", a, token)
  1223  		}
  1224  		return true, nil
  1225  	}, func(err error) {
  1226  		require.Fail(t, err.Error())
  1227  	})
  1228  }
  1229  
  1230  // TestTaskRunner_DeriveToken_Unrecoverable asserts that an unrecoverable error
  1231  // from deriving a vault token will fail a task.
  1232  func TestTaskRunner_DeriveToken_Unrecoverable(t *testing.T) {
  1233  	t.Parallel()
  1234  
  1235  	// Use a batch job with no restarts
  1236  	alloc := mock.BatchAlloc()
  1237  	tg := alloc.Job.TaskGroups[0]
  1238  	tg.RestartPolicy.Attempts = 0
  1239  	tg.RestartPolicy.Interval = 0
  1240  	tg.RestartPolicy.Delay = 0
  1241  	tg.RestartPolicy.Mode = structs.RestartPolicyModeFail
  1242  	task := tg.Tasks[0]
  1243  	task.Config = map[string]interface{}{
  1244  		"run_for": "0s",
  1245  	}
  1246  	task.Vault = &structs.Vault{Policies: []string{"default"}}
  1247  
  1248  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1249  	defer cleanup()
  1250  
  1251  	// Error the token derivation
  1252  	vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
  1253  	vaultClient.SetDeriveTokenError(alloc.ID, []string{task.Name}, fmt.Errorf("Non recoverable"))
  1254  
  1255  	tr, err := NewTaskRunner(conf)
  1256  	require.NoError(t, err)
  1257  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1258  	go tr.Run()
  1259  
  1260  	// Wait for the task to die
  1261  	select {
  1262  	case <-tr.WaitCh():
  1263  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1264  		require.Fail(t, "timed out waiting for task runner to fail")
  1265  	}
  1266  
  1267  	// Task should be dead and last event should have failed task
  1268  	state := tr.TaskState()
  1269  	require.Equal(t, structs.TaskStateDead, state.State)
  1270  	require.True(t, state.Failed)
  1271  	require.Len(t, state.Events, 3)
  1272  	require.True(t, state.Events[2].FailsTask)
  1273  }
  1274  
  1275  // TestTaskRunner_Download_ChrootExec asserts that downloaded artifacts may be
  1276  // executed in a chroot.
  1277  func TestTaskRunner_Download_ChrootExec(t *testing.T) {
  1278  	t.Parallel()
  1279  	ctestutil.ExecCompatible(t)
  1280  
  1281  	ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("."))))
  1282  	defer ts.Close()
  1283  
  1284  	// Create a task that downloads a script and executes it.
  1285  	alloc := mock.BatchAlloc()
  1286  	alloc.Job.TaskGroups[0].RestartPolicy = &structs.RestartPolicy{}
  1287  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1288  	task.Driver = "exec"
  1289  	task.Config = map[string]interface{}{
  1290  		"command": "noop.sh",
  1291  	}
  1292  	task.Artifacts = []*structs.TaskArtifact{
  1293  		{
  1294  			GetterSource: fmt.Sprintf("%s/testdata/noop.sh", ts.URL),
  1295  			GetterMode:   "file",
  1296  			RelativeDest: "noop.sh",
  1297  		},
  1298  	}
  1299  
  1300  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
  1301  	defer cleanup()
  1302  
  1303  	// Wait for task to run and exit
  1304  	select {
  1305  	case <-tr.WaitCh():
  1306  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1307  		require.Fail(t, "timed out waiting for task runner to exit")
  1308  	}
  1309  
  1310  	state := tr.TaskState()
  1311  	require.Equal(t, structs.TaskStateDead, state.State)
  1312  	require.False(t, state.Failed)
  1313  }
  1314  
  1315  // TestTaskRunner_Download_Exec asserts that downloaded artifacts may be
  1316  // executed in a driver without filesystem isolation.
  1317  func TestTaskRunner_Download_RawExec(t *testing.T) {
  1318  	t.Parallel()
  1319  
  1320  	ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("."))))
  1321  	defer ts.Close()
  1322  
  1323  	// Create a task that downloads a script and executes it.
  1324  	alloc := mock.BatchAlloc()
  1325  	alloc.Job.TaskGroups[0].RestartPolicy = &structs.RestartPolicy{}
  1326  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1327  	task.Driver = "raw_exec"
  1328  	task.Config = map[string]interface{}{
  1329  		"command": "noop.sh",
  1330  	}
  1331  	task.Artifacts = []*structs.TaskArtifact{
  1332  		{
  1333  			GetterSource: fmt.Sprintf("%s/testdata/noop.sh", ts.URL),
  1334  			GetterMode:   "file",
  1335  			RelativeDest: "noop.sh",
  1336  		},
  1337  	}
  1338  
  1339  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
  1340  	defer cleanup()
  1341  
  1342  	// Wait for task to run and exit
  1343  	select {
  1344  	case <-tr.WaitCh():
  1345  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1346  		require.Fail(t, "timed out waiting for task runner to exit")
  1347  	}
  1348  
  1349  	state := tr.TaskState()
  1350  	require.Equal(t, structs.TaskStateDead, state.State)
  1351  	require.False(t, state.Failed)
  1352  }
  1353  
  1354  // TestTaskRunner_Download_List asserts that multiple artificats are downloaded
  1355  // before a task is run.
  1356  func TestTaskRunner_Download_List(t *testing.T) {
  1357  	t.Parallel()
  1358  	ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("."))))
  1359  	defer ts.Close()
  1360  
  1361  	// Create an allocation that has a task with a list of artifacts.
  1362  	alloc := mock.BatchAlloc()
  1363  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1364  	f1 := "task_runner_test.go"
  1365  	f2 := "task_runner.go"
  1366  	artifact1 := structs.TaskArtifact{
  1367  		GetterSource: fmt.Sprintf("%s/%s", ts.URL, f1),
  1368  	}
  1369  	artifact2 := structs.TaskArtifact{
  1370  		GetterSource: fmt.Sprintf("%s/%s", ts.URL, f2),
  1371  	}
  1372  	task.Artifacts = []*structs.TaskArtifact{&artifact1, &artifact2}
  1373  
  1374  	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
  1375  	defer cleanup()
  1376  
  1377  	// Wait for task to run and exit
  1378  	select {
  1379  	case <-tr.WaitCh():
  1380  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1381  		require.Fail(t, "timed out waiting for task runner to exit")
  1382  	}
  1383  
  1384  	state := tr.TaskState()
  1385  	require.Equal(t, structs.TaskStateDead, state.State)
  1386  	require.False(t, state.Failed)
  1387  
  1388  	require.Len(t, state.Events, 5)
  1389  	assert.Equal(t, structs.TaskReceived, state.Events[0].Type)
  1390  	assert.Equal(t, structs.TaskSetup, state.Events[1].Type)
  1391  	assert.Equal(t, structs.TaskDownloadingArtifacts, state.Events[2].Type)
  1392  	assert.Equal(t, structs.TaskStarted, state.Events[3].Type)
  1393  	assert.Equal(t, structs.TaskTerminated, state.Events[4].Type)
  1394  
  1395  	// Check that both files exist.
  1396  	_, err := os.Stat(filepath.Join(conf.TaskDir.Dir, f1))
  1397  	require.NoErrorf(t, err, "%v not downloaded", f1)
  1398  
  1399  	_, err = os.Stat(filepath.Join(conf.TaskDir.Dir, f2))
  1400  	require.NoErrorf(t, err, "%v not downloaded", f2)
  1401  }
  1402  
  1403  // TestTaskRunner_Download_Retries asserts that failed artifact downloads are
  1404  // retried according to the task's restart policy.
  1405  func TestTaskRunner_Download_Retries(t *testing.T) {
  1406  	t.Parallel()
  1407  
  1408  	// Create an allocation that has a task with bad artifacts.
  1409  	alloc := mock.BatchAlloc()
  1410  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1411  	artifact := structs.TaskArtifact{
  1412  		GetterSource: "http://127.0.0.1:0/foo/bar/baz",
  1413  	}
  1414  	task.Artifacts = []*structs.TaskArtifact{&artifact}
  1415  
  1416  	// Make the restart policy retry once
  1417  	alloc.Job.TaskGroups[0].RestartPolicy = &structs.RestartPolicy{
  1418  		Attempts: 1,
  1419  		Interval: 10 * time.Minute,
  1420  		Delay:    1 * time.Second,
  1421  		Mode:     structs.RestartPolicyModeFail,
  1422  	}
  1423  
  1424  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
  1425  	defer cleanup()
  1426  
  1427  	select {
  1428  	case <-tr.WaitCh():
  1429  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1430  		require.Fail(t, "timed out waiting for task to exit")
  1431  	}
  1432  
  1433  	state := tr.TaskState()
  1434  	require.Equal(t, structs.TaskStateDead, state.State)
  1435  	require.True(t, state.Failed)
  1436  	require.Len(t, state.Events, 8, pretty.Sprint(state.Events))
  1437  	require.Equal(t, structs.TaskReceived, state.Events[0].Type)
  1438  	require.Equal(t, structs.TaskSetup, state.Events[1].Type)
  1439  	require.Equal(t, structs.TaskDownloadingArtifacts, state.Events[2].Type)
  1440  	require.Equal(t, structs.TaskArtifactDownloadFailed, state.Events[3].Type)
  1441  	require.Equal(t, structs.TaskRestarting, state.Events[4].Type)
  1442  	require.Equal(t, structs.TaskDownloadingArtifacts, state.Events[5].Type)
  1443  	require.Equal(t, structs.TaskArtifactDownloadFailed, state.Events[6].Type)
  1444  	require.Equal(t, structs.TaskNotRestarting, state.Events[7].Type)
  1445  }
  1446  
  1447  // TestTaskRunner_DriverNetwork asserts that a driver's network is properly
  1448  // used in services and checks.
  1449  func TestTaskRunner_DriverNetwork(t *testing.T) {
  1450  	t.Parallel()
  1451  
  1452  	alloc := mock.Alloc()
  1453  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1454  	task.Driver = "mock_driver"
  1455  	task.Config = map[string]interface{}{
  1456  		"run_for":         "100s",
  1457  		"driver_ip":       "10.1.2.3",
  1458  		"driver_port_map": "http:80",
  1459  	}
  1460  
  1461  	// Create services and checks with custom address modes to exercise
  1462  	// address detection logic
  1463  	task.Services = []*structs.Service{
  1464  		{
  1465  			Name:        "host-service",
  1466  			PortLabel:   "http",
  1467  			AddressMode: "host",
  1468  			Checks: []*structs.ServiceCheck{
  1469  				{
  1470  					Name:        "driver-check",
  1471  					Type:        "tcp",
  1472  					PortLabel:   "1234",
  1473  					AddressMode: "driver",
  1474  				},
  1475  			},
  1476  		},
  1477  		{
  1478  			Name:        "driver-service",
  1479  			PortLabel:   "5678",
  1480  			AddressMode: "driver",
  1481  			Checks: []*structs.ServiceCheck{
  1482  				{
  1483  					Name:      "host-check",
  1484  					Type:      "tcp",
  1485  					PortLabel: "http",
  1486  				},
  1487  				{
  1488  					Name:        "driver-label-check",
  1489  					Type:        "tcp",
  1490  					PortLabel:   "http",
  1491  					AddressMode: "driver",
  1492  				},
  1493  			},
  1494  		},
  1495  	}
  1496  
  1497  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1498  	defer cleanup()
  1499  
  1500  	// Use a mock agent to test for services
  1501  	consulAgent := agentconsul.NewMockAgent()
  1502  	consulClient := agentconsul.NewServiceClient(consulAgent, conf.Logger, true)
  1503  	defer consulClient.Shutdown()
  1504  	go consulClient.Run()
  1505  
  1506  	conf.Consul = consulClient
  1507  
  1508  	tr, err := NewTaskRunner(conf)
  1509  	require.NoError(t, err)
  1510  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1511  	go tr.Run()
  1512  
  1513  	// Wait for the task to start
  1514  	testWaitForTaskToStart(t, tr)
  1515  
  1516  	testutil.WaitForResult(func() (bool, error) {
  1517  		services, _ := consulAgent.Services()
  1518  		if n := len(services); n != 2 {
  1519  			return false, fmt.Errorf("expected 2 services, but found %d", n)
  1520  		}
  1521  		for _, s := range services {
  1522  			switch s.Service {
  1523  			case "host-service":
  1524  				if expected := "192.168.0.100"; s.Address != expected {
  1525  					return false, fmt.Errorf("expected host-service to have IP=%s but found %s",
  1526  						expected, s.Address)
  1527  				}
  1528  			case "driver-service":
  1529  				if expected := "10.1.2.3"; s.Address != expected {
  1530  					return false, fmt.Errorf("expected driver-service to have IP=%s but found %s",
  1531  						expected, s.Address)
  1532  				}
  1533  				if expected := 5678; s.Port != expected {
  1534  					return false, fmt.Errorf("expected driver-service to have port=%d but found %d",
  1535  						expected, s.Port)
  1536  				}
  1537  			default:
  1538  				return false, fmt.Errorf("unexpected service: %q", s.Service)
  1539  			}
  1540  
  1541  		}
  1542  
  1543  		checks := consulAgent.CheckRegs()
  1544  		if n := len(checks); n != 3 {
  1545  			return false, fmt.Errorf("expected 3 checks, but found %d", n)
  1546  		}
  1547  		for _, check := range checks {
  1548  			switch check.Name {
  1549  			case "driver-check":
  1550  				if expected := "10.1.2.3:1234"; check.TCP != expected {
  1551  					return false, fmt.Errorf("expected driver-check to have address %q but found %q", expected, check.TCP)
  1552  				}
  1553  			case "driver-label-check":
  1554  				if expected := "10.1.2.3:80"; check.TCP != expected {
  1555  					return false, fmt.Errorf("expected driver-label-check to have address %q but found %q", expected, check.TCP)
  1556  				}
  1557  			case "host-check":
  1558  				if expected := "192.168.0.100:"; !strings.HasPrefix(check.TCP, expected) {
  1559  					return false, fmt.Errorf("expected host-check to have address start with %q but found %q", expected, check.TCP)
  1560  				}
  1561  			default:
  1562  				return false, fmt.Errorf("unexpected check: %q", check.Name)
  1563  			}
  1564  		}
  1565  
  1566  		return true, nil
  1567  	}, func(err error) {
  1568  		services, _ := consulAgent.Services()
  1569  		for _, s := range services {
  1570  			t.Logf(pretty.Sprint("Service: ", s))
  1571  		}
  1572  		for _, c := range consulAgent.CheckRegs() {
  1573  			t.Logf(pretty.Sprint("Check:   ", c))
  1574  		}
  1575  		require.NoError(t, err)
  1576  	})
  1577  }
  1578  
  1579  // TestTaskRunner_RestartSignalTask_NotRunning asserts resilience to failures
  1580  // when a restart or signal is triggered and the task is not running.
  1581  func TestTaskRunner_RestartSignalTask_NotRunning(t *testing.T) {
  1582  	t.Parallel()
  1583  
  1584  	alloc := mock.BatchAlloc()
  1585  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1586  	task.Driver = "mock_driver"
  1587  	task.Config = map[string]interface{}{
  1588  		"run_for": "0s",
  1589  	}
  1590  
  1591  	// Use vault to block the start
  1592  	task.Vault = &structs.Vault{Policies: []string{"default"}}
  1593  
  1594  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1595  	defer cleanup()
  1596  
  1597  	// Control when we get a Vault token
  1598  	waitCh := make(chan struct{}, 1)
  1599  	defer close(waitCh)
  1600  	handler := func(*structs.Allocation, []string) (map[string]string, error) {
  1601  		<-waitCh
  1602  		return map[string]string{task.Name: "1234"}, nil
  1603  	}
  1604  	vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
  1605  	vaultClient.DeriveTokenFn = handler
  1606  
  1607  	tr, err := NewTaskRunner(conf)
  1608  	require.NoError(t, err)
  1609  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1610  	go tr.Run()
  1611  
  1612  	select {
  1613  	case <-tr.WaitCh():
  1614  		require.Fail(t, "unexpected exit")
  1615  	case <-time.After(1 * time.Second):
  1616  	}
  1617  
  1618  	// Send a signal and restart
  1619  	err = tr.Signal(structs.NewTaskEvent("don't panic"), "QUIT")
  1620  	require.EqualError(t, err, ErrTaskNotRunning.Error())
  1621  
  1622  	// Send a restart
  1623  	err = tr.Restart(context.Background(), structs.NewTaskEvent("don't panic"), false)
  1624  	require.EqualError(t, err, ErrTaskNotRunning.Error())
  1625  
  1626  	// Unblock and let it finish
  1627  	waitCh <- struct{}{}
  1628  
  1629  	select {
  1630  	case <-tr.WaitCh():
  1631  	case <-time.After(10 * time.Second):
  1632  		require.Fail(t, "timed out waiting for task to complete")
  1633  	}
  1634  
  1635  	// Assert the task ran and never restarted
  1636  	state := tr.TaskState()
  1637  	require.Equal(t, structs.TaskStateDead, state.State)
  1638  	require.False(t, state.Failed)
  1639  	require.Len(t, state.Events, 4, pretty.Sprint(state.Events))
  1640  	require.Equal(t, structs.TaskReceived, state.Events[0].Type)
  1641  	require.Equal(t, structs.TaskSetup, state.Events[1].Type)
  1642  	require.Equal(t, structs.TaskStarted, state.Events[2].Type)
  1643  	require.Equal(t, structs.TaskTerminated, state.Events[3].Type)
  1644  }
  1645  
  1646  // TestTaskRunner_Run_RecoverableStartError asserts tasks are restarted if they
  1647  // return a recoverable error from StartTask.
  1648  func TestTaskRunner_Run_RecoverableStartError(t *testing.T) {
  1649  	t.Parallel()
  1650  
  1651  	alloc := mock.BatchAlloc()
  1652  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1653  	task.Config = map[string]interface{}{
  1654  		"start_error":             "driver failure",
  1655  		"start_error_recoverable": true,
  1656  	}
  1657  
  1658  	// Make the restart policy retry once
  1659  	alloc.Job.TaskGroups[0].RestartPolicy = &structs.RestartPolicy{
  1660  		Attempts: 1,
  1661  		Interval: 10 * time.Minute,
  1662  		Delay:    0,
  1663  		Mode:     structs.RestartPolicyModeFail,
  1664  	}
  1665  
  1666  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
  1667  	defer cleanup()
  1668  
  1669  	select {
  1670  	case <-tr.WaitCh():
  1671  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1672  		require.Fail(t, "timed out waiting for task to exit")
  1673  	}
  1674  
  1675  	state := tr.TaskState()
  1676  	require.Equal(t, structs.TaskStateDead, state.State)
  1677  	require.True(t, state.Failed)
  1678  	require.Len(t, state.Events, 6, pretty.Sprint(state.Events))
  1679  	require.Equal(t, structs.TaskReceived, state.Events[0].Type)
  1680  	require.Equal(t, structs.TaskSetup, state.Events[1].Type)
  1681  	require.Equal(t, structs.TaskDriverFailure, state.Events[2].Type)
  1682  	require.Equal(t, structs.TaskRestarting, state.Events[3].Type)
  1683  	require.Equal(t, structs.TaskDriverFailure, state.Events[4].Type)
  1684  	require.Equal(t, structs.TaskNotRestarting, state.Events[5].Type)
  1685  }
  1686  
  1687  // TestTaskRunner_Template_Artifact asserts that tasks can use artifacts as templates.
  1688  func TestTaskRunner_Template_Artifact(t *testing.T) {
  1689  	t.Parallel()
  1690  
  1691  	ts := httptest.NewServer(http.FileServer(http.Dir(".")))
  1692  	defer ts.Close()
  1693  
  1694  	alloc := mock.BatchAlloc()
  1695  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1696  	f1 := "task_runner.go"
  1697  	f2 := "test"
  1698  	task.Artifacts = []*structs.TaskArtifact{
  1699  		{GetterSource: fmt.Sprintf("%s/%s", ts.URL, f1)},
  1700  	}
  1701  	task.Templates = []*structs.Template{
  1702  		{
  1703  			SourcePath: f1,
  1704  			DestPath:   "local/test",
  1705  			ChangeMode: structs.TemplateChangeModeNoop,
  1706  		},
  1707  	}
  1708  
  1709  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1710  	defer cleanup()
  1711  
  1712  	tr, err := NewTaskRunner(conf)
  1713  	require.NoError(t, err)
  1714  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1715  	go tr.Run()
  1716  
  1717  	// Wait for task to run and exit
  1718  	select {
  1719  	case <-tr.WaitCh():
  1720  	case <-time.After(15 * time.Second * time.Duration(testutil.TestMultiplier())):
  1721  		require.Fail(t, "timed out waiting for task runner to exit")
  1722  	}
  1723  
  1724  	state := tr.TaskState()
  1725  	require.Equal(t, structs.TaskStateDead, state.State)
  1726  	require.True(t, state.Successful())
  1727  	require.False(t, state.Failed)
  1728  
  1729  	artifactsDownloaded := false
  1730  	for _, e := range state.Events {
  1731  		if e.Type == structs.TaskDownloadingArtifacts {
  1732  			artifactsDownloaded = true
  1733  		}
  1734  	}
  1735  	assert.True(t, artifactsDownloaded, "expected artifacts downloaded events")
  1736  
  1737  	// Check that both files exist.
  1738  	_, err = os.Stat(filepath.Join(conf.TaskDir.Dir, f1))
  1739  	require.NoErrorf(t, err, "%v not downloaded", f1)
  1740  
  1741  	_, err = os.Stat(filepath.Join(conf.TaskDir.LocalDir, f2))
  1742  	require.NoErrorf(t, err, "%v not rendered", f2)
  1743  }
  1744  
  1745  // TestTaskRunner_Template_NewVaultToken asserts that a new vault token is
  1746  // created when rendering template and that it is revoked on alloc completion
  1747  func TestTaskRunner_Template_NewVaultToken(t *testing.T) {
  1748  	t.Parallel()
  1749  
  1750  	alloc := mock.BatchAlloc()
  1751  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1752  	task.Templates = []*structs.Template{
  1753  		{
  1754  			EmbeddedTmpl: `{{key "foo"}}`,
  1755  			DestPath:     "local/test",
  1756  			ChangeMode:   structs.TemplateChangeModeNoop,
  1757  		},
  1758  	}
  1759  	task.Vault = &structs.Vault{Policies: []string{"default"}}
  1760  
  1761  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1762  	defer cleanup()
  1763  
  1764  	tr, err := NewTaskRunner(conf)
  1765  	require.NoError(t, err)
  1766  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1767  	go tr.Run()
  1768  
  1769  	// Wait for a Vault token
  1770  	var token string
  1771  	testutil.WaitForResult(func() (bool, error) {
  1772  		token = tr.getVaultToken()
  1773  
  1774  		if token == "" {
  1775  			return false, fmt.Errorf("No Vault token")
  1776  		}
  1777  
  1778  		return true, nil
  1779  	}, func(err error) {
  1780  		require.NoError(t, err)
  1781  	})
  1782  
  1783  	vault := conf.Vault.(*vaultclient.MockVaultClient)
  1784  	renewalCh, ok := vault.RenewTokens()[token]
  1785  	require.True(t, ok, "no renewal channel for token")
  1786  
  1787  	renewalCh <- fmt.Errorf("Test killing")
  1788  	close(renewalCh)
  1789  
  1790  	var token2 string
  1791  	testutil.WaitForResult(func() (bool, error) {
  1792  		token2 = tr.getVaultToken()
  1793  
  1794  		if token2 == "" {
  1795  			return false, fmt.Errorf("No Vault token")
  1796  		}
  1797  
  1798  		if token2 == token {
  1799  			return false, fmt.Errorf("token wasn't recreated")
  1800  		}
  1801  
  1802  		return true, nil
  1803  	}, func(err error) {
  1804  		require.NoError(t, err)
  1805  	})
  1806  
  1807  	// Check the token was revoked
  1808  	testutil.WaitForResult(func() (bool, error) {
  1809  		if len(vault.StoppedTokens()) != 1 {
  1810  			return false, fmt.Errorf("Expected a stopped token: %v", vault.StoppedTokens())
  1811  		}
  1812  
  1813  		if a := vault.StoppedTokens()[0]; a != token {
  1814  			return false, fmt.Errorf("got stopped token %q; want %q", a, token)
  1815  		}
  1816  
  1817  		return true, nil
  1818  	}, func(err error) {
  1819  		require.NoError(t, err)
  1820  	})
  1821  
  1822  }
  1823  
  1824  // TestTaskRunner_VaultManager_Restart asserts that the alloc is restarted when the alloc
  1825  // derived vault token expires, when task is configured with Restart change mode
  1826  func TestTaskRunner_VaultManager_Restart(t *testing.T) {
  1827  	t.Parallel()
  1828  
  1829  	alloc := mock.BatchAlloc()
  1830  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1831  	task.Config = map[string]interface{}{
  1832  		"run_for": "10s",
  1833  	}
  1834  	task.Vault = &structs.Vault{
  1835  		Policies:   []string{"default"},
  1836  		ChangeMode: structs.VaultChangeModeRestart,
  1837  	}
  1838  
  1839  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1840  	defer cleanup()
  1841  
  1842  	tr, err := NewTaskRunner(conf)
  1843  	require.NoError(t, err)
  1844  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1845  	go tr.Run()
  1846  
  1847  	testWaitForTaskToStart(t, tr)
  1848  
  1849  	tr.vaultTokenLock.Lock()
  1850  	token := tr.vaultToken
  1851  	tr.vaultTokenLock.Unlock()
  1852  
  1853  	require.NotEmpty(t, token)
  1854  
  1855  	vault := conf.Vault.(*vaultclient.MockVaultClient)
  1856  	renewalCh, ok := vault.RenewTokens()[token]
  1857  	require.True(t, ok, "no renewal channel for token")
  1858  
  1859  	renewalCh <- fmt.Errorf("Test killing")
  1860  	close(renewalCh)
  1861  
  1862  	testutil.WaitForResult(func() (bool, error) {
  1863  		state := tr.TaskState()
  1864  
  1865  		if len(state.Events) == 0 {
  1866  			return false, fmt.Errorf("no events yet")
  1867  		}
  1868  
  1869  		foundRestartSignal, foundRestarting := false, false
  1870  		for _, e := range state.Events {
  1871  			switch e.Type {
  1872  			case structs.TaskRestartSignal:
  1873  				foundRestartSignal = true
  1874  			case structs.TaskRestarting:
  1875  				foundRestarting = true
  1876  			}
  1877  		}
  1878  
  1879  		if !foundRestartSignal {
  1880  			return false, fmt.Errorf("no restart signal event yet: %#v", state.Events)
  1881  		}
  1882  
  1883  		if !foundRestarting {
  1884  			return false, fmt.Errorf("no restarting event yet: %#v", state.Events)
  1885  		}
  1886  
  1887  		lastEvent := state.Events[len(state.Events)-1]
  1888  		if lastEvent.Type != structs.TaskStarted {
  1889  			return false, fmt.Errorf("expected last event to be task starting but was %#v", lastEvent)
  1890  		}
  1891  		return true, nil
  1892  	}, func(err error) {
  1893  		require.NoError(t, err)
  1894  	})
  1895  }
  1896  
  1897  // TestTaskRunner_VaultManager_Signal asserts that the alloc is signalled when the alloc
  1898  // derived vault token expires, when task is configured with signal change mode
  1899  func TestTaskRunner_VaultManager_Signal(t *testing.T) {
  1900  	t.Parallel()
  1901  
  1902  	alloc := mock.BatchAlloc()
  1903  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1904  	task.Config = map[string]interface{}{
  1905  		"run_for": "10s",
  1906  	}
  1907  	task.Vault = &structs.Vault{
  1908  		Policies:     []string{"default"},
  1909  		ChangeMode:   structs.VaultChangeModeSignal,
  1910  		ChangeSignal: "SIGUSR1",
  1911  	}
  1912  
  1913  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1914  	defer cleanup()
  1915  
  1916  	tr, err := NewTaskRunner(conf)
  1917  	require.NoError(t, err)
  1918  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1919  	go tr.Run()
  1920  
  1921  	testWaitForTaskToStart(t, tr)
  1922  
  1923  	tr.vaultTokenLock.Lock()
  1924  	token := tr.vaultToken
  1925  	tr.vaultTokenLock.Unlock()
  1926  
  1927  	require.NotEmpty(t, token)
  1928  
  1929  	vault := conf.Vault.(*vaultclient.MockVaultClient)
  1930  	renewalCh, ok := vault.RenewTokens()[token]
  1931  	require.True(t, ok, "no renewal channel for token")
  1932  
  1933  	renewalCh <- fmt.Errorf("Test killing")
  1934  	close(renewalCh)
  1935  
  1936  	testutil.WaitForResult(func() (bool, error) {
  1937  		state := tr.TaskState()
  1938  
  1939  		if len(state.Events) == 0 {
  1940  			return false, fmt.Errorf("no events yet")
  1941  		}
  1942  
  1943  		foundSignaling := false
  1944  		for _, e := range state.Events {
  1945  			if e.Type == structs.TaskSignaling {
  1946  				foundSignaling = true
  1947  			}
  1948  		}
  1949  
  1950  		if !foundSignaling {
  1951  			return false, fmt.Errorf("no signaling event yet: %#v", state.Events)
  1952  		}
  1953  
  1954  		return true, nil
  1955  	}, func(err error) {
  1956  		require.NoError(t, err)
  1957  	})
  1958  
  1959  }
  1960  
  1961  // TestTaskRunner_UnregisterConsul_Retries asserts a task is unregistered from
  1962  // Consul when waiting to be retried.
  1963  func TestTaskRunner_UnregisterConsul_Retries(t *testing.T) {
  1964  	t.Parallel()
  1965  
  1966  	alloc := mock.Alloc()
  1967  	// Make the restart policy try one ctx.update
  1968  	alloc.Job.TaskGroups[0].RestartPolicy = &structs.RestartPolicy{
  1969  		Attempts: 1,
  1970  		Interval: 10 * time.Minute,
  1971  		Delay:    time.Nanosecond,
  1972  		Mode:     structs.RestartPolicyModeFail,
  1973  	}
  1974  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1975  	task.Driver = "mock_driver"
  1976  	task.Config = map[string]interface{}{
  1977  		"exit_code": "1",
  1978  		"run_for":   "1ns",
  1979  	}
  1980  
  1981  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1982  	defer cleanup()
  1983  
  1984  	tr, err := NewTaskRunner(conf)
  1985  	require.NoError(t, err)
  1986  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1987  	tr.Run()
  1988  
  1989  	state := tr.TaskState()
  1990  	require.Equal(t, structs.TaskStateDead, state.State)
  1991  
  1992  	consul := conf.Consul.(*consulapi.MockConsulServiceClient)
  1993  	consulOps := consul.GetOps()
  1994  	require.Len(t, consulOps, 8)
  1995  
  1996  	// Initial add
  1997  	require.Equal(t, "add", consulOps[0].Op)
  1998  
  1999  	// Removing canary and non-canary entries on first exit
  2000  	require.Equal(t, "remove", consulOps[1].Op)
  2001  	require.Equal(t, "remove", consulOps[2].Op)
  2002  
  2003  	// Second add on retry
  2004  	require.Equal(t, "add", consulOps[3].Op)
  2005  
  2006  	// Removing canary and non-canary entries on retry
  2007  	require.Equal(t, "remove", consulOps[4].Op)
  2008  	require.Equal(t, "remove", consulOps[5].Op)
  2009  
  2010  	// Removing canary and non-canary entries on stop
  2011  	require.Equal(t, "remove", consulOps[6].Op)
  2012  	require.Equal(t, "remove", consulOps[7].Op)
  2013  }
  2014  
  2015  // testWaitForTaskToStart waits for the task to be running or fails the test
  2016  func testWaitForTaskToStart(t *testing.T, tr *TaskRunner) {
  2017  	testutil.WaitForResult(func() (bool, error) {
  2018  		ts := tr.TaskState()
  2019  		return ts.State == structs.TaskStateRunning, fmt.Errorf("%v", ts.State)
  2020  	}, func(err error) {
  2021  		require.NoError(t, err)
  2022  	})
  2023  }