github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/serviceregistration/watcher_test.go (about)

     1  package serviceregistration
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/hashicorp/nomad/ci"
    11  	"github.com/hashicorp/nomad/helper/testlog"
    12  	"github.com/hashicorp/nomad/nomad/structs"
    13  	"github.com/hashicorp/nomad/testutil"
    14  	"github.com/shoenig/test/must"
    15  )
    16  
    17  // restartRecord is used by a fakeWorkloadRestarter to record when restarts occur
    18  // due to a watched check.
    19  type restartRecord struct {
    20  	timestamp time.Time
    21  	source    string
    22  	reason    string
    23  	failure   bool
    24  }
    25  
    26  // fakeWorkloadRestarter is a test implementation of TaskRestarter.
    27  type fakeWorkloadRestarter struct {
    28  	// restarts is a slice of all of the restarts triggered by the checkWatcher
    29  	restarts []restartRecord
    30  
    31  	// need the checkWatcher to re-Watch restarted tasks like TaskRunner
    32  	watcher *UniversalCheckWatcher
    33  
    34  	// check to re-Watch on restarts
    35  	check     *structs.ServiceCheck
    36  	allocID   string
    37  	taskName  string
    38  	checkName string
    39  
    40  	lock sync.Mutex
    41  }
    42  
    43  // newFakeCheckRestart creates a new mock WorkloadRestarter.
    44  func newFakeWorkloadRestarter(w *UniversalCheckWatcher, allocID, taskName, checkName string, c *structs.ServiceCheck) *fakeWorkloadRestarter {
    45  	return &fakeWorkloadRestarter{
    46  		watcher:   w,
    47  		check:     c,
    48  		allocID:   allocID,
    49  		taskName:  taskName,
    50  		checkName: checkName,
    51  	}
    52  }
    53  
    54  // Restart implements part of the TaskRestarter interface needed for check watching
    55  // and is normally fulfilled by a TaskRunner.
    56  //
    57  // Restarts are recorded in the []restarts field and re-Watch the check.
    58  func (c *fakeWorkloadRestarter) Restart(_ context.Context, event *structs.TaskEvent, failure bool) error {
    59  	c.lock.Lock()
    60  	defer c.lock.Unlock()
    61  
    62  	restart := restartRecord{
    63  		timestamp: time.Now(),
    64  		source:    event.Type,
    65  		reason:    event.DisplayMessage,
    66  		failure:   failure,
    67  	}
    68  	c.restarts = append(c.restarts, restart)
    69  
    70  	// Re-Watch the check just like TaskRunner
    71  	c.watcher.Watch(c.allocID, c.taskName, c.checkName, c.check, c)
    72  	return nil
    73  }
    74  
    75  // String is useful for debugging.
    76  func (c *fakeWorkloadRestarter) String() string {
    77  	c.lock.Lock()
    78  	defer c.lock.Unlock()
    79  
    80  	s := fmt.Sprintf("%s %s %s restarts:\n", c.allocID, c.taskName, c.checkName)
    81  	for _, r := range c.restarts {
    82  		s += fmt.Sprintf("%s - %s: %s (failure: %t)\n", r.timestamp, r.source, r.reason, r.failure)
    83  	}
    84  	return s
    85  }
    86  
    87  // GetRestarts for testing in a thread-safe way
    88  func (c *fakeWorkloadRestarter) GetRestarts() []restartRecord {
    89  	c.lock.Lock()
    90  	defer c.lock.Unlock()
    91  
    92  	o := make([]restartRecord, len(c.restarts))
    93  	copy(o, c.restarts)
    94  	return o
    95  }
    96  
    97  // response is a response returned by fakeCheckStatusGetter after a certain time
    98  type response struct {
    99  	at     time.Time
   100  	id     string
   101  	status string
   102  }
   103  
   104  // fakeCheckStatusGetter is a mock implementation of CheckStatusGetter
   105  type fakeCheckStatusGetter struct {
   106  	lock      sync.Mutex
   107  	responses map[string][]response
   108  }
   109  
   110  func (g *fakeCheckStatusGetter) Get() (map[string]string, error) {
   111  	g.lock.Lock()
   112  	defer g.lock.Unlock()
   113  
   114  	now := time.Now()
   115  	result := make(map[string]string)
   116  	// use the newest response after now for the response
   117  	for k, vs := range g.responses {
   118  		for _, v := range vs {
   119  			if v.at.After(now) {
   120  				break
   121  			}
   122  			result[k] = v.status
   123  		}
   124  	}
   125  
   126  	return result, nil
   127  }
   128  
   129  func (g *fakeCheckStatusGetter) add(checkID, status string, at time.Time) {
   130  	g.lock.Lock()
   131  	defer g.lock.Unlock()
   132  	if g.responses == nil {
   133  		g.responses = make(map[string][]response)
   134  	}
   135  	g.responses[checkID] = append(g.responses[checkID], response{at, checkID, status})
   136  }
   137  
   138  func testCheck() *structs.ServiceCheck {
   139  	return &structs.ServiceCheck{
   140  		Name:     "testcheck",
   141  		Interval: 100 * time.Millisecond,
   142  		Timeout:  100 * time.Millisecond,
   143  		CheckRestart: &structs.CheckRestart{
   144  			Limit:          3,
   145  			Grace:          100 * time.Millisecond,
   146  			IgnoreWarnings: false,
   147  		},
   148  	}
   149  }
   150  
   151  // testWatcherSetup sets up a fakeChecksAPI and a real checkWatcher with a test
   152  // logger and faster poll frequency.
   153  func testWatcherSetup(t *testing.T) (*fakeCheckStatusGetter, *UniversalCheckWatcher) {
   154  	logger := testlog.HCLogger(t)
   155  	getter := new(fakeCheckStatusGetter)
   156  	cw := NewCheckWatcher(logger, getter)
   157  	cw.pollFrequency = 10 * time.Millisecond
   158  	return getter, cw
   159  }
   160  
   161  func before() time.Time {
   162  	return time.Now().Add(-10 * time.Second)
   163  }
   164  
   165  // TestCheckWatcher_SkipUnwatched asserts unwatched checks are ignored.
   166  func TestCheckWatcher_SkipUnwatched(t *testing.T) {
   167  	ci.Parallel(t)
   168  
   169  	// Create a check with restarting disabled
   170  	check := testCheck()
   171  	check.CheckRestart = nil
   172  
   173  	logger := testlog.HCLogger(t)
   174  	getter := new(fakeCheckStatusGetter)
   175  
   176  	cw := NewCheckWatcher(logger, getter)
   177  	restarter1 := newFakeWorkloadRestarter(cw, "testalloc1", "testtask1", "testcheck1", check)
   178  	cw.Watch("testalloc1", "testtask1", "testcheck1", check, restarter1)
   179  
   180  	// Check should have been dropped as it's not watched
   181  	enqueued := len(cw.checkUpdateCh)
   182  	must.Zero(t, enqueued, must.Sprintf("expected 0 checks to be enqueued for watching but found %d", enqueued))
   183  }
   184  
   185  // TestCheckWatcher_Healthy asserts healthy tasks are not restarted.
   186  func TestCheckWatcher_Healthy(t *testing.T) {
   187  	ci.Parallel(t)
   188  
   189  	now := before()
   190  	getter, cw := testWatcherSetup(t)
   191  
   192  	// Make both checks healthy from the beginning
   193  	getter.add("testcheck1", "passing", now)
   194  	getter.add("testcheck2", "passing", now)
   195  
   196  	check1 := testCheck()
   197  	restarter1 := newFakeWorkloadRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1)
   198  	cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1)
   199  
   200  	check2 := testCheck()
   201  	check2.CheckRestart.Limit = 1
   202  	check2.CheckRestart.Grace = 0
   203  	restarter2 := newFakeWorkloadRestarter(cw, "testalloc2", "testtask2", "testcheck2", check2)
   204  	cw.Watch("testalloc2", "testtask2", "testcheck2", check2, restarter2)
   205  
   206  	// Run
   207  	ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
   208  	defer cancel()
   209  	cw.Run(ctx)
   210  
   211  	// Ensure restart was never called
   212  	must.SliceEmpty(t, restarter1.restarts, must.Sprint("expected check 1 to not be restarted"))
   213  	must.SliceEmpty(t, restarter2.restarts, must.Sprint("expected check 2 to not be restarted"))
   214  }
   215  
   216  // TestCheckWatcher_Unhealthy asserts unhealthy tasks are restarted exactly once.
   217  func TestCheckWatcher_Unhealthy(t *testing.T) {
   218  	ci.Parallel(t)
   219  
   220  	now := before()
   221  	getter, cw := testWatcherSetup(t)
   222  
   223  	// Check has always been failing
   224  	getter.add("testcheck1", "critical", now)
   225  
   226  	check1 := testCheck()
   227  	restarter1 := newFakeWorkloadRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1)
   228  	cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1)
   229  
   230  	// Run
   231  	ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
   232  	defer cancel()
   233  	cw.Run(ctx)
   234  
   235  	// Ensure restart was called exactly once
   236  	must.Len(t, 1, restarter1.restarts, must.Sprint("expected check to be restarted once"))
   237  }
   238  
   239  // TestCheckWatcher_HealthyWarning asserts checks in warning with
   240  // ignore_warnings=true do not restart tasks.
   241  func TestCheckWatcher_HealthyWarning(t *testing.T) {
   242  	ci.Parallel(t)
   243  
   244  	now := before()
   245  	getter, cw := testWatcherSetup(t)
   246  
   247  	// Check is always in warning but that's ok
   248  	getter.add("testcheck1", "warning", now)
   249  
   250  	check1 := testCheck()
   251  	check1.CheckRestart.Limit = 1
   252  	check1.CheckRestart.Grace = 0
   253  	check1.CheckRestart.IgnoreWarnings = true
   254  	restarter1 := newFakeWorkloadRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1)
   255  	cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1)
   256  
   257  	// Run
   258  	ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
   259  	defer cancel()
   260  	cw.Run(ctx)
   261  
   262  	// Ensure restart was never called on check 1
   263  	must.SliceEmpty(t, restarter1.restarts, must.Sprint("expected check 1 to not be restarted"))
   264  }
   265  
   266  // TestCheckWatcher_Flapping asserts checks that flap from healthy to unhealthy
   267  // before the unhealthy limit is reached do not restart tasks.
   268  func TestCheckWatcher_Flapping(t *testing.T) {
   269  	ci.Parallel(t)
   270  
   271  	getter, cw := testWatcherSetup(t)
   272  
   273  	check1 := testCheck()
   274  	check1.CheckRestart.Grace = 0
   275  	restarter1 := newFakeWorkloadRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1)
   276  	cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1)
   277  
   278  	// Check flaps and is never failing for the full 200ms needed to restart
   279  	now := time.Now()
   280  	getter.add("testcheck1", "passing", now)
   281  	getter.add("testcheck1", "critical", now.Add(100*time.Millisecond))
   282  	getter.add("testcheck1", "passing", now.Add(250*time.Millisecond))
   283  	getter.add("testcheck1", "critical", now.Add(300*time.Millisecond))
   284  	getter.add("testcheck1", "passing", now.Add(450*time.Millisecond))
   285  
   286  	ctx, cancel := context.WithTimeout(context.Background(), 600*time.Millisecond)
   287  	defer cancel()
   288  	cw.Run(ctx)
   289  
   290  	// Ensure restart was never called on check 1
   291  	must.SliceEmpty(t, restarter1.restarts, must.Sprint("expected check 1 to not be restarted"))
   292  }
   293  
   294  // TestCheckWatcher_Unwatch asserts unwatching checks prevents restarts.
   295  func TestCheckWatcher_Unwatch(t *testing.T) {
   296  	ci.Parallel(t)
   297  
   298  	now := before()
   299  	getter, cw := testWatcherSetup(t)
   300  
   301  	// Always failing
   302  	getter.add("testcheck1", "critical", now)
   303  
   304  	// Unwatch immediately
   305  	check1 := testCheck()
   306  	check1.CheckRestart.Limit = 1
   307  	check1.CheckRestart.Grace = 100 * time.Millisecond
   308  	restarter1 := newFakeWorkloadRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1)
   309  	cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1)
   310  	cw.Unwatch("testcheck1")
   311  
   312  	ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond)
   313  	defer cancel()
   314  	cw.Run(ctx)
   315  
   316  	// Ensure restart was never called on check 1
   317  	must.SliceEmpty(t, restarter1.restarts, must.Sprint("expected check 1 to not be restarted"))
   318  }
   319  
   320  // TestCheckWatcher_MultipleChecks asserts that when there are multiple checks
   321  // for a single task, all checks should be removed when any of them restart the
   322  // task to avoid multiple restarts.
   323  func TestCheckWatcher_MultipleChecks(t *testing.T) {
   324  	ci.Parallel(t)
   325  
   326  	getter, cw := testWatcherSetup(t)
   327  
   328  	// check is critical, 3 passing; should only be 1 net restart
   329  	now := time.Now()
   330  	getter.add("testcheck1", "critical", before())
   331  	getter.add("testcheck1", "passing", now.Add(150*time.Millisecond))
   332  	getter.add("testcheck2", "critical", before())
   333  	getter.add("testcheck2", "passing", now.Add(150*time.Millisecond))
   334  	getter.add("testcheck3", "passing", time.Time{})
   335  
   336  	check1 := testCheck()
   337  	check1.Name = "testcheck1"
   338  	check1.CheckRestart.Limit = 1
   339  	restarter1 := newFakeWorkloadRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1)
   340  	cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1)
   341  
   342  	check2 := testCheck()
   343  	check2.Name = "testcheck2"
   344  	check2.CheckRestart.Limit = 1
   345  	restarter2 := newFakeWorkloadRestarter(cw, "testalloc1", "testtask1", "testcheck2", check2)
   346  	cw.Watch("testalloc1", "testtask1", "testcheck2", check2, restarter2)
   347  
   348  	check3 := testCheck()
   349  	check3.Name = "testcheck3"
   350  	check3.CheckRestart.Limit = 1
   351  	restarter3 := newFakeWorkloadRestarter(cw, "testalloc1", "testtask1", "testcheck3", check3)
   352  	cw.Watch("testalloc1", "testtask1", "testcheck3", check3, restarter3)
   353  
   354  	// Run
   355  	ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
   356  	defer cancel()
   357  	cw.Run(ctx)
   358  
   359  	// Ensure that restart was only called once on check 1 or 2. Since
   360  	// checks are in a map it's random which check triggers the restart
   361  	// first.
   362  	if n := len(restarter1.restarts) + len(restarter2.restarts); n != 1 {
   363  		t.Errorf("expected check 1 & 2 to be restarted 1 time but found %d\ncheck 1:\n%s\ncheck 2:%s",
   364  			n, restarter1, restarter2)
   365  	}
   366  
   367  	if n := len(restarter3.restarts); n != 0 {
   368  		t.Errorf("expected check 3 to not be restarted but found %d:\n%s", n, restarter3)
   369  	}
   370  }
   371  
   372  // TestCheckWatcher_Deadlock asserts that check watcher will not deadlock when
   373  // attempting to restart a task even if its update queue is full.
   374  // https://github.com/hashicorp/nomad/issues/5395
   375  func TestCheckWatcher_Deadlock(t *testing.T) {
   376  	ci.Parallel(t)
   377  
   378  	getter, cw := testWatcherSetup(t)
   379  
   380  	// If TR.Restart blocks, restarting len(checkUpdateCh)+1 checks causes
   381  	// a deadlock due to checkWatcher.Run being blocked in
   382  	// checkRestart.apply and unable to process updates from the chan!
   383  	n := cap(cw.checkUpdateCh) + 1
   384  	checks := make([]*structs.ServiceCheck, n)
   385  	restarters := make([]*fakeWorkloadRestarter, n)
   386  	for i := 0; i < n; i++ {
   387  		c := testCheck()
   388  		r := newFakeWorkloadRestarter(cw,
   389  			fmt.Sprintf("alloc%d", i),
   390  			fmt.Sprintf("task%d", i),
   391  			fmt.Sprintf("check%d", i),
   392  			c,
   393  		)
   394  		checks[i] = c
   395  		restarters[i] = r
   396  	}
   397  
   398  	// Run
   399  	ctx, cancel := context.WithCancel(context.Background())
   400  	defer cancel()
   401  	go cw.Run(ctx)
   402  
   403  	// Watch
   404  	for _, r := range restarters {
   405  		cw.Watch(r.allocID, r.taskName, r.checkName, r.check, r)
   406  	}
   407  
   408  	// Make them all fail
   409  	for _, r := range restarters {
   410  		getter.add(r.checkName, "critical", time.Time{})
   411  	}
   412  
   413  	// Ensure that restart was called exactly once on all checks
   414  	testutil.WaitForResult(func() (bool, error) {
   415  		for _, r := range restarters {
   416  			if n := len(r.GetRestarts()); n != 1 {
   417  				return false, fmt.Errorf("expected 1 restart but found %d", n)
   418  			}
   419  		}
   420  		return true, nil
   421  	}, func(err error) {
   422  		must.NoError(t, err)
   423  	})
   424  }