github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/allocrunner/taskrunner/sids_hook_test.go (about)

     1  //go:build !windows
     2  // +build !windows
     3  
     4  // todo(shoenig): Once Connect is supported on Windows, we'll need to make this
     5  //  set of tests work there too.
     6  
     7  package taskrunner
     8  
     9  import (
    10  	"context"
    11  	"io/ioutil"
    12  	"os"
    13  	"path/filepath"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/hashicorp/nomad/ci"
    18  	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
    19  	consulapi "github.com/hashicorp/nomad/client/consul"
    20  	"github.com/hashicorp/nomad/helper"
    21  	"github.com/hashicorp/nomad/helper/testlog"
    22  	"github.com/hashicorp/nomad/helper/uuid"
    23  	"github.com/hashicorp/nomad/nomad/mock"
    24  	"github.com/hashicorp/nomad/nomad/structs"
    25  	"github.com/stretchr/testify/require"
    26  	"golang.org/x/sys/unix"
    27  )
    28  
    29  var _ interfaces.TaskPrestartHook = (*sidsHook)(nil)
    30  
    31  func sidecar(task string) (string, structs.TaskKind) {
    32  	name := structs.ConnectProxyPrefix + "-" + task
    33  	kind := structs.TaskKind(structs.ConnectProxyPrefix + ":" + task)
    34  	return name, kind
    35  }
    36  
    37  func TestSIDSHook_recoverToken(t *testing.T) {
    38  	ci.Parallel(t)
    39  	r := require.New(t)
    40  
    41  	secrets := t.TempDir()
    42  
    43  	taskName, taskKind := sidecar("foo")
    44  	h := newSIDSHook(sidsHookConfig{
    45  		task: &structs.Task{
    46  			Name: taskName,
    47  			Kind: taskKind,
    48  		},
    49  		logger: testlog.HCLogger(t),
    50  	})
    51  
    52  	expected := uuid.Generate()
    53  	err := h.writeToken(secrets, expected)
    54  	r.NoError(err)
    55  
    56  	token, err := h.recoverToken(secrets)
    57  	r.NoError(err)
    58  	r.Equal(expected, token)
    59  }
    60  
    61  func TestSIDSHook_recoverToken_empty(t *testing.T) {
    62  	ci.Parallel(t)
    63  	r := require.New(t)
    64  
    65  	secrets := t.TempDir()
    66  
    67  	taskName, taskKind := sidecar("foo")
    68  	h := newSIDSHook(sidsHookConfig{
    69  		task: &structs.Task{
    70  			Name: taskName,
    71  			Kind: taskKind,
    72  		},
    73  		logger: testlog.HCLogger(t),
    74  	})
    75  
    76  	token, err := h.recoverToken(secrets)
    77  	r.NoError(err)
    78  	r.Empty(token)
    79  }
    80  
    81  func TestSIDSHook_recoverToken_unReadable(t *testing.T) {
    82  	ci.Parallel(t)
    83  	// This test fails when running as root because the test case for checking
    84  	// the error condition when the file is unreadable fails (root can read the
    85  	// file even though the permissions are set to 0200).
    86  	if unix.Geteuid() == 0 {
    87  		t.Skip("test only works as non-root")
    88  	}
    89  
    90  	r := require.New(t)
    91  
    92  	secrets := t.TempDir()
    93  
    94  	err := os.Chmod(secrets, 0000)
    95  	r.NoError(err)
    96  
    97  	taskName, taskKind := sidecar("foo")
    98  	h := newSIDSHook(sidsHookConfig{
    99  		task: &structs.Task{
   100  			Name: taskName,
   101  			Kind: taskKind,
   102  		},
   103  		logger: testlog.HCLogger(t),
   104  	})
   105  
   106  	_, err = h.recoverToken(secrets)
   107  	r.Error(err)
   108  }
   109  
   110  func TestSIDSHook_writeToken(t *testing.T) {
   111  	ci.Parallel(t)
   112  	r := require.New(t)
   113  
   114  	secrets := t.TempDir()
   115  
   116  	id := uuid.Generate()
   117  	h := new(sidsHook)
   118  	err := h.writeToken(secrets, id)
   119  	r.NoError(err)
   120  
   121  	content, err := ioutil.ReadFile(filepath.Join(secrets, sidsTokenFile))
   122  	r.NoError(err)
   123  	r.Equal(id, string(content))
   124  }
   125  
   126  func TestSIDSHook_writeToken_unWritable(t *testing.T) {
   127  	ci.Parallel(t)
   128  	// This test fails when running as root because the test case for checking
   129  	// the error condition when the file is unreadable fails (root can read the
   130  	// file even though the permissions are set to 0200).
   131  	if unix.Geteuid() == 0 {
   132  		t.Skip("test only works as non-root")
   133  	}
   134  
   135  	r := require.New(t)
   136  
   137  	secrets := t.TempDir()
   138  
   139  	err := os.Chmod(secrets, 0000)
   140  	r.NoError(err)
   141  
   142  	id := uuid.Generate()
   143  	h := new(sidsHook)
   144  	err = h.writeToken(secrets, id)
   145  	r.Error(err)
   146  }
   147  
   148  func Test_SIDSHook_writeToken_nonExistent(t *testing.T) {
   149  	ci.Parallel(t)
   150  	r := require.New(t)
   151  
   152  	base := t.TempDir()
   153  	secrets := filepath.Join(base, "does/not/exist")
   154  
   155  	id := uuid.Generate()
   156  	h := new(sidsHook)
   157  	err := h.writeToken(secrets, id)
   158  	r.Error(err)
   159  }
   160  
   161  func TestSIDSHook_deriveSIToken(t *testing.T) {
   162  	ci.Parallel(t)
   163  	r := require.New(t)
   164  
   165  	taskName, taskKind := sidecar("task1")
   166  	h := newSIDSHook(sidsHookConfig{
   167  		alloc: &structs.Allocation{ID: "a1"},
   168  		task: &structs.Task{
   169  			Name: taskName,
   170  			Kind: taskKind,
   171  		},
   172  		logger:     testlog.HCLogger(t),
   173  		sidsClient: consulapi.NewMockServiceIdentitiesClient(),
   174  	})
   175  
   176  	ctx := context.Background()
   177  	token, err := h.deriveSIToken(ctx)
   178  	r.NoError(err)
   179  	r.True(helper.IsUUID(token), "token: %q", token)
   180  }
   181  
   182  func TestSIDSHook_deriveSIToken_timeout(t *testing.T) {
   183  	ci.Parallel(t)
   184  	r := require.New(t)
   185  
   186  	siClient := consulapi.NewMockServiceIdentitiesClient()
   187  	siClient.DeriveTokenFn = func(allocation *structs.Allocation, strings []string) (m map[string]string, err error) {
   188  		select {
   189  		// block forever, hopefully triggering a timeout in the caller
   190  		}
   191  	}
   192  
   193  	taskName, taskKind := sidecar("task1")
   194  	h := newSIDSHook(sidsHookConfig{
   195  		alloc: &structs.Allocation{ID: "a1"},
   196  		task: &structs.Task{
   197  			Name: taskName,
   198  			Kind: taskKind,
   199  		},
   200  		logger:     testlog.HCLogger(t),
   201  		sidsClient: siClient,
   202  	})
   203  
   204  	// set the timeout to a really small value for testing
   205  	h.derivationTimeout = time.Duration(1 * time.Millisecond)
   206  
   207  	ctx := context.Background()
   208  	_, err := h.deriveSIToken(ctx)
   209  	r.EqualError(err, "context deadline exceeded")
   210  }
   211  
   212  func TestSIDSHook_computeBackoff(t *testing.T) {
   213  	ci.Parallel(t)
   214  
   215  	try := func(i int, exp time.Duration) {
   216  		result := computeBackoff(i)
   217  		require.Equal(t, exp, result)
   218  	}
   219  
   220  	try(0, time.Duration(0))
   221  	try(1, 100*time.Millisecond)
   222  	try(2, 10*time.Second)
   223  	try(3, 15*time.Second)
   224  	try(4, 20*time.Second)
   225  	try(5, 25*time.Second)
   226  }
   227  
   228  func TestSIDSHook_backoff(t *testing.T) {
   229  	ci.Parallel(t)
   230  	r := require.New(t)
   231  
   232  	ctx := context.Background()
   233  	stop := !backoff(ctx, 0)
   234  	r.False(stop)
   235  }
   236  
   237  func TestSIDSHook_backoffKilled(t *testing.T) {
   238  	ci.Parallel(t)
   239  	r := require.New(t)
   240  
   241  	ctx, cancel := context.WithTimeout(context.Background(), 1)
   242  	defer cancel()
   243  
   244  	stop := !backoff(ctx, 1000)
   245  	r.True(stop)
   246  }
   247  
   248  func TestTaskRunner_DeriveSIToken_UnWritableTokenFile(t *testing.T) {
   249  	ci.Parallel(t)
   250  	// Normally this test would live in test_runner_test.go, but since it requires
   251  	// root and the check for root doesn't like Windows, we put this file in here
   252  	// for now.
   253  
   254  	// This test fails when running as root because the test case for checking
   255  	// the error condition when the file is unreadable fails (root can read the
   256  	// file even though the permissions are set to 0200).
   257  	if unix.Geteuid() == 0 {
   258  		t.Skip("test only works as non-root")
   259  	}
   260  
   261  	r := require.New(t)
   262  
   263  	alloc := mock.BatchConnectAlloc()
   264  	task := alloc.Job.TaskGroups[0].Tasks[0]
   265  	task.Config = map[string]interface{}{
   266  		"run_for": "0s",
   267  	}
   268  
   269  	trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
   270  	defer cleanup()
   271  
   272  	// make the si_token file un-writable, triggering a failure after a
   273  	// successful token derivation
   274  	secrets := t.TempDir()
   275  	trConfig.TaskDir.SecretsDir = secrets
   276  	err := ioutil.WriteFile(filepath.Join(secrets, sidsTokenFile), nil, 0400)
   277  	r.NoError(err)
   278  
   279  	// set a consul token for the nomad client, which is what triggers the
   280  	// SIDS hook to be applied
   281  	trConfig.ClientConfig.ConsulConfig.Token = uuid.Generate()
   282  
   283  	// derive token works just fine
   284  	deriveFn := func(*structs.Allocation, []string) (map[string]string, error) {
   285  		return map[string]string{task.Name: uuid.Generate()}, nil
   286  	}
   287  	siClient := trConfig.ConsulSI.(*consulapi.MockServiceIdentitiesClient)
   288  	siClient.DeriveTokenFn = deriveFn
   289  
   290  	// start the task runner
   291  	tr, err := NewTaskRunner(trConfig)
   292  	r.NoError(err)
   293  	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
   294  	useMockEnvoyBootstrapHook(tr) // mock the envoy bootstrap
   295  
   296  	go tr.Run()
   297  
   298  	// wait for task runner to finish running
   299  	testWaitForTaskToDie(t, tr)
   300  
   301  	// assert task exited un-successfully
   302  	finalState := tr.TaskState()
   303  	r.Equal(structs.TaskStateDead, finalState.State)
   304  	r.True(finalState.Failed) // should have failed to write SI token
   305  	r.Contains(finalState.Events[2].DisplayMessage, "failed to write SI token")
   306  
   307  	// assert the token is *not* on disk, as secrets dir was un-writable
   308  	tokenPath := filepath.Join(trConfig.TaskDir.SecretsDir, sidsTokenFile)
   309  	token, err := ioutil.ReadFile(tokenPath)
   310  	r.NoError(err)
   311  	r.Empty(token)
   312  }