github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/allocrunner/taskrunner/task_runner_test.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"os"
    11  	"path/filepath"
    12  	"strings"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/golang/snappy"
    17  	"github.com/hashicorp/nomad/ci"
    18  	"github.com/hashicorp/nomad/client/allocdir"
    19  	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
    20  	"github.com/hashicorp/nomad/client/allocrunner/taskrunner/getter"
    21  	"github.com/hashicorp/nomad/client/config"
    22  	consulapi "github.com/hashicorp/nomad/client/consul"
    23  	"github.com/hashicorp/nomad/client/devicemanager"
    24  	"github.com/hashicorp/nomad/client/lib/cgutil"
    25  	"github.com/hashicorp/nomad/client/pluginmanager/drivermanager"
    26  	regMock "github.com/hashicorp/nomad/client/serviceregistration/mock"
    27  	"github.com/hashicorp/nomad/client/serviceregistration/wrapper"
    28  	cstate "github.com/hashicorp/nomad/client/state"
    29  	ctestutil "github.com/hashicorp/nomad/client/testutil"
    30  	"github.com/hashicorp/nomad/client/vaultclient"
    31  	agentconsul "github.com/hashicorp/nomad/command/agent/consul"
    32  	mockdriver "github.com/hashicorp/nomad/drivers/mock"
    33  	"github.com/hashicorp/nomad/drivers/rawexec"
    34  	"github.com/hashicorp/nomad/helper/pointer"
    35  	"github.com/hashicorp/nomad/helper/testlog"
    36  	"github.com/hashicorp/nomad/helper/uuid"
    37  	"github.com/hashicorp/nomad/nomad/mock"
    38  	"github.com/hashicorp/nomad/nomad/structs"
    39  	"github.com/hashicorp/nomad/plugins/device"
    40  	"github.com/hashicorp/nomad/plugins/drivers"
    41  	"github.com/hashicorp/nomad/testutil"
    42  	"github.com/kr/pretty"
    43  	"github.com/stretchr/testify/assert"
    44  	"github.com/stretchr/testify/require"
    45  )
    46  
    47  type MockTaskStateUpdater struct {
    48  	ch chan struct{}
    49  }
    50  
    51  func NewMockTaskStateUpdater() *MockTaskStateUpdater {
    52  	return &MockTaskStateUpdater{
    53  		ch: make(chan struct{}, 1),
    54  	}
    55  }
    56  
    57  func (m *MockTaskStateUpdater) TaskStateUpdated() {
    58  	select {
    59  	case m.ch <- struct{}{}:
    60  	default:
    61  	}
    62  }
    63  
    64  // testTaskRunnerConfig returns a taskrunner.Config for the given alloc+task
    65  // plus a cleanup func.
    66  func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName string) (*Config, func()) {
    67  	logger := testlog.HCLogger(t)
    68  	clientConf, cleanup := config.TestClientConfig(t)
    69  
    70  	// Find the task
    71  	var thisTask *structs.Task
    72  	for _, tg := range alloc.Job.TaskGroups {
    73  		for _, task := range tg.Tasks {
    74  			if task.Name == taskName {
    75  				if thisTask != nil {
    76  					cleanup()
    77  					t.Fatalf("multiple tasks named %q; cannot use this helper", taskName)
    78  				}
    79  				thisTask = task
    80  			}
    81  		}
    82  	}
    83  	if thisTask == nil {
    84  		cleanup()
    85  		t.Fatalf("could not find task %q", taskName)
    86  	}
    87  
    88  	// Create the alloc dir + task dir
    89  	allocDir := allocdir.NewAllocDir(logger, clientConf.AllocDir, alloc.ID)
    90  	if err := allocDir.Build(); err != nil {
    91  		cleanup()
    92  		t.Fatalf("error building alloc dir: %v", err)
    93  	}
    94  	taskDir := allocDir.NewTaskDir(taskName)
    95  
    96  	// Compute the name of the v2 cgroup in case we need it in creation, configuration, and cleanup
    97  	cgroup := filepath.Join(cgutil.CgroupRoot, "testing.slice", cgutil.CgroupScope(alloc.ID, taskName))
    98  
    99  	// Create the cgroup if we are in v2 mode
   100  	if cgutil.UseV2 {
   101  		if err := os.MkdirAll(cgroup, 0755); err != nil {
   102  			t.Fatalf("failed to setup v2 cgroup for test: %v:", err)
   103  		}
   104  	}
   105  
   106  	trCleanup := func() {
   107  		if err := allocDir.Destroy(); err != nil {
   108  			t.Logf("error destroying alloc dir: %v", err)
   109  		}
   110  
   111  		// Cleanup the cgroup if we are in v2 mode
   112  		if cgutil.UseV2 {
   113  			_ = os.RemoveAll(cgroup)
   114  		}
   115  
   116  		cleanup()
   117  	}
   118  
   119  	shutdownDelayCtx, shutdownDelayCancelFn := context.WithCancel(context.Background())
   120  
   121  	// Create a closed channel to mock TaskCoordinator.startConditionForTask.
   122  	// Closed channel indicates this task is not blocked on prestart hooks.
   123  	closedCh := make(chan struct{})
   124  	close(closedCh)
   125  
   126  	// Set up the Nomad and Consul registration providers along with the wrapper.
   127  	consulRegMock := regMock.NewServiceRegistrationHandler(logger)
   128  	nomadRegMock := regMock.NewServiceRegistrationHandler(logger)
   129  	wrapperMock := wrapper.NewHandlerWrapper(logger, consulRegMock, nomadRegMock)
   130  
   131  	conf := &Config{
   132  		Alloc:                 alloc,
   133  		ClientConfig:          clientConf,
   134  		Task:                  thisTask,
   135  		TaskDir:               taskDir,
   136  		Logger:                clientConf.Logger,
   137  		Consul:                consulRegMock,
   138  		ConsulSI:              consulapi.NewMockServiceIdentitiesClient(),
   139  		Vault:                 vaultclient.NewMockVaultClient(),
   140  		StateDB:               cstate.NoopDB{},
   141  		StateUpdater:          NewMockTaskStateUpdater(),
   142  		DeviceManager:         devicemanager.NoopMockManager(),
   143  		DriverManager:         drivermanager.TestDriverManager(t),
   144  		ServersContactedCh:    make(chan struct{}),
   145  		StartConditionMetCh:   closedCh,
   146  		ShutdownDelayCtx:      shutdownDelayCtx,
   147  		ShutdownDelayCancelFn: shutdownDelayCancelFn,
   148  		ServiceRegWrapper:     wrapperMock,
   149  		Getter:                getter.TestSandbox(t),
   150  	}
   151  
   152  	// Set the cgroup path getter if we are in v2 mode
   153  	if cgutil.UseV2 {
   154  		conf.CpusetCgroupPathGetter = func(context.Context) (string, error) {
   155  			return filepath.Join(cgutil.CgroupRoot, "testing.slice", alloc.ID, thisTask.Name), nil
   156  		}
   157  	}
   158  
   159  	return conf, trCleanup
   160  }
   161  
   162  // runTestTaskRunner runs a TaskRunner and returns its configuration as well as
   163  // a cleanup function that ensures the runner is stopped and cleaned up. Tests
   164  // which need to change the Config *must* use testTaskRunnerConfig instead.
   165  func runTestTaskRunner(t *testing.T, alloc *structs.Allocation, taskName string) (*TaskRunner, *Config, func()) {
   166  	config, cleanup := testTaskRunnerConfig(t, alloc, taskName)
   167  
   168  	tr, err := NewTaskRunner(config)
   169  	require.NoError(t, err)
   170  	go tr.Run()
   171  
   172  	return tr, config, func() {
   173  		tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   174  		cleanup()
   175  	}
   176  }
   177  
   178  func TestTaskRunner_BuildTaskConfig_CPU_Memory(t *testing.T) {
   179  	ci.Parallel(t)
   180  
   181  	cases := []struct {
   182  		name                  string
   183  		cpu                   int64
   184  		memoryMB              int64
   185  		memoryMaxMB           int64
   186  		expectedLinuxMemoryMB int64
   187  	}{
   188  		{
   189  			name:                  "plain no max",
   190  			cpu:                   100,
   191  			memoryMB:              100,
   192  			memoryMaxMB:           0,
   193  			expectedLinuxMemoryMB: 100,
   194  		},
   195  		{
   196  			name:                  "plain with max=reserve",
   197  			cpu:                   100,
   198  			memoryMB:              100,
   199  			memoryMaxMB:           100,
   200  			expectedLinuxMemoryMB: 100,
   201  		},
   202  		{
   203  			name:                  "plain with max>reserve",
   204  			cpu:                   100,
   205  			memoryMB:              100,
   206  			memoryMaxMB:           200,
   207  			expectedLinuxMemoryMB: 200,
   208  		},
   209  	}
   210  
   211  	for _, c := range cases {
   212  		t.Run(c.name, func(t *testing.T) {
   213  			alloc := mock.BatchAlloc()
   214  			alloc.Job.TaskGroups[0].Count = 1
   215  			task := alloc.Job.TaskGroups[0].Tasks[0]
   216  			task.Driver = "mock_driver"
   217  			task.Config = map[string]interface{}{
   218  				"run_for": "2s",
   219  			}
   220  			res := alloc.AllocatedResources.Tasks[task.Name]
   221  			res.Cpu.CpuShares = c.cpu
   222  			res.Memory.MemoryMB = c.memoryMB
   223  			res.Memory.MemoryMaxMB = c.memoryMaxMB
   224  
   225  			conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   226  			conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
   227  			defer cleanup()
   228  
   229  			// Run the first TaskRunner
   230  			tr, err := NewTaskRunner(conf)
   231  			require.NoError(t, err)
   232  
   233  			tc := tr.buildTaskConfig()
   234  			require.Equal(t, c.cpu, tc.Resources.LinuxResources.CPUShares)
   235  			require.Equal(t, c.expectedLinuxMemoryMB*1024*1024, tc.Resources.LinuxResources.MemoryLimitBytes)
   236  
   237  			require.Equal(t, c.cpu, tc.Resources.NomadResources.Cpu.CpuShares)
   238  			require.Equal(t, c.memoryMB, tc.Resources.NomadResources.Memory.MemoryMB)
   239  			require.Equal(t, c.memoryMaxMB, tc.Resources.NomadResources.Memory.MemoryMaxMB)
   240  		})
   241  	}
   242  }
   243  
   244  // TestTaskRunner_Stop_ExitCode asserts that the exit code is captured on a task, even if it's stopped
   245  func TestTaskRunner_Stop_ExitCode(t *testing.T) {
   246  	ctestutil.ExecCompatible(t)
   247  	ci.Parallel(t)
   248  
   249  	alloc := mock.BatchAlloc()
   250  	alloc.Job.TaskGroups[0].Count = 1
   251  	task := alloc.Job.TaskGroups[0].Tasks[0]
   252  	task.KillSignal = "SIGTERM"
   253  	task.Driver = "raw_exec"
   254  	task.Config = map[string]interface{}{
   255  		"command": "/bin/sleep",
   256  		"args":    []string{"1000"},
   257  	}
   258  	task.Env = map[string]string{
   259  		"NOMAD_PARENT_CGROUP": "nomad.slice",
   260  		"NOMAD_ALLOC_ID":      alloc.ID,
   261  		"NOMAD_TASK_NAME":     task.Name,
   262  	}
   263  
   264  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   265  	defer cleanup()
   266  
   267  	// Run the first TaskRunner
   268  	tr, err := NewTaskRunner(conf)
   269  	require.NoError(t, err)
   270  	go tr.Run()
   271  
   272  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   273  
   274  	// Wait for it to be running
   275  	testWaitForTaskToStart(t, tr)
   276  
   277  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   278  	defer cancel()
   279  
   280  	err = tr.Kill(ctx, structs.NewTaskEvent("shutdown"))
   281  	require.NoError(t, err)
   282  
   283  	var exitEvent *structs.TaskEvent
   284  	state := tr.TaskState()
   285  	for _, e := range state.Events {
   286  		if e.Type == structs.TaskTerminated {
   287  			exitEvent = e
   288  			break
   289  		}
   290  	}
   291  	require.NotNilf(t, exitEvent, "exit event not found: %v", state.Events)
   292  
   293  	require.Equal(t, 143, exitEvent.ExitCode)
   294  	require.Equal(t, 15, exitEvent.Signal)
   295  
   296  }
   297  
   298  // TestTaskRunner_Restore_Running asserts restoring a running task does not
   299  // rerun the task.
   300  func TestTaskRunner_Restore_Running(t *testing.T) {
   301  	ci.Parallel(t)
   302  	require := require.New(t)
   303  
   304  	alloc := mock.BatchAlloc()
   305  	alloc.Job.TaskGroups[0].Count = 1
   306  	task := alloc.Job.TaskGroups[0].Tasks[0]
   307  	task.Driver = "mock_driver"
   308  	task.Config = map[string]interface{}{
   309  		"run_for": "2s",
   310  	}
   311  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   312  	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
   313  	defer cleanup()
   314  
   315  	// Run the first TaskRunner
   316  	origTR, err := NewTaskRunner(conf)
   317  	require.NoError(err)
   318  	go origTR.Run()
   319  	defer origTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   320  
   321  	// Wait for it to be running
   322  	testWaitForTaskToStart(t, origTR)
   323  
   324  	// Cause TR to exit without shutting down task
   325  	origTR.Shutdown()
   326  
   327  	// Start a new TaskRunner and make sure it does not rerun the task
   328  	newTR, err := NewTaskRunner(conf)
   329  	require.NoError(err)
   330  
   331  	// Do the Restore
   332  	require.NoError(newTR.Restore())
   333  
   334  	go newTR.Run()
   335  	defer newTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   336  
   337  	// Wait for new task runner to exit when the process does
   338  	testWaitForTaskToDie(t, newTR)
   339  
   340  	// Assert that the process was only started once
   341  	started := 0
   342  	state := newTR.TaskState()
   343  	require.Equal(structs.TaskStateDead, state.State)
   344  	for _, ev := range state.Events {
   345  		if ev.Type == structs.TaskStarted {
   346  			started++
   347  		}
   348  	}
   349  	assert.Equal(t, 1, started)
   350  }
   351  
   352  // TestTaskRunner_Restore_Dead asserts that restoring a dead task will place it
   353  // back in the correct state. If the task was waiting for an alloc restart it
   354  // must be able to be restarted after restore, otherwise a restart must fail.
   355  func TestTaskRunner_Restore_Dead(t *testing.T) {
   356  	ci.Parallel(t)
   357  
   358  	alloc := mock.BatchAlloc()
   359  	alloc.Job.TaskGroups[0].Count = 1
   360  	task := alloc.Job.TaskGroups[0].Tasks[0]
   361  	task.Driver = "mock_driver"
   362  	task.Config = map[string]interface{}{
   363  		"run_for": "2s",
   364  	}
   365  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   366  	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
   367  	defer cleanup()
   368  
   369  	// Run the first TaskRunner
   370  	origTR, err := NewTaskRunner(conf)
   371  	require.NoError(t, err)
   372  	go origTR.Run()
   373  	defer origTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   374  
   375  	// Wait for it to be dead
   376  	testWaitForTaskToDie(t, origTR)
   377  
   378  	// Cause TR to exit without shutting down task
   379  	origTR.Shutdown()
   380  
   381  	// Start a new TaskRunner and do the Restore
   382  	newTR, err := NewTaskRunner(conf)
   383  	require.NoError(t, err)
   384  	require.NoError(t, newTR.Restore())
   385  
   386  	go newTR.Run()
   387  	defer newTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   388  
   389  	// Verify that the TaskRunner is still active since it was recovered after
   390  	// a forced shutdown.
   391  	select {
   392  	case <-newTR.WaitCh():
   393  		require.Fail(t, "WaitCh is not blocking")
   394  	default:
   395  	}
   396  
   397  	// Verify that we can restart task.
   398  	// Retry a few times as the newTR.Run() may not have started yet.
   399  	testutil.WaitForResult(func() (bool, error) {
   400  		ev := &structs.TaskEvent{Type: structs.TaskRestartSignal}
   401  		err = newTR.ForceRestart(context.Background(), ev, false)
   402  		return err == nil, err
   403  	}, func(err error) {
   404  		require.NoError(t, err)
   405  	})
   406  	testWaitForTaskToStart(t, newTR)
   407  
   408  	// Kill task to verify that it's restored as dead and not able to restart.
   409  	newTR.Kill(context.Background(), nil)
   410  	testutil.WaitForResult(func() (bool, error) {
   411  		select {
   412  		case <-newTR.WaitCh():
   413  			return true, nil
   414  		default:
   415  			return false, fmt.Errorf("task still running")
   416  		}
   417  	}, func(err error) {
   418  		require.NoError(t, err)
   419  	})
   420  
   421  	newTR2, err := NewTaskRunner(conf)
   422  	require.NoError(t, err)
   423  	require.NoError(t, newTR2.Restore())
   424  
   425  	go newTR2.Run()
   426  	defer newTR2.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   427  
   428  	ev := &structs.TaskEvent{Type: structs.TaskRestartSignal}
   429  	err = newTR2.ForceRestart(context.Background(), ev, false)
   430  	require.Equal(t, err, ErrTaskNotRunning)
   431  }
   432  
   433  // setupRestoreFailureTest starts a service, shuts down the task runner, and
   434  // kills the task before restarting a new TaskRunner. The new TaskRunner is
   435  // returned once it is running and waiting in pending along with a cleanup
   436  // func.
   437  func setupRestoreFailureTest(t *testing.T, alloc *structs.Allocation) (*TaskRunner, *Config, func()) {
   438  	task := alloc.Job.TaskGroups[0].Tasks[0]
   439  	task.Driver = "raw_exec"
   440  	task.Config = map[string]interface{}{
   441  		"command": "sleep",
   442  		"args":    []string{"30"},
   443  	}
   444  	task.Env = map[string]string{
   445  		"NOMAD_PARENT_CGROUP": "nomad.slice",
   446  		"NOMAD_ALLOC_ID":      alloc.ID,
   447  		"NOMAD_TASK_NAME":     task.Name,
   448  	}
   449  	conf, cleanup1 := testTaskRunnerConfig(t, alloc, task.Name)
   450  	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between runs
   451  
   452  	// Run the first TaskRunner
   453  	origTR, err := NewTaskRunner(conf)
   454  	require.NoError(t, err)
   455  	go origTR.Run()
   456  	cleanup2 := func() {
   457  		origTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   458  		cleanup1()
   459  	}
   460  
   461  	// Wait for it to be running
   462  	testWaitForTaskToStart(t, origTR)
   463  
   464  	handle := origTR.getDriverHandle()
   465  	require.NotNil(t, handle)
   466  	taskID := handle.taskID
   467  
   468  	// Cause TR to exit without shutting down task
   469  	origTR.Shutdown()
   470  
   471  	// Get the driver
   472  	driverPlugin, err := conf.DriverManager.Dispense(rawexec.PluginID.Name)
   473  	require.NoError(t, err)
   474  	rawexecDriver := driverPlugin.(*rawexec.Driver)
   475  
   476  	// Assert the task is still running despite TR having exited
   477  	taskStatus, err := rawexecDriver.InspectTask(taskID)
   478  	require.NoError(t, err)
   479  	require.Equal(t, drivers.TaskStateRunning, taskStatus.State)
   480  
   481  	// Kill the task so it fails to recover when restore is called
   482  	require.NoError(t, rawexecDriver.DestroyTask(taskID, true))
   483  	_, err = rawexecDriver.InspectTask(taskID)
   484  	require.EqualError(t, err, drivers.ErrTaskNotFound.Error())
   485  
   486  	// Create a new TaskRunner and Restore the task
   487  	conf.ServersContactedCh = make(chan struct{})
   488  	newTR, err := NewTaskRunner(conf)
   489  	require.NoError(t, err)
   490  
   491  	// Assert the TR will wait on servers because reattachment failed
   492  	require.NoError(t, newTR.Restore())
   493  	require.True(t, newTR.waitOnServers)
   494  
   495  	// Start new TR
   496  	go newTR.Run()
   497  	cleanup3 := func() {
   498  		newTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   499  		cleanup2()
   500  		cleanup1()
   501  	}
   502  
   503  	// Assert task has not been restarted
   504  	_, err = rawexecDriver.InspectTask(taskID)
   505  	require.EqualError(t, err, drivers.ErrTaskNotFound.Error())
   506  	ts := newTR.TaskState()
   507  	require.Equal(t, structs.TaskStatePending, ts.State)
   508  
   509  	return newTR, conf, cleanup3
   510  }
   511  
   512  // TestTaskRunner_Restore_Restart asserts restoring a dead task blocks until
   513  // MarkAlive is called. #1795
   514  func TestTaskRunner_Restore_Restart(t *testing.T) {
   515  	ci.Parallel(t)
   516  
   517  	newTR, conf, cleanup := setupRestoreFailureTest(t, mock.Alloc())
   518  	defer cleanup()
   519  
   520  	// Fake contacting the server by closing the chan
   521  	close(conf.ServersContactedCh)
   522  
   523  	testutil.WaitForResult(func() (bool, error) {
   524  		ts := newTR.TaskState().State
   525  		return ts == structs.TaskStateRunning, fmt.Errorf("expected task to be running but found %q", ts)
   526  	}, func(err error) {
   527  		require.NoError(t, err)
   528  	})
   529  }
   530  
   531  // TestTaskRunner_Restore_Kill asserts restoring a dead task blocks until
   532  // the task is killed. #1795
   533  func TestTaskRunner_Restore_Kill(t *testing.T) {
   534  	ci.Parallel(t)
   535  
   536  	newTR, _, cleanup := setupRestoreFailureTest(t, mock.Alloc())
   537  	defer cleanup()
   538  
   539  	// Sending the task a terminal update shouldn't kill it or unblock it
   540  	alloc := newTR.Alloc().Copy()
   541  	alloc.DesiredStatus = structs.AllocDesiredStatusStop
   542  	newTR.Update(alloc)
   543  
   544  	require.Equal(t, structs.TaskStatePending, newTR.TaskState().State)
   545  
   546  	// AllocRunner will immediately kill tasks after sending a terminal
   547  	// update.
   548  	newTR.Kill(context.Background(), structs.NewTaskEvent(structs.TaskKilling))
   549  
   550  	select {
   551  	case <-newTR.WaitCh():
   552  		// It died as expected!
   553  	case <-time.After(10 * time.Second):
   554  		require.Fail(t, "timeout waiting for task to die")
   555  	}
   556  }
   557  
   558  // TestTaskRunner_Restore_Update asserts restoring a dead task blocks until
   559  // Update is called. #1795
   560  func TestTaskRunner_Restore_Update(t *testing.T) {
   561  	ci.Parallel(t)
   562  
   563  	newTR, conf, cleanup := setupRestoreFailureTest(t, mock.Alloc())
   564  	defer cleanup()
   565  
   566  	// Fake Client.runAllocs behavior by calling Update then closing chan
   567  	alloc := newTR.Alloc().Copy()
   568  	newTR.Update(alloc)
   569  
   570  	// Update alone should not unblock the test
   571  	require.Equal(t, structs.TaskStatePending, newTR.TaskState().State)
   572  
   573  	// Fake Client.runAllocs behavior of closing chan after Update
   574  	close(conf.ServersContactedCh)
   575  
   576  	testutil.WaitForResult(func() (bool, error) {
   577  		ts := newTR.TaskState().State
   578  		return ts == structs.TaskStateRunning, fmt.Errorf("expected task to be running but found %q", ts)
   579  	}, func(err error) {
   580  		require.NoError(t, err)
   581  	})
   582  }
   583  
   584  // TestTaskRunner_Restore_System asserts restoring a dead system task does not
   585  // block.
   586  func TestTaskRunner_Restore_System(t *testing.T) {
   587  	ci.Parallel(t)
   588  
   589  	alloc := mock.Alloc()
   590  	alloc.Job.Type = structs.JobTypeSystem
   591  	task := alloc.Job.TaskGroups[0].Tasks[0]
   592  	task.Driver = "raw_exec"
   593  	task.Config = map[string]interface{}{
   594  		"command": "sleep",
   595  		"args":    []string{"30"},
   596  	}
   597  	task.Env = map[string]string{
   598  		"NOMAD_PARENT_CGROUP": "nomad.slice",
   599  		"NOMAD_ALLOC_ID":      alloc.ID,
   600  		"NOMAD_TASK_NAME":     task.Name,
   601  	}
   602  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   603  	defer cleanup()
   604  	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between runs
   605  
   606  	// Run the first TaskRunner
   607  	origTR, err := NewTaskRunner(conf)
   608  	require.NoError(t, err)
   609  	go origTR.Run()
   610  	defer origTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   611  
   612  	// Wait for it to be running
   613  	testWaitForTaskToStart(t, origTR)
   614  
   615  	handle := origTR.getDriverHandle()
   616  	require.NotNil(t, handle)
   617  	taskID := handle.taskID
   618  
   619  	// Cause TR to exit without shutting down task
   620  	origTR.Shutdown()
   621  
   622  	// Get the driver
   623  	driverPlugin, err := conf.DriverManager.Dispense(rawexec.PluginID.Name)
   624  	require.NoError(t, err)
   625  	rawexecDriver := driverPlugin.(*rawexec.Driver)
   626  
   627  	// Assert the task is still running despite TR having exited
   628  	taskStatus, err := rawexecDriver.InspectTask(taskID)
   629  	require.NoError(t, err)
   630  	require.Equal(t, drivers.TaskStateRunning, taskStatus.State)
   631  
   632  	// Kill the task so it fails to recover when restore is called
   633  	require.NoError(t, rawexecDriver.DestroyTask(taskID, true))
   634  	_, err = rawexecDriver.InspectTask(taskID)
   635  	require.EqualError(t, err, drivers.ErrTaskNotFound.Error())
   636  
   637  	// Create a new TaskRunner and Restore the task
   638  	conf.ServersContactedCh = make(chan struct{})
   639  	newTR, err := NewTaskRunner(conf)
   640  	require.NoError(t, err)
   641  
   642  	// Assert the TR will not wait on servers even though reattachment
   643  	// failed because it is a system task.
   644  	require.NoError(t, newTR.Restore())
   645  	require.False(t, newTR.waitOnServers)
   646  
   647  	// Nothing should have closed the chan
   648  	select {
   649  	case <-conf.ServersContactedCh:
   650  		require.Fail(t, "serversContactedCh was closed but should not have been")
   651  	default:
   652  	}
   653  
   654  	testutil.WaitForResult(func() (bool, error) {
   655  		ts := newTR.TaskState().State
   656  		return ts == structs.TaskStateRunning, fmt.Errorf("expected task to be running but found %q", ts)
   657  	}, func(err error) {
   658  		require.NoError(t, err)
   659  	})
   660  }
   661  
   662  // TestTaskRunner_TaskEnv_Interpolated asserts driver configurations are
   663  // interpolated.
   664  func TestTaskRunner_TaskEnv_Interpolated(t *testing.T) {
   665  	ci.Parallel(t)
   666  	require := require.New(t)
   667  
   668  	alloc := mock.BatchAlloc()
   669  	alloc.Job.TaskGroups[0].Meta = map[string]string{
   670  		"common_user": "somebody",
   671  	}
   672  	task := alloc.Job.TaskGroups[0].Tasks[0]
   673  	task.Meta = map[string]string{
   674  		"foo": "bar",
   675  	}
   676  
   677  	// Use interpolation from both node attributes and meta vars
   678  	task.Config = map[string]interface{}{
   679  		"run_for":       "1ms",
   680  		"stdout_string": `${node.region} ${NOMAD_META_foo} ${NOMAD_META_common_user}`,
   681  	}
   682  
   683  	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
   684  	defer cleanup()
   685  
   686  	// Wait for task to complete
   687  	testWaitForTaskToDie(t, tr)
   688  
   689  	// Get the mock driver plugin
   690  	driverPlugin, err := conf.DriverManager.Dispense(mockdriver.PluginID.Name)
   691  	require.NoError(err)
   692  	mockDriver := driverPlugin.(*mockdriver.Driver)
   693  
   694  	// Assert its config has been properly interpolated
   695  	driverCfg, mockCfg := mockDriver.GetTaskConfig()
   696  	require.NotNil(driverCfg)
   697  	require.NotNil(mockCfg)
   698  	assert.Equal(t, "global bar somebody", mockCfg.StdoutString)
   699  }
   700  
   701  // TestTaskRunner_TaskEnv_None asserts raw_exec uses host paths and env vars.
   702  func TestTaskRunner_TaskEnv_None(t *testing.T) {
   703  	ci.Parallel(t)
   704  	require := require.New(t)
   705  
   706  	alloc := mock.BatchAlloc()
   707  	task := alloc.Job.TaskGroups[0].Tasks[0]
   708  	task.Driver = "raw_exec"
   709  	task.Config = map[string]interface{}{
   710  		"command": "sh",
   711  		"args": []string{"-c", "echo $NOMAD_ALLOC_DIR; " +
   712  			"echo $NOMAD_TASK_DIR; " +
   713  			"echo $NOMAD_SECRETS_DIR; " +
   714  			"echo $PATH",
   715  		},
   716  	}
   717  	task.Env = map[string]string{
   718  		"NOMAD_PARENT_CGROUP": "nomad.slice",
   719  		"NOMAD_ALLOC_ID":      alloc.ID,
   720  		"NOMAD_TASK_NAME":     task.Name,
   721  	}
   722  	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
   723  	defer cleanup()
   724  
   725  	// Expect host paths
   726  	root := filepath.Join(conf.ClientConfig.AllocDir, alloc.ID)
   727  	taskDir := filepath.Join(root, task.Name)
   728  	exp := fmt.Sprintf(`%s/alloc
   729  %s/local
   730  %s/secrets
   731  %s
   732  `, root, taskDir, taskDir, os.Getenv("PATH"))
   733  
   734  	// Wait for task to exit and kill the task runner to run the stop hooks.
   735  	testWaitForTaskToDie(t, tr)
   736  	tr.Kill(context.Background(), structs.NewTaskEvent("kill"))
   737  	select {
   738  	case <-tr.WaitCh():
   739  	case <-time.After(15 * time.Second):
   740  		require.Fail("timeout waiting for task to exit")
   741  	}
   742  
   743  	// Read stdout
   744  	p := filepath.Join(conf.TaskDir.LogDir, task.Name+".stdout.0")
   745  	stdout, err := ioutil.ReadFile(p)
   746  	require.NoError(err)
   747  	require.Equalf(exp, string(stdout), "expected: %s\n\nactual: %s\n", exp, stdout)
   748  }
   749  
   750  // Test that devices get sent to the driver
   751  func TestTaskRunner_DevicePropogation(t *testing.T) {
   752  	ci.Parallel(t)
   753  	require := require.New(t)
   754  
   755  	// Create a mock alloc that has a gpu
   756  	alloc := mock.BatchAlloc()
   757  	alloc.Job.TaskGroups[0].Count = 1
   758  	task := alloc.Job.TaskGroups[0].Tasks[0]
   759  	task.Driver = "mock_driver"
   760  	task.Config = map[string]interface{}{
   761  		"run_for": "100ms",
   762  	}
   763  	tRes := alloc.AllocatedResources.Tasks[task.Name]
   764  	tRes.Devices = append(tRes.Devices, &structs.AllocatedDeviceResource{Type: "mock"})
   765  
   766  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   767  	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
   768  	defer cleanup()
   769  
   770  	// Setup the devicemanager
   771  	dm, ok := conf.DeviceManager.(*devicemanager.MockManager)
   772  	require.True(ok)
   773  
   774  	dm.ReserveF = func(d *structs.AllocatedDeviceResource) (*device.ContainerReservation, error) {
   775  		res := &device.ContainerReservation{
   776  			Envs: map[string]string{
   777  				"ABC": "123",
   778  			},
   779  			Mounts: []*device.Mount{
   780  				{
   781  					ReadOnly: true,
   782  					TaskPath: "foo",
   783  					HostPath: "bar",
   784  				},
   785  			},
   786  			Devices: []*device.DeviceSpec{
   787  				{
   788  					TaskPath:    "foo",
   789  					HostPath:    "bar",
   790  					CgroupPerms: "123",
   791  				},
   792  			},
   793  		}
   794  		return res, nil
   795  	}
   796  
   797  	// Run the TaskRunner
   798  	tr, err := NewTaskRunner(conf)
   799  	require.NoError(err)
   800  	go tr.Run()
   801  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   802  
   803  	// Wait for task to complete
   804  	testWaitForTaskToDie(t, tr)
   805  
   806  	// Get the mock driver plugin
   807  	driverPlugin, err := conf.DriverManager.Dispense(mockdriver.PluginID.Name)
   808  	require.NoError(err)
   809  	mockDriver := driverPlugin.(*mockdriver.Driver)
   810  
   811  	// Assert its config has been properly interpolated
   812  	driverCfg, _ := mockDriver.GetTaskConfig()
   813  	require.NotNil(driverCfg)
   814  	require.Len(driverCfg.Devices, 1)
   815  	require.Equal(driverCfg.Devices[0].Permissions, "123")
   816  	require.Len(driverCfg.Mounts, 1)
   817  	require.Equal(driverCfg.Mounts[0].TaskPath, "foo")
   818  	require.Contains(driverCfg.Env, "ABC")
   819  }
   820  
   821  // mockEnvHook is a test hook that sets an env var and done=true. It fails if
   822  // it's called more than once.
   823  type mockEnvHook struct {
   824  	called int
   825  }
   826  
   827  func (*mockEnvHook) Name() string {
   828  	return "mock_env_hook"
   829  }
   830  
   831  func (h *mockEnvHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error {
   832  	h.called++
   833  
   834  	resp.Done = true
   835  	resp.Env = map[string]string{
   836  		"mock_hook": "1",
   837  	}
   838  
   839  	return nil
   840  }
   841  
   842  // TestTaskRunner_Restore_HookEnv asserts that re-running prestart hooks with
   843  // hook environments set restores the environment without re-running done
   844  // hooks.
   845  func TestTaskRunner_Restore_HookEnv(t *testing.T) {
   846  	ci.Parallel(t)
   847  	require := require.New(t)
   848  
   849  	alloc := mock.BatchAlloc()
   850  	task := alloc.Job.TaskGroups[0].Tasks[0]
   851  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   852  	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between prestart calls
   853  	defer cleanup()
   854  
   855  	tr, err := NewTaskRunner(conf)
   856  	require.NoError(err)
   857  
   858  	// Override the default hooks to only run the mock hook
   859  	mockHook := &mockEnvHook{}
   860  	tr.runnerHooks = []interfaces.TaskHook{mockHook}
   861  
   862  	// Manually run prestart hooks
   863  	require.NoError(tr.prestart())
   864  
   865  	// Assert env was called
   866  	require.Equal(1, mockHook.called)
   867  
   868  	// Re-running prestart hooks should *not* call done mock hook
   869  	require.NoError(tr.prestart())
   870  
   871  	// Assert env was called
   872  	require.Equal(1, mockHook.called)
   873  
   874  	// Assert the env is still set
   875  	env := tr.envBuilder.Build().All()
   876  	require.Contains(env, "mock_hook")
   877  	require.Equal("1", env["mock_hook"])
   878  }
   879  
   880  // This test asserts that we can recover from an "external" plugin exiting by
   881  // retrieving a new instance of the driver and recovering the task.
   882  func TestTaskRunner_RecoverFromDriverExiting(t *testing.T) {
   883  	ci.Parallel(t)
   884  	require := require.New(t)
   885  
   886  	// Create an allocation using the mock driver that exits simulating the
   887  	// driver crashing. We can then test that the task runner recovers from this
   888  	alloc := mock.BatchAlloc()
   889  	task := alloc.Job.TaskGroups[0].Tasks[0]
   890  	task.Driver = "mock_driver"
   891  	task.Config = map[string]interface{}{
   892  		"plugin_exit_after": "1s",
   893  		"run_for":           "5s",
   894  	}
   895  
   896  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   897  	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between prestart calls
   898  	defer cleanup()
   899  
   900  	tr, err := NewTaskRunner(conf)
   901  	require.NoError(err)
   902  
   903  	start := time.Now()
   904  	go tr.Run()
   905  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   906  
   907  	// Wait for the task to be running
   908  	testWaitForTaskToStart(t, tr)
   909  
   910  	// Get the task ID
   911  	tr.stateLock.RLock()
   912  	l := tr.localState.TaskHandle
   913  	require.NotNil(l)
   914  	require.NotNil(l.Config)
   915  	require.NotEmpty(l.Config.ID)
   916  	id := l.Config.ID
   917  	tr.stateLock.RUnlock()
   918  
   919  	// Get the mock driver plugin
   920  	driverPlugin, err := conf.DriverManager.Dispense(mockdriver.PluginID.Name)
   921  	require.NoError(err)
   922  	mockDriver := driverPlugin.(*mockdriver.Driver)
   923  
   924  	// Wait for the task to start
   925  	testutil.WaitForResult(func() (bool, error) {
   926  		// Get the handle and check that it was recovered
   927  		handle := mockDriver.GetHandle(id)
   928  		if handle == nil {
   929  			return false, fmt.Errorf("nil handle")
   930  		}
   931  		if !handle.Recovered {
   932  			return false, fmt.Errorf("handle not recovered")
   933  		}
   934  		return true, nil
   935  	}, func(err error) {
   936  		t.Fatal(err.Error())
   937  	})
   938  
   939  	// Wait for task to complete
   940  	select {
   941  	case <-tr.WaitCh():
   942  	case <-time.After(10 * time.Second):
   943  	}
   944  
   945  	// Ensure that we actually let the task complete
   946  	require.True(time.Now().Sub(start) > 5*time.Second)
   947  
   948  	// Check it finished successfully
   949  	state := tr.TaskState()
   950  	require.True(state.Successful())
   951  }
   952  
   953  // TestTaskRunner_ShutdownDelay asserts services are removed from Consul
   954  // ${shutdown_delay} seconds before killing the process.
   955  func TestTaskRunner_ShutdownDelay(t *testing.T) {
   956  	ci.Parallel(t)
   957  
   958  	alloc := mock.Alloc()
   959  	task := alloc.Job.TaskGroups[0].Tasks[0]
   960  	task.Services[0].Tags = []string{"tag1"}
   961  	task.Services = task.Services[:1] // only need 1 for this test
   962  	task.Driver = "mock_driver"
   963  	task.Config = map[string]interface{}{
   964  		"run_for": "1000s",
   965  	}
   966  
   967  	// No shutdown escape hatch for this delay, so don't set it too high
   968  	task.ShutdownDelay = 1000 * time.Duration(testutil.TestMultiplier()) * time.Millisecond
   969  
   970  	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
   971  	defer cleanup()
   972  
   973  	mockConsul := conf.Consul.(*regMock.ServiceRegistrationHandler)
   974  
   975  	// Wait for the task to start
   976  	testWaitForTaskToStart(t, tr)
   977  
   978  	testutil.WaitForResult(func() (bool, error) {
   979  		ops := mockConsul.GetOps()
   980  		if n := len(ops); n != 1 {
   981  			return false, fmt.Errorf("expected 1 consul operation. Found %d", n)
   982  		}
   983  		return ops[0].Op == "add", fmt.Errorf("consul operation was not a registration: %#v", ops[0])
   984  	}, func(err error) {
   985  		t.Fatalf("err: %v", err)
   986  	})
   987  
   988  	// Asynchronously kill task
   989  	killSent := time.Now()
   990  	killed := make(chan struct{})
   991  	go func() {
   992  		defer close(killed)
   993  		assert.NoError(t, tr.Kill(context.Background(), structs.NewTaskEvent("test")))
   994  	}()
   995  
   996  	// Wait for *1* de-registration calls (all [non-]canary variants removed).
   997  
   998  WAIT:
   999  	for {
  1000  		ops := mockConsul.GetOps()
  1001  		switch n := len(ops); n {
  1002  		case 1:
  1003  			// Waiting for single de-registration call.
  1004  		case 2:
  1005  			require.Equalf(t, "remove", ops[1].Op, "expected deregistration but found: %#v", ops[1])
  1006  			break WAIT
  1007  		default:
  1008  			// ?!
  1009  			t.Fatalf("unexpected number of consul operations: %d\n%s", n, pretty.Sprint(ops))
  1010  
  1011  		}
  1012  
  1013  		select {
  1014  		case <-killed:
  1015  			t.Fatal("killed while service still registered")
  1016  		case <-time.After(10 * time.Millisecond):
  1017  		}
  1018  	}
  1019  
  1020  	// Wait for actual exit
  1021  	select {
  1022  	case <-tr.WaitCh():
  1023  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1024  		t.Fatalf("timeout")
  1025  	}
  1026  
  1027  	<-killed
  1028  	killDur := time.Now().Sub(killSent)
  1029  	if killDur < task.ShutdownDelay {
  1030  		t.Fatalf("task killed before shutdown_delay (killed_after: %s; shutdown_delay: %s",
  1031  			killDur, task.ShutdownDelay,
  1032  		)
  1033  	}
  1034  }
  1035  
  1036  // TestTaskRunner_NoShutdownDelay asserts services are removed from
  1037  // Consul and tasks are killed without waiting for ${shutdown_delay}
  1038  // when the alloc has the NoShutdownDelay transition flag set.
  1039  func TestTaskRunner_NoShutdownDelay(t *testing.T) {
  1040  	ci.Parallel(t)
  1041  
  1042  	// don't set this too high so that we don't block the test runner
  1043  	// on shutting down the agent if the test fails
  1044  	maxTestDuration := time.Duration(testutil.TestMultiplier()*10) * time.Second
  1045  	maxTimeToFailDuration := time.Duration(testutil.TestMultiplier()) * time.Second
  1046  
  1047  	alloc := mock.Alloc()
  1048  	alloc.DesiredTransition = structs.DesiredTransition{NoShutdownDelay: pointer.Of(true)}
  1049  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1050  	task.Services[0].Tags = []string{"tag1"}
  1051  	task.Services = task.Services[:1] // only need 1 for this test
  1052  	task.Driver = "mock_driver"
  1053  	task.Config = map[string]interface{}{
  1054  		"run_for": "1000s",
  1055  	}
  1056  	task.ShutdownDelay = maxTestDuration
  1057  
  1058  	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
  1059  	defer cleanup()
  1060  
  1061  	mockConsul := conf.Consul.(*regMock.ServiceRegistrationHandler)
  1062  
  1063  	testWaitForTaskToStart(t, tr)
  1064  
  1065  	testutil.WaitForResult(func() (bool, error) {
  1066  		ops := mockConsul.GetOps()
  1067  		if n := len(ops); n != 1 {
  1068  			return false, fmt.Errorf("expected 1 consul operation. Found %d", n)
  1069  		}
  1070  		return ops[0].Op == "add", fmt.Errorf("consul operation was not a registration: %#v", ops[0])
  1071  	}, func(err error) {
  1072  		t.Fatalf("err: %v", err)
  1073  	})
  1074  
  1075  	testCtx, cancel := context.WithTimeout(context.Background(), maxTimeToFailDuration)
  1076  	defer cancel()
  1077  
  1078  	killed := make(chan error)
  1079  	go func() {
  1080  		tr.shutdownDelayCancel()
  1081  		err := tr.Kill(testCtx, structs.NewTaskEvent("test"))
  1082  		killed <- err
  1083  	}()
  1084  
  1085  	// Wait for first de-registration call. Note that unlike
  1086  	// TestTaskRunner_ShutdownDelay, we're racing with task exit
  1087  	// and can't assert that we only get the first deregistration op
  1088  	// (from serviceHook.PreKill).
  1089  	testutil.WaitForResult(func() (bool, error) {
  1090  		ops := mockConsul.GetOps()
  1091  		if n := len(ops); n < 2 {
  1092  			return false, fmt.Errorf("expected at least 2 consul operations.")
  1093  		}
  1094  		return ops[1].Op == "remove", fmt.Errorf(
  1095  			"consul operation was not a deregistration: %#v", ops[1])
  1096  	}, func(err error) {
  1097  		t.Fatalf("err: %v", err)
  1098  	})
  1099  
  1100  	// Wait for the task to exit
  1101  	select {
  1102  	case <-tr.WaitCh():
  1103  	case <-time.After(maxTimeToFailDuration):
  1104  		t.Fatalf("task kill did not ignore shutdown delay")
  1105  		return
  1106  	}
  1107  
  1108  	err := <-killed
  1109  	require.NoError(t, err, "killing task returned unexpected error")
  1110  }
  1111  
  1112  // TestTaskRunner_Dispatch_Payload asserts that a dispatch job runs and the
  1113  // payload was written to disk.
  1114  func TestTaskRunner_Dispatch_Payload(t *testing.T) {
  1115  	ci.Parallel(t)
  1116  
  1117  	alloc := mock.BatchAlloc()
  1118  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1119  	task.Driver = "mock_driver"
  1120  	task.Config = map[string]interface{}{
  1121  		"run_for": "1s",
  1122  	}
  1123  
  1124  	fileName := "test"
  1125  	task.DispatchPayload = &structs.DispatchPayloadConfig{
  1126  		File: fileName,
  1127  	}
  1128  	alloc.Job.ParameterizedJob = &structs.ParameterizedJobConfig{}
  1129  
  1130  	// Add a payload (they're snappy encoded bytes)
  1131  	expected := []byte("hello world")
  1132  	compressed := snappy.Encode(nil, expected)
  1133  	alloc.Job.Payload = compressed
  1134  
  1135  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
  1136  	defer cleanup()
  1137  
  1138  	// Wait for it to finish
  1139  	testutil.WaitForResult(func() (bool, error) {
  1140  		ts := tr.TaskState()
  1141  		return ts.State == structs.TaskStateDead, fmt.Errorf("%v", ts.State)
  1142  	}, func(err error) {
  1143  		require.NoError(t, err)
  1144  	})
  1145  
  1146  	// Should have exited successfully
  1147  	ts := tr.TaskState()
  1148  	require.False(t, ts.Failed)
  1149  	require.Zero(t, ts.Restarts)
  1150  
  1151  	// Check that the file was written to disk properly
  1152  	payloadPath := filepath.Join(tr.taskDir.LocalDir, fileName)
  1153  	data, err := ioutil.ReadFile(payloadPath)
  1154  	require.NoError(t, err)
  1155  	require.Equal(t, expected, data)
  1156  }
  1157  
  1158  // TestTaskRunner_SignalFailure asserts that signal errors are properly
  1159  // propagated from the driver to TaskRunner.
  1160  func TestTaskRunner_SignalFailure(t *testing.T) {
  1161  	ci.Parallel(t)
  1162  
  1163  	alloc := mock.Alloc()
  1164  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1165  	task.Driver = "mock_driver"
  1166  	errMsg := "test forcing failure"
  1167  	task.Config = map[string]interface{}{
  1168  		"run_for":      "10m",
  1169  		"signal_error": errMsg,
  1170  	}
  1171  
  1172  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
  1173  	defer cleanup()
  1174  
  1175  	testWaitForTaskToStart(t, tr)
  1176  
  1177  	require.EqualError(t, tr.Signal(&structs.TaskEvent{}, "SIGINT"), errMsg)
  1178  }
  1179  
  1180  // TestTaskRunner_RestartTask asserts that restarting a task works and emits a
  1181  // Restarting event.
  1182  func TestTaskRunner_RestartTask(t *testing.T) {
  1183  	ci.Parallel(t)
  1184  
  1185  	alloc := mock.Alloc()
  1186  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1187  	task.Driver = "mock_driver"
  1188  	task.Config = map[string]interface{}{
  1189  		"run_for": "10m",
  1190  	}
  1191  
  1192  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
  1193  	defer cleanup()
  1194  
  1195  	testWaitForTaskToStart(t, tr)
  1196  
  1197  	// Restart task. Send a RestartSignal event like check watcher. Restart
  1198  	// handler emits the Restarting event.
  1199  	event := structs.NewTaskEvent(structs.TaskRestartSignal).SetRestartReason("test")
  1200  	const fail = false
  1201  	tr.Restart(context.Background(), event.Copy(), fail)
  1202  
  1203  	// Wait for it to restart and be running again
  1204  	testutil.WaitForResult(func() (bool, error) {
  1205  		ts := tr.TaskState()
  1206  		if ts.Restarts != 1 {
  1207  			return false, fmt.Errorf("expected 1 restart but found %d\nevents: %s",
  1208  				ts.Restarts, pretty.Sprint(ts.Events))
  1209  		}
  1210  		if ts.State != structs.TaskStateRunning {
  1211  			return false, fmt.Errorf("expected running but received %s", ts.State)
  1212  		}
  1213  		return true, nil
  1214  	}, func(err error) {
  1215  		require.NoError(t, err)
  1216  	})
  1217  
  1218  	// Assert the expected Restarting event was emitted
  1219  	found := false
  1220  	events := tr.TaskState().Events
  1221  	for _, e := range events {
  1222  		if e.Type == structs.TaskRestartSignal {
  1223  			found = true
  1224  			require.Equal(t, event.Time, e.Time)
  1225  			require.Equal(t, event.RestartReason, e.RestartReason)
  1226  			require.Contains(t, e.DisplayMessage, event.RestartReason)
  1227  		}
  1228  	}
  1229  	require.True(t, found, "restarting task event not found", pretty.Sprint(events))
  1230  }
  1231  
  1232  // TestTaskRunner_CheckWatcher_Restart asserts that when enabled an unhealthy
  1233  // Consul check will cause a task to restart following restart policy rules.
  1234  func TestTaskRunner_CheckWatcher_Restart(t *testing.T) {
  1235  	ci.Parallel(t)
  1236  
  1237  	alloc := mock.Alloc()
  1238  
  1239  	// Make the restart policy fail within this test
  1240  	tg := alloc.Job.TaskGroups[0]
  1241  	tg.RestartPolicy.Attempts = 2
  1242  	tg.RestartPolicy.Interval = 1 * time.Minute
  1243  	tg.RestartPolicy.Delay = 10 * time.Millisecond
  1244  	tg.RestartPolicy.Mode = structs.RestartPolicyModeFail
  1245  
  1246  	task := tg.Tasks[0]
  1247  	task.Driver = "mock_driver"
  1248  	task.Config = map[string]interface{}{
  1249  		"run_for": "10m",
  1250  	}
  1251  
  1252  	// Make the task register a check that fails
  1253  	task.Services[0].Checks[0] = &structs.ServiceCheck{
  1254  		Name:     "test-restarts",
  1255  		Type:     structs.ServiceCheckTCP,
  1256  		Interval: 50 * time.Millisecond,
  1257  		CheckRestart: &structs.CheckRestart{
  1258  			Limit: 2,
  1259  			Grace: 100 * time.Millisecond,
  1260  		},
  1261  	}
  1262  	task.Services[0].Provider = structs.ServiceProviderConsul
  1263  
  1264  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1265  	defer cleanup()
  1266  
  1267  	// Replace mock Consul ServiceClient, with the real ServiceClient
  1268  	// backed by a mock consul whose checks are always unhealthy.
  1269  	consulAgent := agentconsul.NewMockAgent(agentconsul.Features{
  1270  		Enterprise: false,
  1271  		Namespaces: false,
  1272  	})
  1273  	consulAgent.SetStatus("critical")
  1274  	namespacesClient := agentconsul.NewNamespacesClient(agentconsul.NewMockNamespaces(nil), consulAgent)
  1275  	consulClient := agentconsul.NewServiceClient(consulAgent, namespacesClient, conf.Logger, true)
  1276  	go consulClient.Run()
  1277  	defer consulClient.Shutdown()
  1278  
  1279  	conf.Consul = consulClient
  1280  	conf.ServiceRegWrapper = wrapper.NewHandlerWrapper(conf.Logger, consulClient, nil)
  1281  
  1282  	tr, err := NewTaskRunner(conf)
  1283  	require.NoError(t, err)
  1284  
  1285  	expectedEvents := []string{
  1286  		"Received",
  1287  		"Task Setup",
  1288  		"Started",
  1289  		"Restart Signaled",
  1290  		"Terminated",
  1291  		"Restarting",
  1292  		"Started",
  1293  		"Restart Signaled",
  1294  		"Terminated",
  1295  		"Restarting",
  1296  		"Started",
  1297  		"Restart Signaled",
  1298  		"Terminated",
  1299  		"Not Restarting",
  1300  	}
  1301  
  1302  	// Bump maxEvents so task events aren't dropped
  1303  	tr.maxEvents = 100
  1304  
  1305  	go tr.Run()
  1306  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1307  
  1308  	// Wait until the task exits. Don't simply wait for it to run as it may
  1309  	// get restarted and terminated before the test is able to observe it
  1310  	// running.
  1311  	testWaitForTaskToDie(t, tr)
  1312  
  1313  	state := tr.TaskState()
  1314  	actualEvents := make([]string, len(state.Events))
  1315  	for i, e := range state.Events {
  1316  		actualEvents[i] = string(e.Type)
  1317  	}
  1318  	require.Equal(t, actualEvents, expectedEvents)
  1319  	require.Equal(t, structs.TaskStateDead, state.State)
  1320  	require.True(t, state.Failed, pretty.Sprint(state))
  1321  }
  1322  
  1323  type mockEnvoyBootstrapHook struct {
  1324  	// nothing
  1325  }
  1326  
  1327  func (_ *mockEnvoyBootstrapHook) Name() string {
  1328  	return "mock_envoy_bootstrap"
  1329  }
  1330  
  1331  func (_ *mockEnvoyBootstrapHook) Prestart(_ context.Context, _ *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error {
  1332  	resp.Done = true
  1333  	return nil
  1334  }
  1335  
  1336  // The envoy bootstrap hook tries to connect to consul and run the envoy
  1337  // bootstrap command, so turn it off when testing connect jobs that are not
  1338  // using envoy.
  1339  func useMockEnvoyBootstrapHook(tr *TaskRunner) {
  1340  	mock := new(mockEnvoyBootstrapHook)
  1341  	for i, hook := range tr.runnerHooks {
  1342  		if _, ok := hook.(*envoyBootstrapHook); ok {
  1343  			tr.runnerHooks[i] = mock
  1344  		}
  1345  	}
  1346  }
  1347  
  1348  // TestTaskRunner_BlockForSIDSToken asserts tasks do not start until a Consul
  1349  // Service Identity token is derived.
  1350  func TestTaskRunner_BlockForSIDSToken(t *testing.T) {
  1351  	ci.Parallel(t)
  1352  	r := require.New(t)
  1353  
  1354  	alloc := mock.BatchConnectAlloc()
  1355  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1356  	task.Config = map[string]interface{}{
  1357  		"run_for": "0s",
  1358  	}
  1359  
  1360  	trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1361  	defer cleanup()
  1362  
  1363  	// set a consul token on the Nomad client's consul config, because that is
  1364  	// what gates the action of requesting SI token(s)
  1365  	trConfig.ClientConfig.ConsulConfig.Token = uuid.Generate()
  1366  
  1367  	// control when we get a Consul SI token
  1368  	token := uuid.Generate()
  1369  	waitCh := make(chan struct{})
  1370  	deriveFn := func(*structs.Allocation, []string) (map[string]string, error) {
  1371  		<-waitCh
  1372  		return map[string]string{task.Name: token}, nil
  1373  	}
  1374  	siClient := trConfig.ConsulSI.(*consulapi.MockServiceIdentitiesClient)
  1375  	siClient.DeriveTokenFn = deriveFn
  1376  
  1377  	// start the task runner
  1378  	tr, err := NewTaskRunner(trConfig)
  1379  	r.NoError(err)
  1380  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1381  	useMockEnvoyBootstrapHook(tr) // mock the envoy bootstrap hook
  1382  
  1383  	go tr.Run()
  1384  
  1385  	// assert task runner blocks on SI token
  1386  	select {
  1387  	case <-tr.WaitCh():
  1388  		r.Fail("task_runner exited before si unblocked")
  1389  	case <-time.After(100 * time.Millisecond):
  1390  	}
  1391  
  1392  	// assert task state is still pending
  1393  	r.Equal(structs.TaskStatePending, tr.TaskState().State)
  1394  
  1395  	// unblock service identity token
  1396  	close(waitCh)
  1397  
  1398  	// task runner should exit now that it has been unblocked and it is a batch
  1399  	// job with a zero sleep time
  1400  	testWaitForTaskToDie(t, tr)
  1401  
  1402  	// assert task exited successfully
  1403  	finalState := tr.TaskState()
  1404  	r.Equal(structs.TaskStateDead, finalState.State)
  1405  	r.False(finalState.Failed)
  1406  
  1407  	// assert the token is on disk
  1408  	tokenPath := filepath.Join(trConfig.TaskDir.SecretsDir, sidsTokenFile)
  1409  	data, err := ioutil.ReadFile(tokenPath)
  1410  	r.NoError(err)
  1411  	r.Equal(token, string(data))
  1412  }
  1413  
  1414  func TestTaskRunner_DeriveSIToken_Retry(t *testing.T) {
  1415  	ci.Parallel(t)
  1416  	r := require.New(t)
  1417  
  1418  	alloc := mock.BatchConnectAlloc()
  1419  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1420  	task.Config = map[string]interface{}{
  1421  		"run_for": "0s",
  1422  	}
  1423  
  1424  	trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1425  	defer cleanup()
  1426  
  1427  	// set a consul token on the Nomad client's consul config, because that is
  1428  	// what gates the action of requesting SI token(s)
  1429  	trConfig.ClientConfig.ConsulConfig.Token = uuid.Generate()
  1430  
  1431  	// control when we get a Consul SI token (recoverable failure on first call)
  1432  	token := uuid.Generate()
  1433  	deriveCount := 0
  1434  	deriveFn := func(*structs.Allocation, []string) (map[string]string, error) {
  1435  		if deriveCount > 0 {
  1436  
  1437  			return map[string]string{task.Name: token}, nil
  1438  		}
  1439  		deriveCount++
  1440  		return nil, structs.NewRecoverableError(errors.New("try again later"), true)
  1441  	}
  1442  	siClient := trConfig.ConsulSI.(*consulapi.MockServiceIdentitiesClient)
  1443  	siClient.DeriveTokenFn = deriveFn
  1444  
  1445  	// start the task runner
  1446  	tr, err := NewTaskRunner(trConfig)
  1447  	r.NoError(err)
  1448  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1449  	useMockEnvoyBootstrapHook(tr) // mock the envoy bootstrap
  1450  	go tr.Run()
  1451  
  1452  	// assert task runner blocks on SI token
  1453  	testWaitForTaskToDie(t, tr)
  1454  
  1455  	// assert task exited successfully
  1456  	finalState := tr.TaskState()
  1457  	r.Equal(structs.TaskStateDead, finalState.State)
  1458  	r.False(finalState.Failed)
  1459  
  1460  	// assert the token is on disk
  1461  	tokenPath := filepath.Join(trConfig.TaskDir.SecretsDir, sidsTokenFile)
  1462  	data, err := ioutil.ReadFile(tokenPath)
  1463  	r.NoError(err)
  1464  	r.Equal(token, string(data))
  1465  }
  1466  
  1467  // TestTaskRunner_DeriveSIToken_Unrecoverable asserts that an unrecoverable error
  1468  // from deriving a service identity token will fail a task.
  1469  func TestTaskRunner_DeriveSIToken_Unrecoverable(t *testing.T) {
  1470  	ci.Parallel(t)
  1471  	r := require.New(t)
  1472  
  1473  	alloc := mock.BatchConnectAlloc()
  1474  	tg := alloc.Job.TaskGroups[0]
  1475  	tg.RestartPolicy.Attempts = 0
  1476  	tg.RestartPolicy.Interval = 0
  1477  	tg.RestartPolicy.Delay = 0
  1478  	tg.RestartPolicy.Mode = structs.RestartPolicyModeFail
  1479  	task := tg.Tasks[0]
  1480  	task.Config = map[string]interface{}{
  1481  		"run_for": "0s",
  1482  	}
  1483  
  1484  	trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1485  	defer cleanup()
  1486  
  1487  	// set a consul token on the Nomad client's consul config, because that is
  1488  	// what gates the action of requesting SI token(s)
  1489  	trConfig.ClientConfig.ConsulConfig.Token = uuid.Generate()
  1490  
  1491  	// SI token derivation suffers a non-retryable error
  1492  	siClient := trConfig.ConsulSI.(*consulapi.MockServiceIdentitiesClient)
  1493  	siClient.SetDeriveTokenError(alloc.ID, []string{task.Name}, errors.New("non-recoverable"))
  1494  
  1495  	tr, err := NewTaskRunner(trConfig)
  1496  	r.NoError(err)
  1497  
  1498  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1499  	useMockEnvoyBootstrapHook(tr) // mock the envoy bootstrap hook
  1500  	go tr.Run()
  1501  
  1502  	// Wait for the task to die
  1503  	select {
  1504  	case <-tr.WaitCh():
  1505  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1506  		require.Fail(t, "timed out waiting for task runner to fail")
  1507  	}
  1508  
  1509  	// assert we have died and failed
  1510  	finalState := tr.TaskState()
  1511  	r.Equal(structs.TaskStateDead, finalState.State)
  1512  	r.True(finalState.Failed)
  1513  	r.Equal(5, len(finalState.Events))
  1514  	/*
  1515  	 + event: Task received by client
  1516  	 + event: Building Task Directory
  1517  	 + event: consul: failed to derive SI token: non-recoverable
  1518  	 + event: consul_sids: context canceled
  1519  	 + event: Policy allows no restarts
  1520  	*/
  1521  	r.Equal("true", finalState.Events[2].Details["fails_task"])
  1522  }
  1523  
  1524  // TestTaskRunner_BlockForVaultToken asserts tasks do not start until a vault token
  1525  // is derived.
  1526  func TestTaskRunner_BlockForVaultToken(t *testing.T) {
  1527  	ci.Parallel(t)
  1528  
  1529  	alloc := mock.BatchAlloc()
  1530  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1531  	task.Config = map[string]interface{}{
  1532  		"run_for": "0s",
  1533  	}
  1534  	task.Vault = &structs.Vault{Policies: []string{"default"}}
  1535  
  1536  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1537  	defer cleanup()
  1538  
  1539  	// Control when we get a Vault token
  1540  	token := "1234"
  1541  	waitCh := make(chan struct{})
  1542  	handler := func(*structs.Allocation, []string) (map[string]string, error) {
  1543  		<-waitCh
  1544  		return map[string]string{task.Name: token}, nil
  1545  	}
  1546  	vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
  1547  	vaultClient.DeriveTokenFn = handler
  1548  
  1549  	tr, err := NewTaskRunner(conf)
  1550  	require.NoError(t, err)
  1551  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1552  	go tr.Run()
  1553  
  1554  	// Assert TR blocks on vault token (does *not* exit)
  1555  	select {
  1556  	case <-tr.WaitCh():
  1557  		require.Fail(t, "tr exited before vault unblocked")
  1558  	case <-time.After(1 * time.Second):
  1559  	}
  1560  
  1561  	// Assert task state is still Pending
  1562  	require.Equal(t, structs.TaskStatePending, tr.TaskState().State)
  1563  
  1564  	// Unblock vault token
  1565  	close(waitCh)
  1566  
  1567  	// TR should exit now that it's unblocked by vault as its a batch job
  1568  	// with 0 sleeping.
  1569  	testWaitForTaskToDie(t, tr)
  1570  
  1571  	// Assert task exited successfully
  1572  	finalState := tr.TaskState()
  1573  	require.Equal(t, structs.TaskStateDead, finalState.State)
  1574  	require.False(t, finalState.Failed)
  1575  
  1576  	// Check that the token is on disk
  1577  	tokenPath := filepath.Join(conf.TaskDir.SecretsDir, vaultTokenFile)
  1578  	data, err := ioutil.ReadFile(tokenPath)
  1579  	require.NoError(t, err)
  1580  	require.Equal(t, token, string(data))
  1581  
  1582  	// Kill task runner to trigger stop hooks
  1583  	tr.Kill(context.Background(), structs.NewTaskEvent("kill"))
  1584  	select {
  1585  	case <-tr.WaitCh():
  1586  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1587  		require.Fail(t, "timed out waiting for task runner to exit")
  1588  	}
  1589  
  1590  	// Check the token was revoked
  1591  	testutil.WaitForResult(func() (bool, error) {
  1592  		if len(vaultClient.StoppedTokens()) != 1 {
  1593  			return false, fmt.Errorf("Expected a stopped token %q but found: %v", token, vaultClient.StoppedTokens())
  1594  		}
  1595  
  1596  		if a := vaultClient.StoppedTokens()[0]; a != token {
  1597  			return false, fmt.Errorf("got stopped token %q; want %q", a, token)
  1598  		}
  1599  		return true, nil
  1600  	}, func(err error) {
  1601  		require.Fail(t, err.Error())
  1602  	})
  1603  }
  1604  
  1605  // TestTaskRunner_DeriveToken_Retry asserts that if a recoverable error is
  1606  // returned when deriving a vault token a task will continue to block while
  1607  // it's retried.
  1608  func TestTaskRunner_DeriveToken_Retry(t *testing.T) {
  1609  	ci.Parallel(t)
  1610  	alloc := mock.BatchAlloc()
  1611  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1612  	task.Vault = &structs.Vault{Policies: []string{"default"}}
  1613  
  1614  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1615  	defer cleanup()
  1616  
  1617  	// Fail on the first attempt to derive a vault token
  1618  	token := "1234"
  1619  	count := 0
  1620  	handler := func(*structs.Allocation, []string) (map[string]string, error) {
  1621  		if count > 0 {
  1622  			return map[string]string{task.Name: token}, nil
  1623  		}
  1624  
  1625  		count++
  1626  		return nil, structs.NewRecoverableError(fmt.Errorf("Want a retry"), true)
  1627  	}
  1628  	vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
  1629  	vaultClient.DeriveTokenFn = handler
  1630  
  1631  	tr, err := NewTaskRunner(conf)
  1632  	require.NoError(t, err)
  1633  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1634  	go tr.Run()
  1635  
  1636  	// Wait for TR to die and check its state
  1637  	testWaitForTaskToDie(t, tr)
  1638  
  1639  	state := tr.TaskState()
  1640  	require.Equal(t, structs.TaskStateDead, state.State)
  1641  	require.False(t, state.Failed)
  1642  
  1643  	// Kill task runner to trigger stop hooks
  1644  	tr.Kill(context.Background(), structs.NewTaskEvent("kill"))
  1645  	select {
  1646  	case <-tr.WaitCh():
  1647  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1648  		require.Fail(t, "timed out waiting for task runner to exit")
  1649  	}
  1650  
  1651  	require.Equal(t, 1, count)
  1652  
  1653  	// Check that the token is on disk
  1654  	tokenPath := filepath.Join(conf.TaskDir.SecretsDir, vaultTokenFile)
  1655  	data, err := ioutil.ReadFile(tokenPath)
  1656  	require.NoError(t, err)
  1657  	require.Equal(t, token, string(data))
  1658  
  1659  	// Check the token was revoked
  1660  	testutil.WaitForResult(func() (bool, error) {
  1661  		if len(vaultClient.StoppedTokens()) != 1 {
  1662  			return false, fmt.Errorf("Expected a stopped token: %v", vaultClient.StoppedTokens())
  1663  		}
  1664  
  1665  		if a := vaultClient.StoppedTokens()[0]; a != token {
  1666  			return false, fmt.Errorf("got stopped token %q; want %q", a, token)
  1667  		}
  1668  		return true, nil
  1669  	}, func(err error) {
  1670  		require.Fail(t, err.Error())
  1671  	})
  1672  }
  1673  
  1674  // TestTaskRunner_DeriveToken_Unrecoverable asserts that an unrecoverable error
  1675  // from deriving a vault token will fail a task.
  1676  func TestTaskRunner_DeriveToken_Unrecoverable(t *testing.T) {
  1677  	ci.Parallel(t)
  1678  
  1679  	// Use a batch job with no restarts
  1680  	alloc := mock.BatchAlloc()
  1681  	tg := alloc.Job.TaskGroups[0]
  1682  	tg.RestartPolicy.Attempts = 0
  1683  	tg.RestartPolicy.Interval = 0
  1684  	tg.RestartPolicy.Delay = 0
  1685  	tg.RestartPolicy.Mode = structs.RestartPolicyModeFail
  1686  	task := tg.Tasks[0]
  1687  	task.Config = map[string]interface{}{
  1688  		"run_for": "0s",
  1689  	}
  1690  	task.Vault = &structs.Vault{Policies: []string{"default"}}
  1691  
  1692  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1693  	defer cleanup()
  1694  
  1695  	// Error the token derivation
  1696  	vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
  1697  	vaultClient.SetDeriveTokenError(alloc.ID, []string{task.Name}, fmt.Errorf("Non recoverable"))
  1698  
  1699  	tr, err := NewTaskRunner(conf)
  1700  	require.NoError(t, err)
  1701  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1702  	go tr.Run()
  1703  
  1704  	// Wait for the task to die
  1705  	select {
  1706  	case <-tr.WaitCh():
  1707  	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
  1708  		require.Fail(t, "timed out waiting for task runner to fail")
  1709  	}
  1710  
  1711  	// Task should be dead and last event should have failed task
  1712  	state := tr.TaskState()
  1713  	require.Equal(t, structs.TaskStateDead, state.State)
  1714  	require.True(t, state.Failed)
  1715  	require.Len(t, state.Events, 3)
  1716  	require.True(t, state.Events[2].FailsTask)
  1717  }
  1718  
  1719  // TestTaskRunner_Download_RawExec asserts that downloaded artifacts may be
  1720  // executed in a driver without filesystem isolation.
  1721  func TestTaskRunner_Download_RawExec(t *testing.T) {
  1722  	ci.Parallel(t)
  1723  
  1724  	ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("."))))
  1725  	defer ts.Close()
  1726  
  1727  	// Create a task that downloads a script and executes it.
  1728  	alloc := mock.BatchAlloc()
  1729  	alloc.Job.TaskGroups[0].RestartPolicy = &structs.RestartPolicy{}
  1730  
  1731  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1732  	task.RestartPolicy = &structs.RestartPolicy{}
  1733  	task.Driver = "raw_exec"
  1734  	task.Config = map[string]interface{}{
  1735  		"command": "noop.sh",
  1736  	}
  1737  	task.Env = map[string]string{
  1738  		"NOMAD_PARENT_CGROUP": "nomad.slice",
  1739  		"NOMAD_ALLOC_ID":      alloc.ID,
  1740  		"NOMAD_TASK_NAME":     task.Name,
  1741  	}
  1742  	task.Artifacts = []*structs.TaskArtifact{
  1743  		{
  1744  			GetterSource: fmt.Sprintf("%s/testdata/noop.sh", ts.URL),
  1745  			GetterMode:   "file",
  1746  			RelativeDest: "noop.sh",
  1747  		},
  1748  	}
  1749  
  1750  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
  1751  	defer cleanup()
  1752  
  1753  	// Wait for task to run and exit
  1754  	testWaitForTaskToDie(t, tr)
  1755  
  1756  	state := tr.TaskState()
  1757  	require.Equal(t, structs.TaskStateDead, state.State)
  1758  	require.False(t, state.Failed)
  1759  }
  1760  
  1761  // TestTaskRunner_Download_List asserts that multiple artificats are downloaded
  1762  // before a task is run.
  1763  func TestTaskRunner_Download_List(t *testing.T) {
  1764  	ci.Parallel(t)
  1765  	ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("."))))
  1766  	defer ts.Close()
  1767  
  1768  	// Create an allocation that has a task with a list of artifacts.
  1769  	alloc := mock.BatchAlloc()
  1770  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1771  	f1 := "task_runner_test.go"
  1772  	f2 := "task_runner.go"
  1773  	artifact1 := structs.TaskArtifact{
  1774  		GetterSource: fmt.Sprintf("%s/%s", ts.URL, f1),
  1775  	}
  1776  	artifact2 := structs.TaskArtifact{
  1777  		GetterSource: fmt.Sprintf("%s/%s", ts.URL, f2),
  1778  	}
  1779  	task.Artifacts = []*structs.TaskArtifact{&artifact1, &artifact2}
  1780  
  1781  	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
  1782  	defer cleanup()
  1783  
  1784  	// Wait for task to run and exit
  1785  	testWaitForTaskToDie(t, tr)
  1786  
  1787  	state := tr.TaskState()
  1788  	require.Equal(t, structs.TaskStateDead, state.State)
  1789  	require.False(t, state.Failed)
  1790  
  1791  	require.Len(t, state.Events, 5)
  1792  	assert.Equal(t, structs.TaskReceived, state.Events[0].Type)
  1793  	assert.Equal(t, structs.TaskSetup, state.Events[1].Type)
  1794  	assert.Equal(t, structs.TaskDownloadingArtifacts, state.Events[2].Type)
  1795  	assert.Equal(t, structs.TaskStarted, state.Events[3].Type)
  1796  	assert.Equal(t, structs.TaskTerminated, state.Events[4].Type)
  1797  
  1798  	// Check that both files exist.
  1799  	_, err := os.Stat(filepath.Join(conf.TaskDir.Dir, f1))
  1800  	require.NoErrorf(t, err, "%v not downloaded", f1)
  1801  
  1802  	_, err = os.Stat(filepath.Join(conf.TaskDir.Dir, f2))
  1803  	require.NoErrorf(t, err, "%v not downloaded", f2)
  1804  }
  1805  
  1806  // TestTaskRunner_Download_Retries asserts that failed artifact downloads are
  1807  // retried according to the task's restart policy.
  1808  func TestTaskRunner_Download_Retries(t *testing.T) {
  1809  	ci.Parallel(t)
  1810  
  1811  	// Create an allocation that has a task with bad artifacts.
  1812  	alloc := mock.BatchAlloc()
  1813  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1814  	artifact := structs.TaskArtifact{
  1815  		GetterSource: "http://127.0.0.1:0/foo/bar/baz",
  1816  	}
  1817  	task.Artifacts = []*structs.TaskArtifact{&artifact}
  1818  
  1819  	// Make the restart policy retry once
  1820  	rp := &structs.RestartPolicy{
  1821  		Attempts: 1,
  1822  		Interval: 10 * time.Minute,
  1823  		Delay:    1 * time.Second,
  1824  		Mode:     structs.RestartPolicyModeFail,
  1825  	}
  1826  	alloc.Job.TaskGroups[0].RestartPolicy = rp
  1827  	alloc.Job.TaskGroups[0].Tasks[0].RestartPolicy = rp
  1828  
  1829  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
  1830  	defer cleanup()
  1831  
  1832  	testWaitForTaskToDie(t, tr)
  1833  
  1834  	state := tr.TaskState()
  1835  	require.Equal(t, structs.TaskStateDead, state.State)
  1836  	require.True(t, state.Failed)
  1837  	require.Len(t, state.Events, 8, pretty.Sprint(state.Events))
  1838  	require.Equal(t, structs.TaskReceived, state.Events[0].Type)
  1839  	require.Equal(t, structs.TaskSetup, state.Events[1].Type)
  1840  	require.Equal(t, structs.TaskDownloadingArtifacts, state.Events[2].Type)
  1841  	require.Equal(t, structs.TaskArtifactDownloadFailed, state.Events[3].Type)
  1842  	require.Equal(t, structs.TaskRestarting, state.Events[4].Type)
  1843  	require.Equal(t, structs.TaskDownloadingArtifacts, state.Events[5].Type)
  1844  	require.Equal(t, structs.TaskArtifactDownloadFailed, state.Events[6].Type)
  1845  	require.Equal(t, structs.TaskNotRestarting, state.Events[7].Type)
  1846  }
  1847  
  1848  // TestTaskRunner_DriverNetwork asserts that a driver's network is properly
  1849  // used in services and checks.
  1850  func TestTaskRunner_DriverNetwork(t *testing.T) {
  1851  	ci.Parallel(t)
  1852  
  1853  	alloc := mock.Alloc()
  1854  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1855  	task.Driver = "mock_driver"
  1856  	task.Config = map[string]interface{}{
  1857  		"run_for":         "100s",
  1858  		"driver_ip":       "10.1.2.3",
  1859  		"driver_port_map": "http:80",
  1860  	}
  1861  
  1862  	// Create services and checks with custom address modes to exercise
  1863  	// address detection logic
  1864  	task.Services = []*structs.Service{
  1865  		{
  1866  			Name:        "host-service",
  1867  			PortLabel:   "http",
  1868  			AddressMode: "host",
  1869  			Provider:    structs.ServiceProviderConsul,
  1870  			Checks: []*structs.ServiceCheck{
  1871  				{
  1872  					Name:        "driver-check",
  1873  					Type:        "tcp",
  1874  					PortLabel:   "1234",
  1875  					AddressMode: "driver",
  1876  				},
  1877  			},
  1878  		},
  1879  		{
  1880  			Name:        "driver-service",
  1881  			PortLabel:   "5678",
  1882  			AddressMode: "driver",
  1883  			Provider:    structs.ServiceProviderConsul,
  1884  			Checks: []*structs.ServiceCheck{
  1885  				{
  1886  					Name:      "host-check",
  1887  					Type:      "tcp",
  1888  					PortLabel: "http",
  1889  				},
  1890  				{
  1891  					Name:        "driver-label-check",
  1892  					Type:        "tcp",
  1893  					PortLabel:   "http",
  1894  					AddressMode: "driver",
  1895  				},
  1896  			},
  1897  		},
  1898  	}
  1899  
  1900  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  1901  	defer cleanup()
  1902  
  1903  	// Use a mock agent to test for services
  1904  	consulAgent := agentconsul.NewMockAgent(agentconsul.Features{
  1905  		Enterprise: false,
  1906  		Namespaces: false,
  1907  	})
  1908  	namespacesClient := agentconsul.NewNamespacesClient(agentconsul.NewMockNamespaces(nil), consulAgent)
  1909  	consulClient := agentconsul.NewServiceClient(consulAgent, namespacesClient, conf.Logger, true)
  1910  	defer consulClient.Shutdown()
  1911  	go consulClient.Run()
  1912  
  1913  	conf.Consul = consulClient
  1914  	conf.ServiceRegWrapper = wrapper.NewHandlerWrapper(conf.Logger, consulClient, nil)
  1915  
  1916  	tr, err := NewTaskRunner(conf)
  1917  	require.NoError(t, err)
  1918  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  1919  	go tr.Run()
  1920  
  1921  	// Wait for the task to start
  1922  	testWaitForTaskToStart(t, tr)
  1923  
  1924  	testutil.WaitForResult(func() (bool, error) {
  1925  		services, _ := consulAgent.ServicesWithFilterOpts("", nil)
  1926  		if n := len(services); n != 2 {
  1927  			return false, fmt.Errorf("expected 2 services, but found %d", n)
  1928  		}
  1929  		for _, s := range services {
  1930  			switch s.Service {
  1931  			case "host-service":
  1932  				if expected := "192.168.0.100"; s.Address != expected {
  1933  					return false, fmt.Errorf("expected host-service to have IP=%s but found %s",
  1934  						expected, s.Address)
  1935  				}
  1936  			case "driver-service":
  1937  				if expected := "10.1.2.3"; s.Address != expected {
  1938  					return false, fmt.Errorf("expected driver-service to have IP=%s but found %s",
  1939  						expected, s.Address)
  1940  				}
  1941  				if expected := 5678; s.Port != expected {
  1942  					return false, fmt.Errorf("expected driver-service to have port=%d but found %d",
  1943  						expected, s.Port)
  1944  				}
  1945  			default:
  1946  				return false, fmt.Errorf("unexpected service: %q", s.Service)
  1947  			}
  1948  
  1949  		}
  1950  
  1951  		checks := consulAgent.CheckRegs()
  1952  		if n := len(checks); n != 3 {
  1953  			return false, fmt.Errorf("expected 3 checks, but found %d", n)
  1954  		}
  1955  		for _, check := range checks {
  1956  			switch check.Name {
  1957  			case "driver-check":
  1958  				if expected := "10.1.2.3:1234"; check.TCP != expected {
  1959  					return false, fmt.Errorf("expected driver-check to have address %q but found %q", expected, check.TCP)
  1960  				}
  1961  			case "driver-label-check":
  1962  				if expected := "10.1.2.3:80"; check.TCP != expected {
  1963  					return false, fmt.Errorf("expected driver-label-check to have address %q but found %q", expected, check.TCP)
  1964  				}
  1965  			case "host-check":
  1966  				if expected := "192.168.0.100:"; !strings.HasPrefix(check.TCP, expected) {
  1967  					return false, fmt.Errorf("expected host-check to have address start with %q but found %q", expected, check.TCP)
  1968  				}
  1969  			default:
  1970  				return false, fmt.Errorf("unexpected check: %q", check.Name)
  1971  			}
  1972  		}
  1973  
  1974  		return true, nil
  1975  	}, func(err error) {
  1976  		services, _ := consulAgent.ServicesWithFilterOpts("", nil)
  1977  		for _, s := range services {
  1978  			t.Logf(pretty.Sprint("Service: ", s))
  1979  		}
  1980  		for _, c := range consulAgent.CheckRegs() {
  1981  			t.Logf(pretty.Sprint("Check:   ", c))
  1982  		}
  1983  		require.NoError(t, err)
  1984  	})
  1985  }
  1986  
  1987  // TestTaskRunner_RestartSignalTask_NotRunning asserts resilience to failures
  1988  // when a restart or signal is triggered and the task is not running.
  1989  func TestTaskRunner_RestartSignalTask_NotRunning(t *testing.T) {
  1990  	ci.Parallel(t)
  1991  
  1992  	alloc := mock.BatchAlloc()
  1993  	task := alloc.Job.TaskGroups[0].Tasks[0]
  1994  	task.Driver = "mock_driver"
  1995  	task.Config = map[string]interface{}{
  1996  		"run_for": "0s",
  1997  	}
  1998  
  1999  	// Use vault to block the start
  2000  	task.Vault = &structs.Vault{Policies: []string{"default"}}
  2001  
  2002  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  2003  	defer cleanup()
  2004  
  2005  	// Control when we get a Vault token
  2006  	waitCh := make(chan struct{}, 1)
  2007  	defer close(waitCh)
  2008  	handler := func(*structs.Allocation, []string) (map[string]string, error) {
  2009  		<-waitCh
  2010  		return map[string]string{task.Name: "1234"}, nil
  2011  	}
  2012  	vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
  2013  	vaultClient.DeriveTokenFn = handler
  2014  
  2015  	tr, err := NewTaskRunner(conf)
  2016  	require.NoError(t, err)
  2017  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  2018  	go tr.Run()
  2019  
  2020  	select {
  2021  	case <-tr.WaitCh():
  2022  		require.Fail(t, "unexpected exit")
  2023  	case <-time.After(1 * time.Second):
  2024  	}
  2025  
  2026  	require.Equal(t, structs.TaskStatePending, tr.TaskState().State)
  2027  
  2028  	// Send a signal and restart
  2029  	err = tr.Signal(structs.NewTaskEvent("don't panic"), "QUIT")
  2030  	require.EqualError(t, err, ErrTaskNotRunning.Error())
  2031  
  2032  	// Send a restart
  2033  	err = tr.Restart(context.Background(), structs.NewTaskEvent("don't panic"), false)
  2034  	require.EqualError(t, err, ErrTaskNotRunning.Error())
  2035  
  2036  	// Unblock and let it finish
  2037  	waitCh <- struct{}{}
  2038  	testWaitForTaskToDie(t, tr)
  2039  
  2040  	// Assert the task ran and never restarted
  2041  	state := tr.TaskState()
  2042  	require.Equal(t, structs.TaskStateDead, state.State)
  2043  	require.False(t, state.Failed)
  2044  	require.Len(t, state.Events, 4, pretty.Sprint(state.Events))
  2045  	require.Equal(t, structs.TaskReceived, state.Events[0].Type)
  2046  	require.Equal(t, structs.TaskSetup, state.Events[1].Type)
  2047  	require.Equal(t, structs.TaskStarted, state.Events[2].Type)
  2048  	require.Equal(t, structs.TaskTerminated, state.Events[3].Type)
  2049  }
  2050  
  2051  // TestTaskRunner_Run_RecoverableStartError asserts tasks are restarted if they
  2052  // return a recoverable error from StartTask.
  2053  func TestTaskRunner_Run_RecoverableStartError(t *testing.T) {
  2054  	ci.Parallel(t)
  2055  
  2056  	alloc := mock.BatchAlloc()
  2057  	task := alloc.Job.TaskGroups[0].Tasks[0]
  2058  	task.Config = map[string]interface{}{
  2059  		"start_error":             "driver failure",
  2060  		"start_error_recoverable": true,
  2061  	}
  2062  
  2063  	// Make the restart policy retry once
  2064  	rp := &structs.RestartPolicy{
  2065  		Attempts: 1,
  2066  		Interval: 10 * time.Minute,
  2067  		Delay:    0,
  2068  		Mode:     structs.RestartPolicyModeFail,
  2069  	}
  2070  	alloc.Job.TaskGroups[0].RestartPolicy = rp
  2071  	alloc.Job.TaskGroups[0].Tasks[0].RestartPolicy = rp
  2072  
  2073  	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
  2074  	defer cleanup()
  2075  
  2076  	testWaitForTaskToDie(t, tr)
  2077  
  2078  	state := tr.TaskState()
  2079  	require.Equal(t, structs.TaskStateDead, state.State)
  2080  	require.True(t, state.Failed)
  2081  	require.Len(t, state.Events, 6, pretty.Sprint(state.Events))
  2082  	require.Equal(t, structs.TaskReceived, state.Events[0].Type)
  2083  	require.Equal(t, structs.TaskSetup, state.Events[1].Type)
  2084  	require.Equal(t, structs.TaskDriverFailure, state.Events[2].Type)
  2085  	require.Equal(t, structs.TaskRestarting, state.Events[3].Type)
  2086  	require.Equal(t, structs.TaskDriverFailure, state.Events[4].Type)
  2087  	require.Equal(t, structs.TaskNotRestarting, state.Events[5].Type)
  2088  }
  2089  
  2090  // TestTaskRunner_Template_Artifact asserts that tasks can use artifacts as templates.
  2091  func TestTaskRunner_Template_Artifact(t *testing.T) {
  2092  	ci.Parallel(t)
  2093  
  2094  	ts := httptest.NewServer(http.FileServer(http.Dir(".")))
  2095  	defer ts.Close()
  2096  
  2097  	alloc := mock.BatchAlloc()
  2098  	task := alloc.Job.TaskGroups[0].Tasks[0]
  2099  	f1 := "task_runner.go"
  2100  	f2 := "test"
  2101  	task.Artifacts = []*structs.TaskArtifact{
  2102  		{GetterSource: fmt.Sprintf("%s/%s", ts.URL, f1)},
  2103  	}
  2104  	task.Templates = []*structs.Template{
  2105  		{
  2106  			SourcePath: f1,
  2107  			DestPath:   "local/test",
  2108  			ChangeMode: structs.TemplateChangeModeNoop,
  2109  		},
  2110  	}
  2111  
  2112  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  2113  	defer cleanup()
  2114  
  2115  	tr, err := NewTaskRunner(conf)
  2116  	require.NoError(t, err)
  2117  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  2118  	go tr.Run()
  2119  
  2120  	// Wait for task to run and exit
  2121  	testWaitForTaskToDie(t, tr)
  2122  
  2123  	state := tr.TaskState()
  2124  	require.Equal(t, structs.TaskStateDead, state.State)
  2125  	require.True(t, state.Successful())
  2126  	require.False(t, state.Failed)
  2127  
  2128  	artifactsDownloaded := false
  2129  	for _, e := range state.Events {
  2130  		if e.Type == structs.TaskDownloadingArtifacts {
  2131  			artifactsDownloaded = true
  2132  		}
  2133  	}
  2134  	assert.True(t, artifactsDownloaded, "expected artifacts downloaded events")
  2135  
  2136  	// Check that both files exist.
  2137  	_, err = os.Stat(filepath.Join(conf.TaskDir.Dir, f1))
  2138  	require.NoErrorf(t, err, "%v not downloaded", f1)
  2139  
  2140  	_, err = os.Stat(filepath.Join(conf.TaskDir.LocalDir, f2))
  2141  	require.NoErrorf(t, err, "%v not rendered", f2)
  2142  }
  2143  
  2144  // TestTaskRunner_Template_BlockingPreStart asserts that a template
  2145  // that fails to render in PreStart can gracefully be shutdown by
  2146  // either killCtx or shutdownCtx
  2147  func TestTaskRunner_Template_BlockingPreStart(t *testing.T) {
  2148  	ci.Parallel(t)
  2149  
  2150  	alloc := mock.BatchAlloc()
  2151  	task := alloc.Job.TaskGroups[0].Tasks[0]
  2152  	task.Templates = []*structs.Template{
  2153  		{
  2154  			EmbeddedTmpl: `{{ with secret "foo/secret" }}{{ .Data.certificate }}{{ end }}`,
  2155  			DestPath:     "local/test",
  2156  			ChangeMode:   structs.TemplateChangeModeNoop,
  2157  		},
  2158  	}
  2159  
  2160  	task.Vault = &structs.Vault{Policies: []string{"default"}}
  2161  
  2162  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  2163  	defer cleanup()
  2164  
  2165  	tr, err := NewTaskRunner(conf)
  2166  	require.NoError(t, err)
  2167  	go tr.Run()
  2168  	defer tr.Shutdown()
  2169  
  2170  	testutil.WaitForResult(func() (bool, error) {
  2171  		ts := tr.TaskState()
  2172  
  2173  		if len(ts.Events) == 0 {
  2174  			return false, fmt.Errorf("no events yet")
  2175  		}
  2176  
  2177  		for _, e := range ts.Events {
  2178  			if e.Type == "Template" && strings.Contains(e.DisplayMessage, "vault.read(foo/secret)") {
  2179  				return true, nil
  2180  			}
  2181  		}
  2182  
  2183  		return false, fmt.Errorf("no missing vault secret template event yet: %#v", ts.Events)
  2184  
  2185  	}, func(err error) {
  2186  		require.NoError(t, err)
  2187  	})
  2188  
  2189  	shutdown := func() <-chan bool {
  2190  		finished := make(chan bool)
  2191  		go func() {
  2192  			tr.Shutdown()
  2193  			finished <- true
  2194  		}()
  2195  
  2196  		return finished
  2197  	}
  2198  
  2199  	select {
  2200  	case <-shutdown():
  2201  		// it shut down like it should have
  2202  	case <-time.After(10 * time.Second):
  2203  		require.Fail(t, "timeout shutting down task")
  2204  	}
  2205  }
  2206  
  2207  // TestTaskRunner_Template_NewVaultToken asserts that a new vault token is
  2208  // created when rendering template and that it is revoked on alloc completion
  2209  func TestTaskRunner_Template_NewVaultToken(t *testing.T) {
  2210  	ci.Parallel(t)
  2211  
  2212  	alloc := mock.BatchAlloc()
  2213  	task := alloc.Job.TaskGroups[0].Tasks[0]
  2214  	task.Templates = []*structs.Template{
  2215  		{
  2216  			EmbeddedTmpl: `{{key "foo"}}`,
  2217  			DestPath:     "local/test",
  2218  			ChangeMode:   structs.TemplateChangeModeNoop,
  2219  		},
  2220  	}
  2221  	task.Vault = &structs.Vault{Policies: []string{"default"}}
  2222  
  2223  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  2224  	defer cleanup()
  2225  
  2226  	tr, err := NewTaskRunner(conf)
  2227  	require.NoError(t, err)
  2228  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  2229  	go tr.Run()
  2230  
  2231  	// Wait for a Vault token
  2232  	var token string
  2233  	testutil.WaitForResult(func() (bool, error) {
  2234  		token = tr.getVaultToken()
  2235  
  2236  		if token == "" {
  2237  			return false, fmt.Errorf("No Vault token")
  2238  		}
  2239  
  2240  		return true, nil
  2241  	}, func(err error) {
  2242  		require.NoError(t, err)
  2243  	})
  2244  
  2245  	vault := conf.Vault.(*vaultclient.MockVaultClient)
  2246  	renewalCh, ok := vault.RenewTokens()[token]
  2247  	require.True(t, ok, "no renewal channel for token")
  2248  
  2249  	renewalCh <- fmt.Errorf("Test killing")
  2250  	close(renewalCh)
  2251  
  2252  	var token2 string
  2253  	testutil.WaitForResult(func() (bool, error) {
  2254  		token2 = tr.getVaultToken()
  2255  
  2256  		if token2 == "" {
  2257  			return false, fmt.Errorf("No Vault token")
  2258  		}
  2259  
  2260  		if token2 == token {
  2261  			return false, fmt.Errorf("token wasn't recreated")
  2262  		}
  2263  
  2264  		return true, nil
  2265  	}, func(err error) {
  2266  		require.NoError(t, err)
  2267  	})
  2268  
  2269  	// Check the token was revoked
  2270  	testutil.WaitForResult(func() (bool, error) {
  2271  		if len(vault.StoppedTokens()) != 1 {
  2272  			return false, fmt.Errorf("Expected a stopped token: %v", vault.StoppedTokens())
  2273  		}
  2274  
  2275  		if a := vault.StoppedTokens()[0]; a != token {
  2276  			return false, fmt.Errorf("got stopped token %q; want %q", a, token)
  2277  		}
  2278  
  2279  		return true, nil
  2280  	}, func(err error) {
  2281  		require.NoError(t, err)
  2282  	})
  2283  
  2284  }
  2285  
  2286  // TestTaskRunner_VaultManager_Restart asserts that the alloc is restarted when the alloc
  2287  // derived vault token expires, when task is configured with Restart change mode
  2288  func TestTaskRunner_VaultManager_Restart(t *testing.T) {
  2289  	ci.Parallel(t)
  2290  
  2291  	alloc := mock.BatchAlloc()
  2292  	task := alloc.Job.TaskGroups[0].Tasks[0]
  2293  	task.Config = map[string]interface{}{
  2294  		"run_for": "10s",
  2295  	}
  2296  	task.Vault = &structs.Vault{
  2297  		Policies:   []string{"default"},
  2298  		ChangeMode: structs.VaultChangeModeRestart,
  2299  	}
  2300  
  2301  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  2302  	defer cleanup()
  2303  
  2304  	tr, err := NewTaskRunner(conf)
  2305  	require.NoError(t, err)
  2306  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  2307  	go tr.Run()
  2308  
  2309  	testWaitForTaskToStart(t, tr)
  2310  
  2311  	tr.vaultTokenLock.Lock()
  2312  	token := tr.vaultToken
  2313  	tr.vaultTokenLock.Unlock()
  2314  
  2315  	require.NotEmpty(t, token)
  2316  
  2317  	vault := conf.Vault.(*vaultclient.MockVaultClient)
  2318  	renewalCh, ok := vault.RenewTokens()[token]
  2319  	require.True(t, ok, "no renewal channel for token")
  2320  
  2321  	renewalCh <- fmt.Errorf("Test killing")
  2322  	close(renewalCh)
  2323  
  2324  	testutil.WaitForResult(func() (bool, error) {
  2325  		state := tr.TaskState()
  2326  
  2327  		if len(state.Events) == 0 {
  2328  			return false, fmt.Errorf("no events yet")
  2329  		}
  2330  
  2331  		foundRestartSignal, foundRestarting := false, false
  2332  		for _, e := range state.Events {
  2333  			switch e.Type {
  2334  			case structs.TaskRestartSignal:
  2335  				foundRestartSignal = true
  2336  			case structs.TaskRestarting:
  2337  				foundRestarting = true
  2338  			}
  2339  		}
  2340  
  2341  		if !foundRestartSignal {
  2342  			return false, fmt.Errorf("no restart signal event yet: %#v", state.Events)
  2343  		}
  2344  
  2345  		if !foundRestarting {
  2346  			return false, fmt.Errorf("no restarting event yet: %#v", state.Events)
  2347  		}
  2348  
  2349  		lastEvent := state.Events[len(state.Events)-1]
  2350  		if lastEvent.Type != structs.TaskStarted {
  2351  			return false, fmt.Errorf("expected last event to be task starting but was %#v", lastEvent)
  2352  		}
  2353  		return true, nil
  2354  	}, func(err error) {
  2355  		require.NoError(t, err)
  2356  	})
  2357  }
  2358  
  2359  // TestTaskRunner_VaultManager_Signal asserts that the alloc is signalled when the alloc
  2360  // derived vault token expires, when task is configured with signal change mode
  2361  func TestTaskRunner_VaultManager_Signal(t *testing.T) {
  2362  	ci.Parallel(t)
  2363  
  2364  	alloc := mock.BatchAlloc()
  2365  	task := alloc.Job.TaskGroups[0].Tasks[0]
  2366  	task.Config = map[string]interface{}{
  2367  		"run_for": "10s",
  2368  	}
  2369  	task.Vault = &structs.Vault{
  2370  		Policies:     []string{"default"},
  2371  		ChangeMode:   structs.VaultChangeModeSignal,
  2372  		ChangeSignal: "SIGUSR1",
  2373  	}
  2374  
  2375  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  2376  	defer cleanup()
  2377  
  2378  	tr, err := NewTaskRunner(conf)
  2379  	require.NoError(t, err)
  2380  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  2381  	go tr.Run()
  2382  
  2383  	testWaitForTaskToStart(t, tr)
  2384  
  2385  	tr.vaultTokenLock.Lock()
  2386  	token := tr.vaultToken
  2387  	tr.vaultTokenLock.Unlock()
  2388  
  2389  	require.NotEmpty(t, token)
  2390  
  2391  	vault := conf.Vault.(*vaultclient.MockVaultClient)
  2392  	renewalCh, ok := vault.RenewTokens()[token]
  2393  	require.True(t, ok, "no renewal channel for token")
  2394  
  2395  	renewalCh <- fmt.Errorf("Test killing")
  2396  	close(renewalCh)
  2397  
  2398  	testutil.WaitForResult(func() (bool, error) {
  2399  		state := tr.TaskState()
  2400  
  2401  		if len(state.Events) == 0 {
  2402  			return false, fmt.Errorf("no events yet")
  2403  		}
  2404  
  2405  		foundSignaling := false
  2406  		for _, e := range state.Events {
  2407  			if e.Type == structs.TaskSignaling {
  2408  				foundSignaling = true
  2409  			}
  2410  		}
  2411  
  2412  		if !foundSignaling {
  2413  			return false, fmt.Errorf("no signaling event yet: %#v", state.Events)
  2414  		}
  2415  
  2416  		return true, nil
  2417  	}, func(err error) {
  2418  		require.NoError(t, err)
  2419  	})
  2420  
  2421  }
  2422  
  2423  // TestTaskRunner_UnregisterConsul_Retries asserts a task is unregistered from
  2424  // Consul when waiting to be retried.
  2425  func TestTaskRunner_UnregisterConsul_Retries(t *testing.T) {
  2426  	ci.Parallel(t)
  2427  
  2428  	alloc := mock.Alloc()
  2429  	// Make the restart policy try one ctx.update
  2430  	rp := &structs.RestartPolicy{
  2431  		Attempts: 1,
  2432  		Interval: 10 * time.Minute,
  2433  		Delay:    time.Nanosecond,
  2434  		Mode:     structs.RestartPolicyModeFail,
  2435  	}
  2436  	alloc.Job.TaskGroups[0].RestartPolicy = rp
  2437  	task := alloc.Job.TaskGroups[0].Tasks[0]
  2438  	task.RestartPolicy = rp
  2439  	task.Driver = "mock_driver"
  2440  	task.Config = map[string]interface{}{
  2441  		"exit_code": "1",
  2442  		"run_for":   "1ns",
  2443  	}
  2444  
  2445  	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  2446  	defer cleanup()
  2447  
  2448  	tr, err := NewTaskRunner(conf)
  2449  	require.NoError(t, err)
  2450  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
  2451  	go tr.Run()
  2452  
  2453  	testWaitForTaskToDie(t, tr)
  2454  
  2455  	state := tr.TaskState()
  2456  	require.Equal(t, structs.TaskStateDead, state.State)
  2457  
  2458  	consul := conf.Consul.(*regMock.ServiceRegistrationHandler)
  2459  	consulOps := consul.GetOps()
  2460  	require.Len(t, consulOps, 4)
  2461  
  2462  	// Initial add
  2463  	require.Equal(t, "add", consulOps[0].Op)
  2464  
  2465  	// Removing entries on first exit
  2466  	require.Equal(t, "remove", consulOps[1].Op)
  2467  
  2468  	// Second add on retry
  2469  	require.Equal(t, "add", consulOps[2].Op)
  2470  
  2471  	// Removing entries on retry
  2472  	require.Equal(t, "remove", consulOps[3].Op)
  2473  }
  2474  
  2475  // testWaitForTaskToStart waits for the task to be running or fails the test
  2476  func testWaitForTaskToStart(t *testing.T, tr *TaskRunner) {
  2477  	testutil.WaitForResult(func() (bool, error) {
  2478  		ts := tr.TaskState()
  2479  		return ts.State == structs.TaskStateRunning, fmt.Errorf("expected task to be running, got %v", ts.State)
  2480  	}, func(err error) {
  2481  		require.NoError(t, err)
  2482  	})
  2483  }
  2484  
  2485  // testWaitForTaskToDie waits for the task to die or fails the test
  2486  func testWaitForTaskToDie(t *testing.T, tr *TaskRunner) {
  2487  	testutil.WaitForResult(func() (bool, error) {
  2488  		ts := tr.TaskState()
  2489  		return ts.State == structs.TaskStateDead, fmt.Errorf("expected task to be dead, got %v", ts.State)
  2490  	}, func(err error) {
  2491  		require.NoError(t, err)
  2492  	})
  2493  }
  2494  
  2495  // TestTaskRunner_BaseLabels tests that the base labels for the task metrics
  2496  // are set appropriately.
  2497  func TestTaskRunner_BaseLabels(t *testing.T) {
  2498  	ci.Parallel(t)
  2499  	require := require.New(t)
  2500  
  2501  	alloc := mock.BatchAlloc()
  2502  	alloc.Namespace = "not-default"
  2503  	task := alloc.Job.TaskGroups[0].Tasks[0]
  2504  	task.Driver = "raw_exec"
  2505  	task.Config = map[string]interface{}{
  2506  		"command": "whoami",
  2507  	}
  2508  
  2509  	config, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
  2510  	defer cleanup()
  2511  
  2512  	tr, err := NewTaskRunner(config)
  2513  	require.NoError(err)
  2514  
  2515  	labels := map[string]string{}
  2516  	for _, e := range tr.baseLabels {
  2517  		labels[e.Name] = e.Value
  2518  	}
  2519  	require.Equal(alloc.Job.Name, labels["job"])
  2520  	require.Equal(alloc.TaskGroup, labels["task_group"])
  2521  	require.Equal(task.Name, labels["task"])
  2522  	require.Equal(alloc.ID, labels["alloc_id"])
  2523  	require.Equal(alloc.Namespace, labels["namespace"])
  2524  }