github.com/adityamillind98/nomad@v0.11.8/command/agent/consul/check_watcher_test.go (about)

     1  package consul
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/hashicorp/consul/api"
    11  	"github.com/hashicorp/nomad/helper/testlog"
    12  	"github.com/hashicorp/nomad/nomad/structs"
    13  	"github.com/hashicorp/nomad/testutil"
    14  	"github.com/stretchr/testify/require"
    15  )
    16  
    17  // checkRestartRecord is used by a testFakeCtx to record when restarts occur
    18  // due to a watched check.
    19  type checkRestartRecord struct {
    20  	timestamp time.Time
    21  	source    string
    22  	reason    string
    23  	failure   bool
    24  }
    25  
    26  // fakeCheckRestarter is a test implementation of TaskRestarter.
    27  type fakeCheckRestarter struct {
    28  	// restarts is a slice of all of the restarts triggered by the checkWatcher
    29  	restarts []checkRestartRecord
    30  
    31  	// need the checkWatcher to re-Watch restarted tasks like TaskRunner
    32  	watcher *checkWatcher
    33  
    34  	// check to re-Watch on restarts
    35  	check     *structs.ServiceCheck
    36  	allocID   string
    37  	taskName  string
    38  	checkName string
    39  
    40  	mu sync.Mutex
    41  }
    42  
    43  // newFakeCheckRestart creates a new TaskRestarter. It needs all of the
    44  // parameters checkWatcher.Watch expects.
    45  func newFakeCheckRestarter(w *checkWatcher, allocID, taskName, checkName string, c *structs.ServiceCheck) *fakeCheckRestarter {
    46  	return &fakeCheckRestarter{
    47  		watcher:   w,
    48  		check:     c,
    49  		allocID:   allocID,
    50  		taskName:  taskName,
    51  		checkName: checkName,
    52  	}
    53  }
    54  
    55  // Restart implements part of the TaskRestarter interface needed for check
    56  // watching and is normally fulfilled by a TaskRunner.
    57  //
    58  // Restarts are recorded in the []restarts field and re-Watch the check.
    59  //func (c *fakeCheckRestarter) Restart(source, reason string, failure bool) {
    60  func (c *fakeCheckRestarter) Restart(ctx context.Context, event *structs.TaskEvent, failure bool) error {
    61  	c.mu.Lock()
    62  	defer c.mu.Unlock()
    63  	restart := checkRestartRecord{
    64  		timestamp: time.Now(),
    65  		source:    event.Type,
    66  		reason:    event.DisplayMessage,
    67  		failure:   failure,
    68  	}
    69  	c.restarts = append(c.restarts, restart)
    70  
    71  	// Re-Watch the check just like TaskRunner
    72  	c.watcher.Watch(c.allocID, c.taskName, c.checkName, c.check, c)
    73  	return nil
    74  }
    75  
    76  // String for debugging
    77  func (c *fakeCheckRestarter) String() string {
    78  	c.mu.Lock()
    79  	defer c.mu.Unlock()
    80  
    81  	s := fmt.Sprintf("%s %s %s restarts:\n", c.allocID, c.taskName, c.checkName)
    82  	for _, r := range c.restarts {
    83  		s += fmt.Sprintf("%s - %s: %s (failure: %t)\n", r.timestamp, r.source, r.reason, r.failure)
    84  	}
    85  	return s
    86  }
    87  
    88  // GetRestarts for testing in a threadsafe way
    89  func (c *fakeCheckRestarter) GetRestarts() []checkRestartRecord {
    90  	c.mu.Lock()
    91  	defer c.mu.Unlock()
    92  
    93  	o := make([]checkRestartRecord, len(c.restarts))
    94  	copy(o, c.restarts)
    95  	return o
    96  }
    97  
    98  // checkResponse is a response returned by the fakeChecksAPI after the given
    99  // time.
   100  type checkResponse struct {
   101  	at     time.Time
   102  	id     string
   103  	status string
   104  }
   105  
   106  // fakeChecksAPI implements the Checks() method for testing Consul.
   107  type fakeChecksAPI struct {
   108  	// responses is a map of check ids to their status at a particular
   109  	// time. checkResponses must be in chronological order.
   110  	responses map[string][]checkResponse
   111  
   112  	mu sync.Mutex
   113  }
   114  
   115  func newFakeChecksAPI() *fakeChecksAPI {
   116  	return &fakeChecksAPI{responses: make(map[string][]checkResponse)}
   117  }
   118  
   119  // add a new check status to Consul at the given time.
   120  func (c *fakeChecksAPI) add(id, status string, at time.Time) {
   121  	c.mu.Lock()
   122  	c.responses[id] = append(c.responses[id], checkResponse{at, id, status})
   123  	c.mu.Unlock()
   124  }
   125  
   126  func (c *fakeChecksAPI) Checks() (map[string]*api.AgentCheck, error) {
   127  	c.mu.Lock()
   128  	defer c.mu.Unlock()
   129  	now := time.Now()
   130  	result := make(map[string]*api.AgentCheck, len(c.responses))
   131  
   132  	// Use the latest response for each check
   133  	for k, vs := range c.responses {
   134  		for _, v := range vs {
   135  			if v.at.After(now) {
   136  				break
   137  			}
   138  			result[k] = &api.AgentCheck{
   139  				CheckID: k,
   140  				Name:    k,
   141  				Status:  v.status,
   142  			}
   143  		}
   144  	}
   145  
   146  	return result, nil
   147  }
   148  
   149  // testWatcherSetup sets up a fakeChecksAPI and a real checkWatcher with a test
   150  // logger and faster poll frequency.
   151  func testWatcherSetup(t *testing.T) (*fakeChecksAPI, *checkWatcher) {
   152  	fakeAPI := newFakeChecksAPI()
   153  	cw := newCheckWatcher(testlog.HCLogger(t), fakeAPI)
   154  	cw.pollFreq = 10 * time.Millisecond
   155  	return fakeAPI, cw
   156  }
   157  
   158  func testCheck() *structs.ServiceCheck {
   159  	return &structs.ServiceCheck{
   160  		Name:     "testcheck",
   161  		Interval: 100 * time.Millisecond,
   162  		Timeout:  100 * time.Millisecond,
   163  		CheckRestart: &structs.CheckRestart{
   164  			Limit:          3,
   165  			Grace:          100 * time.Millisecond,
   166  			IgnoreWarnings: false,
   167  		},
   168  	}
   169  }
   170  
   171  // TestCheckWatcher_Skip asserts unwatched checks are ignored.
   172  func TestCheckWatcher_Skip(t *testing.T) {
   173  	t.Parallel()
   174  
   175  	// Create a check with restarting disabled
   176  	check := testCheck()
   177  	check.CheckRestart = nil
   178  
   179  	cw := newCheckWatcher(testlog.HCLogger(t), newFakeChecksAPI())
   180  	restarter1 := newFakeCheckRestarter(cw, "testalloc1", "testtask1", "testcheck1", check)
   181  	cw.Watch("testalloc1", "testtask1", "testcheck1", check, restarter1)
   182  
   183  	// Check should have been dropped as it's not watched
   184  	if n := len(cw.checkUpdateCh); n != 0 {
   185  		t.Fatalf("expected 0 checks to be enqueued for watching but found %d", n)
   186  	}
   187  }
   188  
   189  // TestCheckWatcher_Healthy asserts healthy tasks are not restarted.
   190  func TestCheckWatcher_Healthy(t *testing.T) {
   191  	t.Parallel()
   192  
   193  	fakeAPI, cw := testWatcherSetup(t)
   194  
   195  	check1 := testCheck()
   196  	restarter1 := newFakeCheckRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1)
   197  	cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1)
   198  
   199  	check2 := testCheck()
   200  	check2.CheckRestart.Limit = 1
   201  	check2.CheckRestart.Grace = 0
   202  	restarter2 := newFakeCheckRestarter(cw, "testalloc2", "testtask2", "testcheck2", check2)
   203  	cw.Watch("testalloc2", "testtask2", "testcheck2", check2, restarter2)
   204  
   205  	// Make both checks healthy from the beginning
   206  	fakeAPI.add("testcheck1", "passing", time.Time{})
   207  	fakeAPI.add("testcheck2", "passing", time.Time{})
   208  
   209  	// Run
   210  	ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
   211  	defer cancel()
   212  	cw.Run(ctx)
   213  
   214  	// Ensure restart was never called
   215  	if n := len(restarter1.restarts); n > 0 {
   216  		t.Errorf("expected check 1 to not be restarted but found %d:\n%s", n, restarter1)
   217  	}
   218  	if n := len(restarter2.restarts); n > 0 {
   219  		t.Errorf("expected check 2 to not be restarted but found %d:\n%s", n, restarter2)
   220  	}
   221  }
   222  
   223  // TestCheckWatcher_Unhealthy asserts unhealthy tasks are restarted exactly once.
   224  func TestCheckWatcher_Unhealthy(t *testing.T) {
   225  	t.Parallel()
   226  
   227  	fakeAPI, cw := testWatcherSetup(t)
   228  
   229  	check1 := testCheck()
   230  	restarter1 := newFakeCheckRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1)
   231  	cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1)
   232  
   233  	// Check has always been failing
   234  	fakeAPI.add("testcheck1", "critical", time.Time{})
   235  
   236  	// Run
   237  	ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
   238  	defer cancel()
   239  	cw.Run(ctx)
   240  
   241  	// Ensure restart was called exactly once
   242  	require.Len(t, restarter1.restarts, 1)
   243  }
   244  
   245  // TestCheckWatcher_HealthyWarning asserts checks in warning with
   246  // ignore_warnings=true do not restart tasks.
   247  func TestCheckWatcher_HealthyWarning(t *testing.T) {
   248  	t.Parallel()
   249  
   250  	fakeAPI, cw := testWatcherSetup(t)
   251  
   252  	check1 := testCheck()
   253  	check1.CheckRestart.Limit = 1
   254  	check1.CheckRestart.Grace = 0
   255  	check1.CheckRestart.IgnoreWarnings = true
   256  	restarter1 := newFakeCheckRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1)
   257  	cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1)
   258  
   259  	// Check is always in warning but that's ok
   260  	fakeAPI.add("testcheck1", "warning", time.Time{})
   261  
   262  	// Run
   263  	ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
   264  	defer cancel()
   265  	cw.Run(ctx)
   266  
   267  	// Ensure restart was never called on check 1
   268  	if n := len(restarter1.restarts); n > 0 {
   269  		t.Errorf("expected check 1 to not be restarted but found %d", n)
   270  	}
   271  }
   272  
   273  // TestCheckWatcher_Flapping asserts checks that flap from healthy to unhealthy
   274  // before the unhealthy limit is reached do not restart tasks.
   275  func TestCheckWatcher_Flapping(t *testing.T) {
   276  	t.Parallel()
   277  
   278  	fakeAPI, cw := testWatcherSetup(t)
   279  
   280  	check1 := testCheck()
   281  	check1.CheckRestart.Grace = 0
   282  	restarter1 := newFakeCheckRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1)
   283  	cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1)
   284  
   285  	// Check flaps and is never failing for the full 200ms needed to restart
   286  	now := time.Now()
   287  	fakeAPI.add("testcheck1", "passing", now)
   288  	fakeAPI.add("testcheck1", "critical", now.Add(100*time.Millisecond))
   289  	fakeAPI.add("testcheck1", "passing", now.Add(250*time.Millisecond))
   290  	fakeAPI.add("testcheck1", "critical", now.Add(300*time.Millisecond))
   291  	fakeAPI.add("testcheck1", "passing", now.Add(450*time.Millisecond))
   292  
   293  	ctx, cancel := context.WithTimeout(context.Background(), 600*time.Millisecond)
   294  	defer cancel()
   295  	cw.Run(ctx)
   296  
   297  	// Ensure restart was never called on check 1
   298  	if n := len(restarter1.restarts); n > 0 {
   299  		t.Errorf("expected check 1 to not be restarted but found %d\n%s", n, restarter1)
   300  	}
   301  }
   302  
   303  // TestCheckWatcher_Unwatch asserts unwatching checks prevents restarts.
   304  func TestCheckWatcher_Unwatch(t *testing.T) {
   305  	t.Parallel()
   306  
   307  	fakeAPI, cw := testWatcherSetup(t)
   308  
   309  	// Unwatch immediately
   310  	check1 := testCheck()
   311  	check1.CheckRestart.Limit = 1
   312  	check1.CheckRestart.Grace = 100 * time.Millisecond
   313  	restarter1 := newFakeCheckRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1)
   314  	cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1)
   315  	cw.Unwatch("testcheck1")
   316  
   317  	// Always failing
   318  	fakeAPI.add("testcheck1", "critical", time.Time{})
   319  
   320  	ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond)
   321  	defer cancel()
   322  	cw.Run(ctx)
   323  
   324  	// Ensure restart was never called on check 1
   325  	if n := len(restarter1.restarts); n > 0 {
   326  		t.Errorf("expected check 1 to not be restarted but found %d\n%s", n, restarter1)
   327  	}
   328  }
   329  
   330  // TestCheckWatcher_MultipleChecks asserts that when there are multiple checks
   331  // for a single task, all checks should be removed when any of them restart the
   332  // task to avoid multiple restarts.
   333  func TestCheckWatcher_MultipleChecks(t *testing.T) {
   334  	t.Parallel()
   335  
   336  	fakeAPI, cw := testWatcherSetup(t)
   337  
   338  	check1 := testCheck()
   339  	check1.CheckRestart.Limit = 1
   340  	restarter1 := newFakeCheckRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1)
   341  	cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1)
   342  
   343  	check2 := testCheck()
   344  	check2.CheckRestart.Limit = 1
   345  	restarter2 := newFakeCheckRestarter(cw, "testalloc1", "testtask1", "testcheck2", check2)
   346  	cw.Watch("testalloc1", "testtask1", "testcheck2", check2, restarter2)
   347  
   348  	check3 := testCheck()
   349  	check3.CheckRestart.Limit = 1
   350  	restarter3 := newFakeCheckRestarter(cw, "testalloc1", "testtask1", "testcheck3", check3)
   351  	cw.Watch("testalloc1", "testtask1", "testcheck3", check3, restarter3)
   352  
   353  	// check 2 & 3 fail long enough to cause 1 restart, but only 1 should restart
   354  	now := time.Now()
   355  	fakeAPI.add("testcheck1", "critical", now)
   356  	fakeAPI.add("testcheck1", "passing", now.Add(150*time.Millisecond))
   357  	fakeAPI.add("testcheck2", "critical", now)
   358  	fakeAPI.add("testcheck2", "passing", now.Add(150*time.Millisecond))
   359  	fakeAPI.add("testcheck3", "passing", time.Time{})
   360  
   361  	// Run
   362  	ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
   363  	defer cancel()
   364  	cw.Run(ctx)
   365  
   366  	// Ensure that restart was only called once on check 1 or 2. Since
   367  	// checks are in a map it's random which check triggers the restart
   368  	// first.
   369  	if n := len(restarter1.restarts) + len(restarter2.restarts); n != 1 {
   370  		t.Errorf("expected check 1 & 2 to be restarted 1 time but found %d\ncheck 1:\n%s\ncheck 2:%s",
   371  			n, restarter1, restarter2)
   372  	}
   373  
   374  	if n := len(restarter3.restarts); n != 0 {
   375  		t.Errorf("expected check 3 to not be restarted but found %d:\n%s", n, restarter3)
   376  	}
   377  }
   378  
   379  // TestCheckWatcher_Deadlock asserts that check watcher will not deadlock when
   380  // attempting to restart a task even if its update queue is full.
   381  // https://github.com/hashicorp/nomad/issues/5395
   382  func TestCheckWatcher_Deadlock(t *testing.T) {
   383  	t.Parallel()
   384  
   385  	fakeAPI, cw := testWatcherSetup(t)
   386  
   387  	// If TR.Restart blocks, restarting len(checkUpdateCh)+1 checks causes
   388  	// a deadlock due to checkWatcher.Run being blocked in
   389  	// checkRestart.apply and unable to process updates from the chan!
   390  	n := cap(cw.checkUpdateCh) + 1
   391  	checks := make([]*structs.ServiceCheck, n)
   392  	restarters := make([]*fakeCheckRestarter, n)
   393  	for i := 0; i < n; i++ {
   394  		c := testCheck()
   395  		r := newFakeCheckRestarter(cw,
   396  			fmt.Sprintf("alloc%d", i),
   397  			fmt.Sprintf("task%d", i),
   398  			fmt.Sprintf("check%d", i),
   399  			c,
   400  		)
   401  		checks[i] = c
   402  		restarters[i] = r
   403  	}
   404  
   405  	// Run
   406  	ctx, cancel := context.WithCancel(context.Background())
   407  	defer cancel()
   408  	go cw.Run(ctx)
   409  
   410  	// Watch
   411  	for _, r := range restarters {
   412  		cw.Watch(r.allocID, r.taskName, r.checkName, r.check, r)
   413  	}
   414  
   415  	// Make them all fail
   416  	for _, r := range restarters {
   417  		fakeAPI.add(r.checkName, "critical", time.Time{})
   418  	}
   419  
   420  	// Ensure that restart was called exactly once on all checks
   421  	testutil.WaitForResult(func() (bool, error) {
   422  		for _, r := range restarters {
   423  			if n := len(r.GetRestarts()); n != 1 {
   424  				return false, fmt.Errorf("expected 1 restart but found %d", n)
   425  			}
   426  		}
   427  		return true, nil
   428  	}, func(err error) {
   429  		require.NoError(t, err)
   430  	})
   431  }