github.com/smithx10/nomad@v0.9.1-rc1/client/allocrunner/taskrunner/task_runner_test.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"os"
    10  	"path/filepath"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/golang/snappy"
    16  	"github.com/hashicorp/nomad/client/allocdir"
    17  	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
    18  	"github.com/hashicorp/nomad/client/config"
    19  	"github.com/hashicorp/nomad/client/consul"
    20  	consulapi "github.com/hashicorp/nomad/client/consul"
    21  	"github.com/hashicorp/nomad/client/devicemanager"
    22  	"github.com/hashicorp/nomad/client/pluginmanager/drivermanager"
    23  	cstate "github.com/hashicorp/nomad/client/state"
    24  	ctestutil "github.com/hashicorp/nomad/client/testutil"
    25  	"github.com/hashicorp/nomad/client/vaultclient"
    26  	agentconsul "github.com/hashicorp/nomad/command/agent/consul"
    27  	mockdriver "github.com/hashicorp/nomad/drivers/mock"
    28  	"github.com/hashicorp/nomad/helper/testlog"
    29  	"github.com/hashicorp/nomad/nomad/mock"
    30  	"github.com/hashicorp/nomad/nomad/structs"
    31  	"github.com/hashicorp/nomad/plugins/device"
    32  	"github.com/hashicorp/nomad/testutil"
    33  	"github.com/kr/pretty"
    34  	"github.com/stretchr/testify/assert"
    35  	"github.com/stretchr/testify/require"
    36  )
    37  
    38  type MockTaskStateUpdater struct {
    39  	ch chan struct{}
    40  }
    41  
    42  func NewMockTaskStateUpdater() *MockTaskStateUpdater {
    43  	return &MockTaskStateUpdater{
    44  		ch: make(chan struct{}, 1),
    45  	}
    46  }
    47  
    48  func (m *MockTaskStateUpdater) TaskStateUpdated() {
    49  	select {
    50  	case m.ch <- struct{}{}:
    51  	default:
    52  	}
    53  }
    54  
    55  // testTaskRunnerConfig returns a taskrunner.Config for the given alloc+task
    56  // plus a cleanup func.
    57  func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName string) (*Config, func()) {
    58  	logger := testlog.HCLogger(t)
    59  	clientConf, cleanup := config.TestClientConfig(t)
    60  
    61  	// Find the task
    62  	var thisTask *structs.Task
    63  	for _, tg := range alloc.Job.TaskGroups {
    64  		for _, task := range tg.Tasks {
    65  			if task.Name == taskName {
    66  				if thisTask != nil {
    67  					cleanup()
    68  					t.Fatalf("multiple tasks named %q; cannot use this helper", taskName)
    69  				}
    70  				thisTask = task
    71  			}
    72  		}
    73  	}
    74  	if thisTask == nil {
    75  		cleanup()
    76  		t.Fatalf("could not find task %q", taskName)
    77  	}
    78  
    79  	// Create the alloc dir + task dir
    80  	allocPath := filepath.Join(clientConf.AllocDir, alloc.ID)
    81  	allocDir := allocdir.NewAllocDir(logger, allocPath)
    82  	if err := allocDir.Build(); err != nil {
    83  		cleanup()
    84  		t.Fatalf("error building alloc dir: %v", err)
    85  	}
    86  	taskDir := allocDir.NewTaskDir(taskName)
    87  
    88  	trCleanup := func() {
    89  		if err := allocDir.Destroy(); err != nil {
    90  			t.Logf("error destroying alloc dir: %v", err)
    91  		}
    92  		cleanup()
    93  	}
    94  
    95  	conf := &Config{
    96  		Alloc:         alloc,
    97  		ClientConfig:  clientConf,
    98  		Consul:        consulapi.NewMockConsulServiceClient(t, logger),
    99  		Task:          thisTask,
   100  		TaskDir:       taskDir,
   101  		Logger:        clientConf.Logger,
   102  		Vault:         vaultclient.NewMockVaultClient(),
   103  		StateDB:       cstate.NoopDB{},
   104  		StateUpdater:  NewMockTaskStateUpdater(),
   105  		DeviceManager: devicemanager.NoopMockManager(),
   106  		DriverManager: drivermanager.TestDriverManager(t),
   107  	}
   108  	return conf, trCleanup
   109  }
   110  
   111  // runTestTaskRunner runs a TaskRunner and returns its configuration as well as
   112  // a cleanup function that ensures the runner is stopped and cleaned up. Tests
   113  // which need to change the Config *must* use testTaskRunnerConfig instead.
   114  func runTestTaskRunner(t *testing.T, alloc *structs.Allocation, taskName string) (*TaskRunner, *Config, func()) {
   115  	config, cleanup := testTaskRunnerConfig(t, alloc, taskName)
   116  
   117  	tr, err := NewTaskRunner(config)
   118  	require.NoError(t, err)
   119  	go tr.Run()
   120  
   121  	return tr, config, func() {
   122  		tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   123  		cleanup()
   124  	}
   125  }
   126  
   127  // TestTaskRunner_Restore asserts restoring a running task does not rerun the
   128  // task.
   129  func TestTaskRunner_Restore_Running(t *testing.T) {
   130  	t.Parallel()
   131  	require := require.New(t)
   132  
   133  	alloc := mock.BatchAlloc()
   134  	alloc.Job.TaskGroups[0].Count = 1
   135  	task := alloc.Job.TaskGroups[0].Tasks[0]
   136  	task.Driver = "mock_driver"
   137  	task.Config = map[string]interface{}{
   138  		"run_for": "2s",
   139  	}
   140  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   141  	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
   142  	defer cleanup()
   143  
   144  	// Run the first TaskRunner
   145  	origTR, err := NewTaskRunner(conf)
   146  	require.NoError(err)
   147  	go origTR.Run()
   148  	defer origTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   149  
   150  	// Wait for it to be running
   151  	testWaitForTaskToStart(t, origTR)
   152  
   153  	// Cause TR to exit without shutting down task
   154  	origTR.Shutdown()
   155  
   156  	// Start a new TaskRunner and make sure it does not rerun the task
   157  	newTR, err := NewTaskRunner(conf)
   158  	require.NoError(err)
   159  
   160  	// Do the Restore
   161  	require.NoError(newTR.Restore())
   162  
   163  	go newTR.Run()
   164  	defer newTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   165  
   166  	// Wait for new task runner to exit when the process does
   167  	<-newTR.WaitCh()
   168  
   169  	// Assert that the process was only started once
   170  	started := 0
   171  	state := newTR.TaskState()
   172  	require.Equal(structs.TaskStateDead, state.State)
   173  	for _, ev := range state.Events {
   174  		if ev.Type == structs.TaskStarted {
   175  			started++
   176  		}
   177  	}
   178  	assert.Equal(t, 1, started)
   179  }
   180  
   181  // TestTaskRunner_TaskEnv_Interpolated asserts driver configurations are
   182  // interpolated.
   183  func TestTaskRunner_TaskEnv_Interpolated(t *testing.T) {
   184  	t.Parallel()
   185  	require := require.New(t)
   186  
   187  	alloc := mock.BatchAlloc()
   188  	alloc.Job.TaskGroups[0].Meta = map[string]string{
   189  		"common_user": "somebody",
   190  	}
   191  	task := alloc.Job.TaskGroups[0].Tasks[0]
   192  	task.Meta = map[string]string{
   193  		"foo": "bar",
   194  	}
   195  
   196  	// Use interpolation from both node attributes and meta vars
   197  	task.Config = map[string]interface{}{
   198  		"run_for":       "1ms",
   199  		"stdout_string": `${node.region} ${NOMAD_META_foo} ${NOMAD_META_common_user}`,
   200  	}
   201  
   202  	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
   203  	defer cleanup()
   204  
   205  	// Wait for task to complete
   206  	select {
   207  	case <-tr.WaitCh():
   208  	case <-time.After(3 * time.Second):
   209  		require.Fail("timeout waiting for task to exit")
   210  	}
   211  
   212  	// Get the mock driver plugin
   213  	driverPlugin, err := conf.DriverManager.Dispense(mockdriver.PluginID.Name)
   214  	require.NoError(err)
   215  	mockDriver := driverPlugin.(*mockdriver.Driver)
   216  
   217  	// Assert its config has been properly interpolated
   218  	driverCfg, mockCfg := mockDriver.GetTaskConfig()
   219  	require.NotNil(driverCfg)
   220  	require.NotNil(mockCfg)
   221  	assert.Equal(t, "global bar somebody", mockCfg.StdoutString)
   222  }
   223  
   224  // TestTaskRunner_TaskEnv_Chroot asserts chroot drivers use chroot paths and
   225  // not host paths.
   226  func TestTaskRunner_TaskEnv_Chroot(t *testing.T) {
   227  	ctestutil.ExecCompatible(t)
   228  	t.Parallel()
   229  	require := require.New(t)
   230  
   231  	alloc := mock.BatchAlloc()
   232  	task := alloc.Job.TaskGroups[0].Tasks[0]
   233  	task.Driver = "exec"
   234  	task.Config = map[string]interface{}{
   235  		"command": "bash",
   236  		"args": []string{"-c", "echo $NOMAD_ALLOC_DIR; " +
   237  			"echo $NOMAD_TASK_DIR; " +
   238  			"echo $NOMAD_SECRETS_DIR; " +
   239  			"echo $PATH; ",
   240  		},
   241  	}
   242  
   243  	// Expect chroot paths and host $PATH
   244  	exp := fmt.Sprintf(`/alloc
   245  /local
   246  /secrets
   247  %s
   248  `, os.Getenv("PATH"))
   249  
   250  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   251  	defer cleanup()
   252  
   253  	// Remove /sbin and /usr from chroot
   254  	conf.ClientConfig.ChrootEnv = map[string]string{
   255  		"/bin":            "/bin",
   256  		"/etc":            "/etc",
   257  		"/lib":            "/lib",
   258  		"/lib32":          "/lib32",
   259  		"/lib64":          "/lib64",
   260  		"/run/resolvconf": "/run/resolvconf",
   261  	}
   262  
   263  	tr, err := NewTaskRunner(conf)
   264  	require.NoError(err)
   265  	go tr.Run()
   266  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   267  
   268  	// Wait for task to exit
   269  	select {
   270  	case <-tr.WaitCh():
   271  	case <-time.After(15 * time.Second):
   272  		require.Fail("timeout waiting for task to exit")
   273  	}
   274  
   275  	// Read stdout
   276  	p := filepath.Join(conf.TaskDir.LogDir, task.Name+".stdout.0")
   277  	stdout, err := ioutil.ReadFile(p)
   278  	require.NoError(err)
   279  	require.Equalf(exp, string(stdout), "expected: %s\n\nactual: %s\n", exp, stdout)
   280  }
   281  
   282  // TestTaskRunner_TaskEnv_Image asserts image drivers use chroot paths and
   283  // not host paths. Host env vars should also be excluded.
   284  func TestTaskRunner_TaskEnv_Image(t *testing.T) {
   285  	ctestutil.DockerCompatible(t)
   286  	t.Parallel()
   287  	require := require.New(t)
   288  
   289  	alloc := mock.BatchAlloc()
   290  	task := alloc.Job.TaskGroups[0].Tasks[0]
   291  	task.Driver = "docker"
   292  	task.Config = map[string]interface{}{
   293  		"image":        "redis:3.2-alpine",
   294  		"network_mode": "none",
   295  		"command":      "sh",
   296  		"args": []string{"-c", "echo $NOMAD_ALLOC_DIR; " +
   297  			"echo $NOMAD_TASK_DIR; " +
   298  			"echo $NOMAD_SECRETS_DIR; " +
   299  			"echo $PATH",
   300  		},
   301  	}
   302  
   303  	// Expect chroot paths and image specific PATH
   304  	exp := `/alloc
   305  /local
   306  /secrets
   307  /usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
   308  `
   309  
   310  	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
   311  	defer cleanup()
   312  
   313  	// Wait for task to exit
   314  	select {
   315  	case <-tr.WaitCh():
   316  	case <-time.After(15 * time.Second):
   317  		require.Fail("timeout waiting for task to exit")
   318  	}
   319  
   320  	// Read stdout
   321  	p := filepath.Join(conf.TaskDir.LogDir, task.Name+".stdout.0")
   322  	stdout, err := ioutil.ReadFile(p)
   323  	require.NoError(err)
   324  	require.Equalf(exp, string(stdout), "expected: %s\n\nactual: %s\n", exp, stdout)
   325  }
   326  
   327  // TestTaskRunner_TaskEnv_None asserts raw_exec uses host paths and env vars.
   328  func TestTaskRunner_TaskEnv_None(t *testing.T) {
   329  	t.Parallel()
   330  	require := require.New(t)
   331  
   332  	alloc := mock.BatchAlloc()
   333  	task := alloc.Job.TaskGroups[0].Tasks[0]
   334  	task.Driver = "raw_exec"
   335  	task.Config = map[string]interface{}{
   336  		"command": "sh",
   337  		"args": []string{"-c", "echo $NOMAD_ALLOC_DIR; " +
   338  			"echo $NOMAD_TASK_DIR; " +
   339  			"echo $NOMAD_SECRETS_DIR; " +
   340  			"echo $PATH",
   341  		},
   342  	}
   343  
   344  	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
   345  	defer cleanup()
   346  
   347  	// Expect host paths
   348  	root := filepath.Join(conf.ClientConfig.AllocDir, alloc.ID)
   349  	taskDir := filepath.Join(root, task.Name)
   350  	exp := fmt.Sprintf(`%s/alloc
   351  %s/local
   352  %s/secrets
   353  %s
   354  `, root, taskDir, taskDir, os.Getenv("PATH"))
   355  
   356  	// Wait for task to exit
   357  	select {
   358  	case <-tr.WaitCh():
   359  	case <-time.After(15 * time.Second):
   360  		require.Fail("timeout waiting for task to exit")
   361  	}
   362  
   363  	// Read stdout
   364  	p := filepath.Join(conf.TaskDir.LogDir, task.Name+".stdout.0")
   365  	stdout, err := ioutil.ReadFile(p)
   366  	require.NoError(err)
   367  	require.Equalf(exp, string(stdout), "expected: %s\n\nactual: %s\n", exp, stdout)
   368  }
   369  
   370  // Test that devices get sent to the driver
   371  func TestTaskRunner_DevicePropogation(t *testing.T) {
   372  	t.Parallel()
   373  	require := require.New(t)
   374  
   375  	// Create a mock alloc that has a gpu
   376  	alloc := mock.BatchAlloc()
   377  	alloc.Job.TaskGroups[0].Count = 1
   378  	task := alloc.Job.TaskGroups[0].Tasks[0]
   379  	task.Driver = "mock_driver"
   380  	task.Config = map[string]interface{}{
   381  		"run_for": "100ms",
   382  	}
   383  	tRes := alloc.AllocatedResources.Tasks[task.Name]
   384  	tRes.Devices = append(tRes.Devices, &structs.AllocatedDeviceResource{Type: "mock"})
   385  
   386  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   387  	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
   388  	defer cleanup()
   389  
   390  	// Setup the devicemanager
   391  	dm, ok := conf.DeviceManager.(*devicemanager.MockManager)
   392  	require.True(ok)
   393  
   394  	dm.ReserveF = func(d *structs.AllocatedDeviceResource) (*device.ContainerReservation, error) {
   395  		res := &device.ContainerReservation{
   396  			Envs: map[string]string{
   397  				"ABC": "123",
   398  			},
   399  			Mounts: []*device.Mount{
   400  				{
   401  					ReadOnly: true,
   402  					TaskPath: "foo",
   403  					HostPath: "bar",
   404  				},
   405  			},
   406  			Devices: []*device.DeviceSpec{
   407  				{
   408  					TaskPath:    "foo",
   409  					HostPath:    "bar",
   410  					CgroupPerms: "123",
   411  				},
   412  			},
   413  		}
   414  		return res, nil
   415  	}
   416  
   417  	// Run the TaskRunner
   418  	tr, err := NewTaskRunner(conf)
   419  	require.NoError(err)
   420  	go tr.Run()
   421  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   422  
   423  	// Wait for task to complete
   424  	select {
   425  	case <-tr.WaitCh():
   426  	case <-time.After(3 * time.Second):
   427  	}
   428  
   429  	// Get the mock driver plugin
   430  	driverPlugin, err := conf.DriverManager.Dispense(mockdriver.PluginID.Name)
   431  	require.NoError(err)
   432  	mockDriver := driverPlugin.(*mockdriver.Driver)
   433  
   434  	// Assert its config has been properly interpolated
   435  	driverCfg, _ := mockDriver.GetTaskConfig()
   436  	require.NotNil(driverCfg)
   437  	require.Len(driverCfg.Devices, 1)
   438  	require.Equal(driverCfg.Devices[0].Permissions, "123")
   439  	require.Len(driverCfg.Mounts, 1)
   440  	require.Equal(driverCfg.Mounts[0].TaskPath, "foo")
   441  	require.Contains(driverCfg.Env, "ABC")
   442  }
   443  
   444  // mockEnvHook is a test hook that sets an env var and done=true. It fails if
   445  // it's called more than once.
   446  type mockEnvHook struct {
   447  	called int
   448  }
   449  
   450  func (*mockEnvHook) Name() string {
   451  	return "mock_env_hook"
   452  }
   453  
   454  func (h *mockEnvHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error {
   455  	h.called++
   456  
   457  	resp.Done = true
   458  	resp.Env = map[string]string{
   459  		"mock_hook": "1",
   460  	}
   461  
   462  	return nil
   463  }
   464  
   465  // TestTaskRunner_Restore_HookEnv asserts that re-running prestart hooks with
   466  // hook environments set restores the environment without re-running done
   467  // hooks.
   468  func TestTaskRunner_Restore_HookEnv(t *testing.T) {
   469  	t.Parallel()
   470  	require := require.New(t)
   471  
   472  	alloc := mock.BatchAlloc()
   473  	task := alloc.Job.TaskGroups[0].Tasks[0]
   474  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   475  	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between prestart calls
   476  	defer cleanup()
   477  
   478  	tr, err := NewTaskRunner(conf)
   479  	require.NoError(err)
   480  
   481  	// Override the default hooks to only run the mock hook
   482  	mockHook := &mockEnvHook{}
   483  	tr.runnerHooks = []interfaces.TaskHook{mockHook}
   484  
   485  	// Manually run prestart hooks
   486  	require.NoError(tr.prestart())
   487  
   488  	// Assert env was called
   489  	require.Equal(1, mockHook.called)
   490  
   491  	// Re-running prestart hooks should *not* call done mock hook
   492  	require.NoError(tr.prestart())
   493  
   494  	// Assert env was called
   495  	require.Equal(1, mockHook.called)
   496  
   497  	// Assert the env is still set
   498  	env := tr.envBuilder.Build().All()
   499  	require.Contains(env, "mock_hook")
   500  	require.Equal("1", env["mock_hook"])
   501  }
   502  
   503  // This test asserts that we can recover from an "external" plugin exiting by
   504  // retrieving a new instance of the driver and recovering the task.
   505  func TestTaskRunner_RecoverFromDriverExiting(t *testing.T) {
   506  	t.Parallel()
   507  	require := require.New(t)
   508  
   509  	// Create an allocation using the mock driver that exits simulating the
   510  	// driver crashing. We can then test that the task runner recovers from this
   511  	alloc := mock.BatchAlloc()
   512  	task := alloc.Job.TaskGroups[0].Tasks[0]
   513  	task.Driver = "mock_driver"
   514  	task.Config = map[string]interface{}{
   515  		"plugin_exit_after": "1s",
   516  		"run_for":           "5s",
   517  	}
   518  
   519  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   520  	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between prestart calls
   521  	defer cleanup()
   522  
   523  	tr, err := NewTaskRunner(conf)
   524  	require.NoError(err)
   525  
   526  	start := time.Now()
   527  	go tr.Run()
   528  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   529  
   530  	// Wait for the task to be running
   531  	testWaitForTaskToStart(t, tr)
   532  
   533  	// Get the task ID
   534  	tr.stateLock.RLock()
   535  	l := tr.localState.TaskHandle
   536  	require.NotNil(l)
   537  	require.NotNil(l.Config)
   538  	require.NotEmpty(l.Config.ID)
   539  	id := l.Config.ID
   540  	tr.stateLock.RUnlock()
   541  
   542  	// Get the mock driver plugin
   543  	driverPlugin, err := conf.DriverManager.Dispense(mockdriver.PluginID.Name)
   544  	require.NoError(err)
   545  	mockDriver := driverPlugin.(*mockdriver.Driver)
   546  
   547  	// Wait for the task to start
   548  	testutil.WaitForResult(func() (bool, error) {
   549  		// Get the handle and check that it was recovered
   550  		handle := mockDriver.GetHandle(id)
   551  		if handle == nil {
   552  			return false, fmt.Errorf("nil handle")
   553  		}
   554  		if !handle.Recovered {
   555  			return false, fmt.Errorf("handle not recovered")
   556  		}
   557  		return true, nil
   558  	}, func(err error) {
   559  		t.Fatal(err.Error())
   560  	})
   561  
   562  	// Wait for task to complete
   563  	select {
   564  	case <-tr.WaitCh():
   565  	case <-time.After(10 * time.Second):
   566  	}
   567  
   568  	// Ensure that we actually let the task complete
   569  	require.True(time.Now().Sub(start) > 5*time.Second)
   570  
   571  	// Check it finished successfully
   572  	state := tr.TaskState()
   573  	require.True(state.Successful())
   574  }
   575  
   576  // TestTaskRunner_ShutdownDelay asserts services are removed from Consul
   577  // ${shutdown_delay} seconds before killing the process.
   578  func TestTaskRunner_ShutdownDelay(t *testing.T) {
   579  	t.Parallel()
   580  
   581  	alloc := mock.Alloc()
   582  	task := alloc.Job.TaskGroups[0].Tasks[0]
   583  	task.Services[0].Tags = []string{"tag1"}
   584  	task.Services = task.Services[:1] // only need 1 for this test
   585  	task.Driver = "mock_driver"
   586  	task.Config = map[string]interface{}{
   587  		"run_for": "1000s",
   588  	}
   589  
   590  	// No shutdown escape hatch for this delay, so don't set it too high
   591  	task.ShutdownDelay = 1000 * time.Duration(testutil.TestMultiplier()) * time.Millisecond
   592  
   593  	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
   594  	defer cleanup()
   595  
   596  	mockConsul := conf.Consul.(*consul.MockConsulServiceClient)
   597  
   598  	// Wait for the task to start
   599  	testWaitForTaskToStart(t, tr)
   600  
   601  	testutil.WaitForResult(func() (bool, error) {
   602  		ops := mockConsul.GetOps()
   603  		if n := len(ops); n != 1 {
   604  			return false, fmt.Errorf("expected 1 consul operation. Found %d", n)
   605  		}
   606  		return ops[0].Op == "add", fmt.Errorf("consul operation was not a registration: %#v", ops[0])
   607  	}, func(err error) {
   608  		t.Fatalf("err: %v", err)
   609  	})
   610  
   611  	// Asynchronously kill task
   612  	killSent := time.Now()
   613  	killed := make(chan struct{})
   614  	go func() {
   615  		defer close(killed)
   616  		assert.NoError(t, tr.Kill(context.Background(), structs.NewTaskEvent("test")))
   617  	}()
   618  
   619  	// Wait for *2* deregistration calls (due to needing to remove both
   620  	// canary tag variants)
   621  WAIT:
   622  	for {
   623  		ops := mockConsul.GetOps()
   624  		switch n := len(ops); n {
   625  		case 1, 2:
   626  			// Waiting for both deregistration calls
   627  		case 3:
   628  			require.Equalf(t, "remove", ops[1].Op, "expected deregistration but found: %#v", ops[1])
   629  			require.Equalf(t, "remove", ops[2].Op, "expected deregistration but found: %#v", ops[2])
   630  			break WAIT
   631  		default:
   632  			// ?!
   633  			t.Fatalf("unexpected number of consul operations: %d\n%s", n, pretty.Sprint(ops))
   634  
   635  		}
   636  
   637  		select {
   638  		case <-killed:
   639  			t.Fatal("killed while service still registered")
   640  		case <-time.After(10 * time.Millisecond):
   641  		}
   642  	}
   643  
   644  	// Wait for actual exit
   645  	select {
   646  	case <-tr.WaitCh():
   647  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
   648  		t.Fatalf("timeout")
   649  	}
   650  
   651  	<-killed
   652  	killDur := time.Now().Sub(killSent)
   653  	if killDur < task.ShutdownDelay {
   654  		t.Fatalf("task killed before shutdown_delay (killed_after: %s; shutdown_delay: %s",
   655  			killDur, task.ShutdownDelay,
   656  		)
   657  	}
   658  }
   659  
   660  // TestTaskRunner_Dispatch_Payload asserts that a dispatch job runs and the
   661  // payload was written to disk.
   662  func TestTaskRunner_Dispatch_Payload(t *testing.T) {
   663  	t.Parallel()
   664  
   665  	alloc := mock.BatchAlloc()
   666  	task := alloc.Job.TaskGroups[0].Tasks[0]
   667  	task.Driver = "mock_driver"
   668  	task.Config = map[string]interface{}{
   669  		"run_for": "1s",
   670  	}
   671  
   672  	fileName := "test"
   673  	task.DispatchPayload = &structs.DispatchPayloadConfig{
   674  		File: fileName,
   675  	}
   676  	alloc.Job.ParameterizedJob = &structs.ParameterizedJobConfig{}
   677  
   678  	// Add a payload (they're snappy encoded bytes)
   679  	expected := []byte("hello world")
   680  	compressed := snappy.Encode(nil, expected)
   681  	alloc.Job.Payload = compressed
   682  
   683  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
   684  	defer cleanup()
   685  
   686  	// Wait for it to finish
   687  	testutil.WaitForResult(func() (bool, error) {
   688  		ts := tr.TaskState()
   689  		return ts.State == structs.TaskStateDead, fmt.Errorf("%v", ts.State)
   690  	}, func(err error) {
   691  		require.NoError(t, err)
   692  	})
   693  
   694  	// Should have exited successfully
   695  	ts := tr.TaskState()
   696  	require.False(t, ts.Failed)
   697  	require.Zero(t, ts.Restarts)
   698  
   699  	// Check that the file was written to disk properly
   700  	payloadPath := filepath.Join(tr.taskDir.LocalDir, fileName)
   701  	data, err := ioutil.ReadFile(payloadPath)
   702  	require.NoError(t, err)
   703  	require.Equal(t, expected, data)
   704  }
   705  
   706  // TestTaskRunner_SignalFailure asserts that signal errors are properly
   707  // propagated from the driver to TaskRunner.
   708  func TestTaskRunner_SignalFailure(t *testing.T) {
   709  	t.Parallel()
   710  
   711  	alloc := mock.Alloc()
   712  	task := alloc.Job.TaskGroups[0].Tasks[0]
   713  	task.Driver = "mock_driver"
   714  	errMsg := "test forcing failure"
   715  	task.Config = map[string]interface{}{
   716  		"run_for":      "10m",
   717  		"signal_error": errMsg,
   718  	}
   719  
   720  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
   721  	defer cleanup()
   722  
   723  	testWaitForTaskToStart(t, tr)
   724  
   725  	require.EqualError(t, tr.Signal(&structs.TaskEvent{}, "SIGINT"), errMsg)
   726  }
   727  
   728  // TestTaskRunner_RestartTask asserts that restarting a task works and emits a
   729  // Restarting event.
   730  func TestTaskRunner_RestartTask(t *testing.T) {
   731  	t.Parallel()
   732  
   733  	alloc := mock.Alloc()
   734  	task := alloc.Job.TaskGroups[0].Tasks[0]
   735  	task.Driver = "mock_driver"
   736  	task.Config = map[string]interface{}{
   737  		"run_for": "10m",
   738  	}
   739  
   740  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
   741  	defer cleanup()
   742  
   743  	testWaitForTaskToStart(t, tr)
   744  
   745  	// Restart task. Send a RestartSignal event like check watcher. Restart
   746  	// handler emits the Restarting event.
   747  	event := structs.NewTaskEvent(structs.TaskRestartSignal).SetRestartReason("test")
   748  	const fail = false
   749  	tr.Restart(context.Background(), event.Copy(), fail)
   750  
   751  	// Wait for it to restart and be running again
   752  	testutil.WaitForResult(func() (bool, error) {
   753  		ts := tr.TaskState()
   754  		if ts.Restarts != 1 {
   755  			return false, fmt.Errorf("expected 1 restart but found %d\nevents: %s",
   756  				ts.Restarts, pretty.Sprint(ts.Events))
   757  		}
   758  		if ts.State != structs.TaskStateRunning {
   759  			return false, fmt.Errorf("expected running but received %s", ts.State)
   760  		}
   761  		return true, nil
   762  	}, func(err error) {
   763  		require.NoError(t, err)
   764  	})
   765  
   766  	// Assert the expected Restarting event was emitted
   767  	found := false
   768  	events := tr.TaskState().Events
   769  	for _, e := range events {
   770  		if e.Type == structs.TaskRestartSignal {
   771  			found = true
   772  			require.Equal(t, event.Time, e.Time)
   773  			require.Equal(t, event.RestartReason, e.RestartReason)
   774  			require.Contains(t, e.DisplayMessage, event.RestartReason)
   775  		}
   776  	}
   777  	require.True(t, found, "restarting task event not found", pretty.Sprint(events))
   778  }
   779  
   780  // TestTaskRunner_CheckWatcher_Restart asserts that when enabled an unhealthy
   781  // Consul check will cause a task to restart following restart policy rules.
   782  func TestTaskRunner_CheckWatcher_Restart(t *testing.T) {
   783  	t.Parallel()
   784  
   785  	alloc := mock.Alloc()
   786  
   787  	// Make the restart policy fail within this test
   788  	tg := alloc.Job.TaskGroups[0]
   789  	tg.RestartPolicy.Attempts = 2
   790  	tg.RestartPolicy.Interval = 1 * time.Minute
   791  	tg.RestartPolicy.Delay = 10 * time.Millisecond
   792  	tg.RestartPolicy.Mode = structs.RestartPolicyModeFail
   793  
   794  	task := tg.Tasks[0]
   795  	task.Driver = "mock_driver"
   796  	task.Config = map[string]interface{}{
   797  		"run_for": "10m",
   798  	}
   799  
   800  	// Make the task register a check that fails
   801  	task.Services[0].Checks[0] = &structs.ServiceCheck{
   802  		Name:     "test-restarts",
   803  		Type:     structs.ServiceCheckTCP,
   804  		Interval: 50 * time.Millisecond,
   805  		CheckRestart: &structs.CheckRestart{
   806  			Limit: 2,
   807  			Grace: 100 * time.Millisecond,
   808  		},
   809  	}
   810  
   811  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   812  	defer cleanup()
   813  
   814  	// Replace mock Consul ServiceClient, with the real ServiceClient
   815  	// backed by a mock consul whose checks are always unhealthy.
   816  	consulAgent := agentconsul.NewMockAgent()
   817  	consulAgent.SetStatus("critical")
   818  	consulClient := agentconsul.NewServiceClient(consulAgent, conf.Logger, true)
   819  	go consulClient.Run()
   820  	defer consulClient.Shutdown()
   821  
   822  	conf.Consul = consulClient
   823  
   824  	tr, err := NewTaskRunner(conf)
   825  	require.NoError(t, err)
   826  
   827  	expectedEvents := []string{
   828  		"Received",
   829  		"Task Setup",
   830  		"Started",
   831  		"Restart Signaled",
   832  		"Terminated",
   833  		"Restarting",
   834  		"Started",
   835  		"Restart Signaled",
   836  		"Terminated",
   837  		"Restarting",
   838  		"Started",
   839  		"Restart Signaled",
   840  		"Terminated",
   841  		"Not Restarting",
   842  	}
   843  
   844  	// Bump maxEvents so task events aren't dropped
   845  	tr.maxEvents = 100
   846  
   847  	go tr.Run()
   848  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   849  
   850  	// Wait until the task exits. Don't simply wait for it to run as it may
   851  	// get restarted and terminated before the test is able to observe it
   852  	// running.
   853  	select {
   854  	case <-tr.WaitCh():
   855  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
   856  		require.Fail(t, "timeout")
   857  	}
   858  
   859  	state := tr.TaskState()
   860  	actualEvents := make([]string, len(state.Events))
   861  	for i, e := range state.Events {
   862  		actualEvents[i] = string(e.Type)
   863  	}
   864  	require.Equal(t, actualEvents, expectedEvents)
   865  
   866  	require.Equal(t, structs.TaskStateDead, state.State)
   867  	require.True(t, state.Failed, pretty.Sprint(state))
   868  }
   869  
   870  // TestTaskRunner_BlockForVault asserts tasks do not start until a vault token
   871  // is derived.
   872  func TestTaskRunner_BlockForVault(t *testing.T) {
   873  	t.Parallel()
   874  
   875  	alloc := mock.BatchAlloc()
   876  	task := alloc.Job.TaskGroups[0].Tasks[0]
   877  	task.Config = map[string]interface{}{
   878  		"run_for": "0s",
   879  	}
   880  	task.Vault = &structs.Vault{Policies: []string{"default"}}
   881  
   882  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   883  	defer cleanup()
   884  
   885  	// Control when we get a Vault token
   886  	token := "1234"
   887  	waitCh := make(chan struct{})
   888  	handler := func(*structs.Allocation, []string) (map[string]string, error) {
   889  		<-waitCh
   890  		return map[string]string{task.Name: token}, nil
   891  	}
   892  	vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
   893  	vaultClient.DeriveTokenFn = handler
   894  
   895  	tr, err := NewTaskRunner(conf)
   896  	require.NoError(t, err)
   897  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   898  	go tr.Run()
   899  
   900  	// Assert TR blocks on vault token (does *not* exit)
   901  	select {
   902  	case <-tr.WaitCh():
   903  		require.Fail(t, "tr exited before vault unblocked")
   904  	case <-time.After(1 * time.Second):
   905  	}
   906  
   907  	// Assert task state is still Pending
   908  	require.Equal(t, structs.TaskStatePending, tr.TaskState().State)
   909  
   910  	// Unblock vault token
   911  	close(waitCh)
   912  
   913  	// TR should exit now that it's unblocked by vault as its a batch job
   914  	// with 0 sleeping.
   915  	select {
   916  	case <-tr.WaitCh():
   917  	case <-time.After(15 * time.Second * time.Duration(testutil.TestMultiplier())):
   918  		require.Fail(t, "timed out waiting for batch task to exit")
   919  	}
   920  
   921  	// Assert task exited successfully
   922  	finalState := tr.TaskState()
   923  	require.Equal(t, structs.TaskStateDead, finalState.State)
   924  	require.False(t, finalState.Failed)
   925  
   926  	// Check that the token is on disk
   927  	tokenPath := filepath.Join(conf.TaskDir.SecretsDir, vaultTokenFile)
   928  	data, err := ioutil.ReadFile(tokenPath)
   929  	require.NoError(t, err)
   930  	require.Equal(t, token, string(data))
   931  
   932  	// Check the token was revoked
   933  	testutil.WaitForResult(func() (bool, error) {
   934  		if len(vaultClient.StoppedTokens()) != 1 {
   935  			return false, fmt.Errorf("Expected a stopped token %q but found: %v", token, vaultClient.StoppedTokens())
   936  		}
   937  
   938  		if a := vaultClient.StoppedTokens()[0]; a != token {
   939  			return false, fmt.Errorf("got stopped token %q; want %q", a, token)
   940  		}
   941  		return true, nil
   942  	}, func(err error) {
   943  		require.Fail(t, err.Error())
   944  	})
   945  }
   946  
   947  // TestTaskRunner_DeriveToken_Retry asserts that if a recoverable error is
   948  // returned when deriving a vault token a task will continue to block while
   949  // it's retried.
   950  func TestTaskRunner_DeriveToken_Retry(t *testing.T) {
   951  	t.Parallel()
   952  	alloc := mock.BatchAlloc()
   953  	task := alloc.Job.TaskGroups[0].Tasks[0]
   954  	task.Vault = &structs.Vault{Policies: []string{"default"}}
   955  
   956  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   957  	defer cleanup()
   958  
   959  	// Fail on the first attempt to derive a vault token
   960  	token := "1234"
   961  	count := 0
   962  	handler := func(*structs.Allocation, []string) (map[string]string, error) {
   963  		if count > 0 {
   964  			return map[string]string{task.Name: token}, nil
   965  		}
   966  
   967  		count++
   968  		return nil, structs.NewRecoverableError(fmt.Errorf("Want a retry"), true)
   969  	}
   970  	vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
   971  	vaultClient.DeriveTokenFn = handler
   972  
   973  	tr, err := NewTaskRunner(conf)
   974  	require.NoError(t, err)
   975  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   976  	go tr.Run()
   977  
   978  	// Wait for TR to exit and check its state
   979  	select {
   980  	case <-tr.WaitCh():
   981  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
   982  		require.Fail(t, "timed out waiting for task runner to exit")
   983  	}
   984  
   985  	state := tr.TaskState()
   986  	require.Equal(t, structs.TaskStateDead, state.State)
   987  	require.False(t, state.Failed)
   988  
   989  	require.Equal(t, 1, count)
   990  
   991  	// Check that the token is on disk
   992  	tokenPath := filepath.Join(conf.TaskDir.SecretsDir, vaultTokenFile)
   993  	data, err := ioutil.ReadFile(tokenPath)
   994  	require.NoError(t, err)
   995  	require.Equal(t, token, string(data))
   996  
   997  	// Check the token was revoked
   998  	testutil.WaitForResult(func() (bool, error) {
   999  		if len(vaultClient.StoppedTokens()) != 1 {
  1000  			return false, fmt.Errorf("Expected a stopped token: %v", vaultClient.StoppedTokens())
  1001  		}
  1002  
  1003  		if a := vaultClient.StoppedTokens()[0]; a != token {
  1004  			return false, fmt.Errorf("got stopped token %q; want %q", a, token)
  1005  		}
  1006  		return true, nil
  1007  	}, func(err error) {
  1008  		require.Fail(t, err.Error())
  1009  	})
  1010  }
  1011  
  1012  // TestTaskRunner_DeriveToken_Unrecoverable asserts that an unrecoverable error
  1013  // from deriving a vault token will fail a task.
  1014  func TestTaskRunner_DeriveToken_Unrecoverable(t *testing.T) {
  1015  	t.Parallel()
  1016  
  1017  	// Use a batch job with no restarts
  1018  	alloc := mock.BatchAlloc()
  1019  	tg := alloc.Job.TaskGroups[0]
  1020  	tg.RestartPolicy.Attempts = 0
  1021  	tg.RestartPolicy.Interval = 0
  1022  	tg.RestartPolicy.Delay = 0
  1023  	tg.RestartPolicy.Mode = structs.RestartPolicyModeFail
  1024  	task := tg.Tasks[0]
  1025  	task.Config = map[string]interface{}{
  1026  		"run_for": "0s",
  1027  	}
  1028  	task.Vault = &structs.Vault{Policies: []string{"default"}}
  1029  
  1030  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1031  	defer cleanup()
  1032  
  1033  	// Error the token derivation
  1034  	vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
  1035  	vaultClient.SetDeriveTokenError(alloc.ID, []string{task.Name}, fmt.Errorf("Non recoverable"))
  1036  
  1037  	tr, err := NewTaskRunner(conf)
  1038  	require.NoError(t, err)
  1039  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1040  	go tr.Run()
  1041  
  1042  	// Wait for the task to die
  1043  	select {
  1044  	case <-tr.WaitCh():
  1045  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1046  		require.Fail(t, "timed out waiting for task runner to fail")
  1047  	}
  1048  
  1049  	// Task should be dead and last event should have failed task
  1050  	state := tr.TaskState()
  1051  	require.Equal(t, structs.TaskStateDead, state.State)
  1052  	require.True(t, state.Failed)
  1053  	require.Len(t, state.Events, 3)
  1054  	require.True(t, state.Events[2].FailsTask)
  1055  }
  1056  
  1057  // TestTaskRunner_Download_ChrootExec asserts that downloaded artifacts may be
  1058  // executed in a chroot.
  1059  func TestTaskRunner_Download_ChrootExec(t *testing.T) {
  1060  	t.Parallel()
  1061  	ctestutil.ExecCompatible(t)
  1062  
  1063  	ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("."))))
  1064  	defer ts.Close()
  1065  
  1066  	// Create a task that downloads a script and executes it.
  1067  	alloc := mock.BatchAlloc()
  1068  	alloc.Job.TaskGroups[0].RestartPolicy = &structs.RestartPolicy{}
  1069  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1070  	task.Driver = "exec"
  1071  	task.Config = map[string]interface{}{
  1072  		"command": "noop.sh",
  1073  	}
  1074  	task.Artifacts = []*structs.TaskArtifact{
  1075  		{
  1076  			GetterSource: fmt.Sprintf("%s/testdata/noop.sh", ts.URL),
  1077  			GetterMode:   "file",
  1078  			RelativeDest: "noop.sh",
  1079  		},
  1080  	}
  1081  
  1082  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
  1083  	defer cleanup()
  1084  
  1085  	// Wait for task to run and exit
  1086  	select {
  1087  	case <-tr.WaitCh():
  1088  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1089  		require.Fail(t, "timed out waiting for task runner to exit")
  1090  	}
  1091  
  1092  	state := tr.TaskState()
  1093  	require.Equal(t, structs.TaskStateDead, state.State)
  1094  	require.False(t, state.Failed)
  1095  }
  1096  
  1097  // TestTaskRunner_Download_Exec asserts that downloaded artifacts may be
  1098  // executed in a driver without filesystem isolation.
  1099  func TestTaskRunner_Download_RawExec(t *testing.T) {
  1100  	t.Parallel()
  1101  
  1102  	ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("."))))
  1103  	defer ts.Close()
  1104  
  1105  	// Create a task that downloads a script and executes it.
  1106  	alloc := mock.BatchAlloc()
  1107  	alloc.Job.TaskGroups[0].RestartPolicy = &structs.RestartPolicy{}
  1108  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1109  	task.Driver = "raw_exec"
  1110  	task.Config = map[string]interface{}{
  1111  		"command": "noop.sh",
  1112  	}
  1113  	task.Artifacts = []*structs.TaskArtifact{
  1114  		{
  1115  			GetterSource: fmt.Sprintf("%s/testdata/noop.sh", ts.URL),
  1116  			GetterMode:   "file",
  1117  			RelativeDest: "noop.sh",
  1118  		},
  1119  	}
  1120  
  1121  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
  1122  	defer cleanup()
  1123  
  1124  	// Wait for task to run and exit
  1125  	select {
  1126  	case <-tr.WaitCh():
  1127  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1128  		require.Fail(t, "timed out waiting for task runner to exit")
  1129  	}
  1130  
  1131  	state := tr.TaskState()
  1132  	require.Equal(t, structs.TaskStateDead, state.State)
  1133  	require.False(t, state.Failed)
  1134  }
  1135  
  1136  // TestTaskRunner_Download_List asserts that multiple artificats are downloaded
  1137  // before a task is run.
  1138  func TestTaskRunner_Download_List(t *testing.T) {
  1139  	t.Parallel()
  1140  	ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("."))))
  1141  	defer ts.Close()
  1142  
  1143  	// Create an allocation that has a task with a list of artifacts.
  1144  	alloc := mock.BatchAlloc()
  1145  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1146  	f1 := "task_runner_test.go"
  1147  	f2 := "task_runner.go"
  1148  	artifact1 := structs.TaskArtifact{
  1149  		GetterSource: fmt.Sprintf("%s/%s", ts.URL, f1),
  1150  	}
  1151  	artifact2 := structs.TaskArtifact{
  1152  		GetterSource: fmt.Sprintf("%s/%s", ts.URL, f2),
  1153  	}
  1154  	task.Artifacts = []*structs.TaskArtifact{&artifact1, &artifact2}
  1155  
  1156  	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
  1157  	defer cleanup()
  1158  
  1159  	// Wait for task to run and exit
  1160  	select {
  1161  	case <-tr.WaitCh():
  1162  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1163  		require.Fail(t, "timed out waiting for task runner to exit")
  1164  	}
  1165  
  1166  	state := tr.TaskState()
  1167  	require.Equal(t, structs.TaskStateDead, state.State)
  1168  	require.False(t, state.Failed)
  1169  
  1170  	require.Len(t, state.Events, 5)
  1171  	assert.Equal(t, structs.TaskReceived, state.Events[0].Type)
  1172  	assert.Equal(t, structs.TaskSetup, state.Events[1].Type)
  1173  	assert.Equal(t, structs.TaskDownloadingArtifacts, state.Events[2].Type)
  1174  	assert.Equal(t, structs.TaskStarted, state.Events[3].Type)
  1175  	assert.Equal(t, structs.TaskTerminated, state.Events[4].Type)
  1176  
  1177  	// Check that both files exist.
  1178  	_, err := os.Stat(filepath.Join(conf.TaskDir.Dir, f1))
  1179  	require.NoErrorf(t, err, "%v not downloaded", f1)
  1180  
  1181  	_, err = os.Stat(filepath.Join(conf.TaskDir.Dir, f2))
  1182  	require.NoErrorf(t, err, "%v not downloaded", f2)
  1183  }
  1184  
  1185  // TestTaskRunner_Download_Retries asserts that failed artifact downloads are
  1186  // retried according to the task's restart policy.
  1187  func TestTaskRunner_Download_Retries(t *testing.T) {
  1188  	t.Parallel()
  1189  
  1190  	// Create an allocation that has a task with bad artifacts.
  1191  	alloc := mock.BatchAlloc()
  1192  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1193  	artifact := structs.TaskArtifact{
  1194  		GetterSource: "http://127.0.0.1:0/foo/bar/baz",
  1195  	}
  1196  	task.Artifacts = []*structs.TaskArtifact{&artifact}
  1197  
  1198  	// Make the restart policy retry once
  1199  	alloc.Job.TaskGroups[0].RestartPolicy = &structs.RestartPolicy{
  1200  		Attempts: 1,
  1201  		Interval: 10 * time.Minute,
  1202  		Delay:    1 * time.Second,
  1203  		Mode:     structs.RestartPolicyModeFail,
  1204  	}
  1205  
  1206  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
  1207  	defer cleanup()
  1208  
  1209  	select {
  1210  	case <-tr.WaitCh():
  1211  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1212  		require.Fail(t, "timed out waiting for task to exit")
  1213  	}
  1214  
  1215  	state := tr.TaskState()
  1216  	require.Equal(t, structs.TaskStateDead, state.State)
  1217  	require.True(t, state.Failed)
  1218  	require.Len(t, state.Events, 8, pretty.Sprint(state.Events))
  1219  	require.Equal(t, structs.TaskReceived, state.Events[0].Type)
  1220  	require.Equal(t, structs.TaskSetup, state.Events[1].Type)
  1221  	require.Equal(t, structs.TaskDownloadingArtifacts, state.Events[2].Type)
  1222  	require.Equal(t, structs.TaskArtifactDownloadFailed, state.Events[3].Type)
  1223  	require.Equal(t, structs.TaskRestarting, state.Events[4].Type)
  1224  	require.Equal(t, structs.TaskDownloadingArtifacts, state.Events[5].Type)
  1225  	require.Equal(t, structs.TaskArtifactDownloadFailed, state.Events[6].Type)
  1226  	require.Equal(t, structs.TaskNotRestarting, state.Events[7].Type)
  1227  }
  1228  
  1229  // TestTaskRunner_DriverNetwork asserts that a driver's network is properly
  1230  // used in services and checks.
  1231  func TestTaskRunner_DriverNetwork(t *testing.T) {
  1232  	t.Parallel()
  1233  
  1234  	alloc := mock.Alloc()
  1235  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1236  	task.Driver = "mock_driver"
  1237  	task.Config = map[string]interface{}{
  1238  		"run_for":         "100s",
  1239  		"driver_ip":       "10.1.2.3",
  1240  		"driver_port_map": "http:80",
  1241  	}
  1242  
  1243  	// Create services and checks with custom address modes to exercise
  1244  	// address detection logic
  1245  	task.Services = []*structs.Service{
  1246  		{
  1247  			Name:        "host-service",
  1248  			PortLabel:   "http",
  1249  			AddressMode: "host",
  1250  			Checks: []*structs.ServiceCheck{
  1251  				{
  1252  					Name:        "driver-check",
  1253  					Type:        "tcp",
  1254  					PortLabel:   "1234",
  1255  					AddressMode: "driver",
  1256  				},
  1257  			},
  1258  		},
  1259  		{
  1260  			Name:        "driver-service",
  1261  			PortLabel:   "5678",
  1262  			AddressMode: "driver",
  1263  			Checks: []*structs.ServiceCheck{
  1264  				{
  1265  					Name:      "host-check",
  1266  					Type:      "tcp",
  1267  					PortLabel: "http",
  1268  				},
  1269  				{
  1270  					Name:        "driver-label-check",
  1271  					Type:        "tcp",
  1272  					PortLabel:   "http",
  1273  					AddressMode: "driver",
  1274  				},
  1275  			},
  1276  		},
  1277  	}
  1278  
  1279  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1280  	defer cleanup()
  1281  
  1282  	// Use a mock agent to test for services
  1283  	consulAgent := agentconsul.NewMockAgent()
  1284  	consulClient := agentconsul.NewServiceClient(consulAgent, conf.Logger, true)
  1285  	defer consulClient.Shutdown()
  1286  	go consulClient.Run()
  1287  
  1288  	conf.Consul = consulClient
  1289  
  1290  	tr, err := NewTaskRunner(conf)
  1291  	require.NoError(t, err)
  1292  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1293  	go tr.Run()
  1294  
  1295  	// Wait for the task to start
  1296  	testWaitForTaskToStart(t, tr)
  1297  
  1298  	testutil.WaitForResult(func() (bool, error) {
  1299  		services, _ := consulAgent.Services()
  1300  		if n := len(services); n != 2 {
  1301  			return false, fmt.Errorf("expected 2 services, but found %d", n)
  1302  		}
  1303  		for _, s := range services {
  1304  			switch s.Service {
  1305  			case "host-service":
  1306  				if expected := "192.168.0.100"; s.Address != expected {
  1307  					return false, fmt.Errorf("expected host-service to have IP=%s but found %s",
  1308  						expected, s.Address)
  1309  				}
  1310  			case "driver-service":
  1311  				if expected := "10.1.2.3"; s.Address != expected {
  1312  					return false, fmt.Errorf("expected driver-service to have IP=%s but found %s",
  1313  						expected, s.Address)
  1314  				}
  1315  				if expected := 5678; s.Port != expected {
  1316  					return false, fmt.Errorf("expected driver-service to have port=%d but found %d",
  1317  						expected, s.Port)
  1318  				}
  1319  			default:
  1320  				return false, fmt.Errorf("unexpected service: %q", s.Service)
  1321  			}
  1322  
  1323  		}
  1324  
  1325  		checks := consulAgent.CheckRegs()
  1326  		if n := len(checks); n != 3 {
  1327  			return false, fmt.Errorf("expected 3 checks, but found %d", n)
  1328  		}
  1329  		for _, check := range checks {
  1330  			switch check.Name {
  1331  			case "driver-check":
  1332  				if expected := "10.1.2.3:1234"; check.TCP != expected {
  1333  					return false, fmt.Errorf("expected driver-check to have address %q but found %q", expected, check.TCP)
  1334  				}
  1335  			case "driver-label-check":
  1336  				if expected := "10.1.2.3:80"; check.TCP != expected {
  1337  					return false, fmt.Errorf("expected driver-label-check to have address %q but found %q", expected, check.TCP)
  1338  				}
  1339  			case "host-check":
  1340  				if expected := "192.168.0.100:"; !strings.HasPrefix(check.TCP, expected) {
  1341  					return false, fmt.Errorf("expected host-check to have address start with %q but found %q", expected, check.TCP)
  1342  				}
  1343  			default:
  1344  				return false, fmt.Errorf("unexpected check: %q", check.Name)
  1345  			}
  1346  		}
  1347  
  1348  		return true, nil
  1349  	}, func(err error) {
  1350  		services, _ := consulAgent.Services()
  1351  		for _, s := range services {
  1352  			t.Logf(pretty.Sprint("Service: ", s))
  1353  		}
  1354  		for _, c := range consulAgent.CheckRegs() {
  1355  			t.Logf(pretty.Sprint("Check:   ", c))
  1356  		}
  1357  		require.NoError(t, err)
  1358  	})
  1359  }
  1360  
  1361  // TestTaskRunner_RestartSignalTask_NotRunning asserts resilience to failures
  1362  // when a restart or signal is triggered and the task is not running.
  1363  func TestTaskRunner_RestartSignalTask_NotRunning(t *testing.T) {
  1364  	t.Parallel()
  1365  
  1366  	alloc := mock.BatchAlloc()
  1367  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1368  	task.Driver = "mock_driver"
  1369  	task.Config = map[string]interface{}{
  1370  		"run_for": "0s",
  1371  	}
  1372  
  1373  	// Use vault to block the start
  1374  	task.Vault = &structs.Vault{Policies: []string{"default"}}
  1375  
  1376  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1377  	defer cleanup()
  1378  
  1379  	// Control when we get a Vault token
  1380  	waitCh := make(chan struct{}, 1)
  1381  	defer close(waitCh)
  1382  	handler := func(*structs.Allocation, []string) (map[string]string, error) {
  1383  		<-waitCh
  1384  		return map[string]string{task.Name: "1234"}, nil
  1385  	}
  1386  	vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
  1387  	vaultClient.DeriveTokenFn = handler
  1388  
  1389  	tr, err := NewTaskRunner(conf)
  1390  	require.NoError(t, err)
  1391  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1392  	go tr.Run()
  1393  
  1394  	select {
  1395  	case <-tr.WaitCh():
  1396  		require.Fail(t, "unexpected exit")
  1397  	case <-time.After(1 * time.Second):
  1398  	}
  1399  
  1400  	// Send a signal and restart
  1401  	err = tr.Signal(structs.NewTaskEvent("don't panic"), "QUIT")
  1402  	require.EqualError(t, err, ErrTaskNotRunning.Error())
  1403  
  1404  	// Send a restart
  1405  	err = tr.Restart(context.Background(), structs.NewTaskEvent("don't panic"), false)
  1406  	require.EqualError(t, err, ErrTaskNotRunning.Error())
  1407  
  1408  	// Unblock and let it finish
  1409  	waitCh <- struct{}{}
  1410  
  1411  	select {
  1412  	case <-tr.WaitCh():
  1413  	case <-time.After(10 * time.Second):
  1414  		require.Fail(t, "timed out waiting for task to complete")
  1415  	}
  1416  
  1417  	// Assert the task ran and never restarted
  1418  	state := tr.TaskState()
  1419  	require.Equal(t, structs.TaskStateDead, state.State)
  1420  	require.False(t, state.Failed)
  1421  	require.Len(t, state.Events, 4, pretty.Sprint(state.Events))
  1422  	require.Equal(t, structs.TaskReceived, state.Events[0].Type)
  1423  	require.Equal(t, structs.TaskSetup, state.Events[1].Type)
  1424  	require.Equal(t, structs.TaskStarted, state.Events[2].Type)
  1425  	require.Equal(t, structs.TaskTerminated, state.Events[3].Type)
  1426  }
  1427  
  1428  // TestTaskRunner_Run_RecoverableStartError asserts tasks are restarted if they
  1429  // return a recoverable error from StartTask.
  1430  func TestTaskRunner_Run_RecoverableStartError(t *testing.T) {
  1431  	t.Parallel()
  1432  
  1433  	alloc := mock.BatchAlloc()
  1434  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1435  	task.Config = map[string]interface{}{
  1436  		"start_error":             "driver failure",
  1437  		"start_error_recoverable": true,
  1438  	}
  1439  
  1440  	// Make the restart policy retry once
  1441  	alloc.Job.TaskGroups[0].RestartPolicy = &structs.RestartPolicy{
  1442  		Attempts: 1,
  1443  		Interval: 10 * time.Minute,
  1444  		Delay:    0,
  1445  		Mode:     structs.RestartPolicyModeFail,
  1446  	}
  1447  
  1448  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
  1449  	defer cleanup()
  1450  
  1451  	select {
  1452  	case <-tr.WaitCh():
  1453  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1454  		require.Fail(t, "timed out waiting for task to exit")
  1455  	}
  1456  
  1457  	state := tr.TaskState()
  1458  	require.Equal(t, structs.TaskStateDead, state.State)
  1459  	require.True(t, state.Failed)
  1460  	require.Len(t, state.Events, 6, pretty.Sprint(state.Events))
  1461  	require.Equal(t, structs.TaskReceived, state.Events[0].Type)
  1462  	require.Equal(t, structs.TaskSetup, state.Events[1].Type)
  1463  	require.Equal(t, structs.TaskDriverFailure, state.Events[2].Type)
  1464  	require.Equal(t, structs.TaskRestarting, state.Events[3].Type)
  1465  	require.Equal(t, structs.TaskDriverFailure, state.Events[4].Type)
  1466  	require.Equal(t, structs.TaskNotRestarting, state.Events[5].Type)
  1467  }
  1468  
  1469  // TestTaskRunner_Template_Artifact asserts that tasks can use artifacts as templates.
  1470  func TestTaskRunner_Template_Artifact(t *testing.T) {
  1471  	t.Parallel()
  1472  
  1473  	ts := httptest.NewServer(http.FileServer(http.Dir(".")))
  1474  	defer ts.Close()
  1475  
  1476  	alloc := mock.BatchAlloc()
  1477  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1478  	f1 := "task_runner.go"
  1479  	f2 := "test"
  1480  	task.Artifacts = []*structs.TaskArtifact{
  1481  		{GetterSource: fmt.Sprintf("%s/%s", ts.URL, f1)},
  1482  	}
  1483  	task.Templates = []*structs.Template{
  1484  		{
  1485  			SourcePath: f1,
  1486  			DestPath:   "local/test",
  1487  			ChangeMode: structs.TemplateChangeModeNoop,
  1488  		},
  1489  	}
  1490  
  1491  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1492  	defer cleanup()
  1493  
  1494  	tr, err := NewTaskRunner(conf)
  1495  	require.NoError(t, err)
  1496  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1497  	go tr.Run()
  1498  
  1499  	// Wait for task to run and exit
  1500  	select {
  1501  	case <-tr.WaitCh():
  1502  	case <-time.After(15 * time.Second * time.Duration(testutil.TestMultiplier())):
  1503  		require.Fail(t, "timed out waiting for task runner to exit")
  1504  	}
  1505  
  1506  	state := tr.TaskState()
  1507  	require.Equal(t, structs.TaskStateDead, state.State)
  1508  	require.True(t, state.Successful())
  1509  	require.False(t, state.Failed)
  1510  
  1511  	artifactsDownloaded := false
  1512  	for _, e := range state.Events {
  1513  		if e.Type == structs.TaskDownloadingArtifacts {
  1514  			artifactsDownloaded = true
  1515  		}
  1516  	}
  1517  	assert.True(t, artifactsDownloaded, "expected artifacts downloaded events")
  1518  
  1519  	// Check that both files exist.
  1520  	_, err = os.Stat(filepath.Join(conf.TaskDir.Dir, f1))
  1521  	require.NoErrorf(t, err, "%v not downloaded", f1)
  1522  
  1523  	_, err = os.Stat(filepath.Join(conf.TaskDir.LocalDir, f2))
  1524  	require.NoErrorf(t, err, "%v not rendered", f2)
  1525  }
  1526  
  1527  // TestTaskRunner_Template_NewVaultToken asserts that a new vault token is
  1528  // created when rendering template and that it is revoked on alloc completion
  1529  func TestTaskRunner_Template_NewVaultToken(t *testing.T) {
  1530  	t.Parallel()
  1531  
  1532  	alloc := mock.BatchAlloc()
  1533  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1534  	task.Templates = []*structs.Template{
  1535  		{
  1536  			EmbeddedTmpl: `{{key "foo"}}`,
  1537  			DestPath:     "local/test",
  1538  			ChangeMode:   structs.TemplateChangeModeNoop,
  1539  		},
  1540  	}
  1541  	task.Vault = &structs.Vault{Policies: []string{"default"}}
  1542  
  1543  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1544  	defer cleanup()
  1545  
  1546  	tr, err := NewTaskRunner(conf)
  1547  	require.NoError(t, err)
  1548  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1549  	go tr.Run()
  1550  
  1551  	// Wait for a Vault token
  1552  	var token string
  1553  	testutil.WaitForResult(func() (bool, error) {
  1554  		token = tr.getVaultToken()
  1555  
  1556  		if token == "" {
  1557  			return false, fmt.Errorf("No Vault token")
  1558  		}
  1559  
  1560  		return true, nil
  1561  	}, func(err error) {
  1562  		require.NoError(t, err)
  1563  	})
  1564  
  1565  	vault := conf.Vault.(*vaultclient.MockVaultClient)
  1566  	renewalCh, ok := vault.RenewTokens()[token]
  1567  	require.True(t, ok, "no renewal channel for token")
  1568  
  1569  	renewalCh <- fmt.Errorf("Test killing")
  1570  	close(renewalCh)
  1571  
  1572  	var token2 string
  1573  	testutil.WaitForResult(func() (bool, error) {
  1574  		token2 = tr.getVaultToken()
  1575  
  1576  		if token2 == "" {
  1577  			return false, fmt.Errorf("No Vault token")
  1578  		}
  1579  
  1580  		if token2 == token {
  1581  			return false, fmt.Errorf("token wasn't recreated")
  1582  		}
  1583  
  1584  		return true, nil
  1585  	}, func(err error) {
  1586  		require.NoError(t, err)
  1587  	})
  1588  
  1589  	// Check the token was revoked
  1590  	testutil.WaitForResult(func() (bool, error) {
  1591  		if len(vault.StoppedTokens()) != 1 {
  1592  			return false, fmt.Errorf("Expected a stopped token: %v", vault.StoppedTokens())
  1593  		}
  1594  
  1595  		if a := vault.StoppedTokens()[0]; a != token {
  1596  			return false, fmt.Errorf("got stopped token %q; want %q", a, token)
  1597  		}
  1598  
  1599  		return true, nil
  1600  	}, func(err error) {
  1601  		require.NoError(t, err)
  1602  	})
  1603  
  1604  }
  1605  
  1606  // TestTaskRunner_VaultManager_Restart asserts that the alloc is restarted when the alloc
  1607  // derived vault token expires, when task is configured with Restart change mode
  1608  func TestTaskRunner_VaultManager_Restart(t *testing.T) {
  1609  	t.Parallel()
  1610  
  1611  	alloc := mock.BatchAlloc()
  1612  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1613  	task.Config = map[string]interface{}{
  1614  		"run_for": "10s",
  1615  	}
  1616  	task.Vault = &structs.Vault{
  1617  		Policies:   []string{"default"},
  1618  		ChangeMode: structs.VaultChangeModeRestart,
  1619  	}
  1620  
  1621  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1622  	defer cleanup()
  1623  
  1624  	tr, err := NewTaskRunner(conf)
  1625  	require.NoError(t, err)
  1626  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1627  	go tr.Run()
  1628  
  1629  	testWaitForTaskToStart(t, tr)
  1630  
  1631  	tr.vaultTokenLock.Lock()
  1632  	token := tr.vaultToken
  1633  	tr.vaultTokenLock.Unlock()
  1634  
  1635  	require.NotEmpty(t, token)
  1636  
  1637  	vault := conf.Vault.(*vaultclient.MockVaultClient)
  1638  	renewalCh, ok := vault.RenewTokens()[token]
  1639  	require.True(t, ok, "no renewal channel for token")
  1640  
  1641  	renewalCh <- fmt.Errorf("Test killing")
  1642  	close(renewalCh)
  1643  
  1644  	testutil.WaitForResult(func() (bool, error) {
  1645  		state := tr.TaskState()
  1646  
  1647  		if len(state.Events) == 0 {
  1648  			return false, fmt.Errorf("no events yet")
  1649  		}
  1650  
  1651  		foundRestartSignal, foundRestarting := false, false
  1652  		for _, e := range state.Events {
  1653  			switch e.Type {
  1654  			case structs.TaskRestartSignal:
  1655  				foundRestartSignal = true
  1656  			case structs.TaskRestarting:
  1657  				foundRestarting = true
  1658  			}
  1659  		}
  1660  
  1661  		if !foundRestartSignal {
  1662  			return false, fmt.Errorf("no restart signal event yet: %#v", state.Events)
  1663  		}
  1664  
  1665  		if !foundRestarting {
  1666  			return false, fmt.Errorf("no restarting event yet: %#v", state.Events)
  1667  		}
  1668  
  1669  		lastEvent := state.Events[len(state.Events)-1]
  1670  		if lastEvent.Type != structs.TaskStarted {
  1671  			return false, fmt.Errorf("expected last event to be task starting but was %#v", lastEvent)
  1672  		}
  1673  		return true, nil
  1674  	}, func(err error) {
  1675  		require.NoError(t, err)
  1676  	})
  1677  }
  1678  
  1679  // TestTaskRunner_VaultManager_Signal asserts that the alloc is signalled when the alloc
  1680  // derived vault token expires, when task is configured with signal change mode
  1681  func TestTaskRunner_VaultManager_Signal(t *testing.T) {
  1682  	t.Parallel()
  1683  
  1684  	alloc := mock.BatchAlloc()
  1685  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1686  	task.Config = map[string]interface{}{
  1687  		"run_for": "10s",
  1688  	}
  1689  	task.Vault = &structs.Vault{
  1690  		Policies:     []string{"default"},
  1691  		ChangeMode:   structs.VaultChangeModeSignal,
  1692  		ChangeSignal: "SIGUSR1",
  1693  	}
  1694  
  1695  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1696  	defer cleanup()
  1697  
  1698  	tr, err := NewTaskRunner(conf)
  1699  	require.NoError(t, err)
  1700  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1701  	go tr.Run()
  1702  
  1703  	testWaitForTaskToStart(t, tr)
  1704  
  1705  	tr.vaultTokenLock.Lock()
  1706  	token := tr.vaultToken
  1707  	tr.vaultTokenLock.Unlock()
  1708  
  1709  	require.NotEmpty(t, token)
  1710  
  1711  	vault := conf.Vault.(*vaultclient.MockVaultClient)
  1712  	renewalCh, ok := vault.RenewTokens()[token]
  1713  	require.True(t, ok, "no renewal channel for token")
  1714  
  1715  	renewalCh <- fmt.Errorf("Test killing")
  1716  	close(renewalCh)
  1717  
  1718  	testutil.WaitForResult(func() (bool, error) {
  1719  		state := tr.TaskState()
  1720  
  1721  		if len(state.Events) == 0 {
  1722  			return false, fmt.Errorf("no events yet")
  1723  		}
  1724  
  1725  		foundSignaling := false
  1726  		for _, e := range state.Events {
  1727  			if e.Type == structs.TaskSignaling {
  1728  				foundSignaling = true
  1729  			}
  1730  		}
  1731  
  1732  		if !foundSignaling {
  1733  			return false, fmt.Errorf("no signaling event yet: %#v", state.Events)
  1734  		}
  1735  
  1736  		return true, nil
  1737  	}, func(err error) {
  1738  		require.NoError(t, err)
  1739  	})
  1740  
  1741  }
  1742  
  1743  // TestTaskRunner_UnregisterConsul_Retries asserts a task is unregistered from
  1744  // Consul when waiting to be retried.
  1745  func TestTaskRunner_UnregisterConsul_Retries(t *testing.T) {
  1746  	t.Parallel()
  1747  
  1748  	alloc := mock.Alloc()
  1749  	// Make the restart policy try one ctx.update
  1750  	alloc.Job.TaskGroups[0].RestartPolicy = &structs.RestartPolicy{
  1751  		Attempts: 1,
  1752  		Interval: 10 * time.Minute,
  1753  		Delay:    time.Nanosecond,
  1754  		Mode:     structs.RestartPolicyModeFail,
  1755  	}
  1756  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1757  	task.Driver = "mock_driver"
  1758  	task.Config = map[string]interface{}{
  1759  		"exit_code": "1",
  1760  		"run_for":   "1ns",
  1761  	}
  1762  
  1763  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1764  	defer cleanup()
  1765  
  1766  	tr, err := NewTaskRunner(conf)
  1767  	require.NoError(t, err)
  1768  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1769  	tr.Run()
  1770  
  1771  	state := tr.TaskState()
  1772  	require.Equal(t, structs.TaskStateDead, state.State)
  1773  
  1774  	consul := conf.Consul.(*consulapi.MockConsulServiceClient)
  1775  	consulOps := consul.GetOps()
  1776  	require.Len(t, consulOps, 6)
  1777  
  1778  	// Initial add
  1779  	require.Equal(t, "add", consulOps[0].Op)
  1780  
  1781  	// Removing canary and non-canary entries on first exit
  1782  	require.Equal(t, "remove", consulOps[1].Op)
  1783  	require.Equal(t, "remove", consulOps[2].Op)
  1784  
  1785  	// Second add on retry
  1786  	require.Equal(t, "add", consulOps[3].Op)
  1787  
  1788  	// Removing canary and non-canary entries on retry
  1789  	require.Equal(t, "remove", consulOps[4].Op)
  1790  	require.Equal(t, "remove", consulOps[5].Op)
  1791  }
  1792  
  1793  // testWaitForTaskToStart waits for the task to be running or fails the test
  1794  func testWaitForTaskToStart(t *testing.T, tr *TaskRunner) {
  1795  	testutil.WaitForResult(func() (bool, error) {
  1796  		ts := tr.TaskState()
  1797  		return ts.State == structs.TaskStateRunning, fmt.Errorf("%v", ts.State)
  1798  	}, func(err error) {
  1799  		require.NoError(t, err)
  1800  	})
  1801  }