github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/worker/dbaccessor/tracker_test.go (about)

     1  // Copyright 2023 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package dbaccessor
     5  
     6  import (
     7  	"context"
     8  	"database/sql"
     9  	"fmt"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/juju/collections/set"
    14  	"github.com/juju/errors"
    15  	jc "github.com/juju/testing/checkers"
    16  	"github.com/juju/worker/v3/workertest"
    17  	"go.uber.org/mock/gomock"
    18  	gc "gopkg.in/check.v1"
    19  
    20  	coredatabase "github.com/juju/juju/core/database"
    21  	"github.com/juju/juju/testing"
    22  )
    23  
    24  type trackedDBWorkerSuite struct {
    25  	dbBaseSuite
    26  }
    27  
    28  var _ = gc.Suite(&trackedDBWorkerSuite{})
    29  
    30  func (s *trackedDBWorkerSuite) TestWorkerStartup(c *gc.C) {
    31  	defer s.setupMocks(c).Finish()
    32  
    33  	s.expectAnyLogs()
    34  	s.expectClock()
    35  	s.expectTimer(0)
    36  
    37  	s.dbApp.EXPECT().Open(gomock.Any(), "controller").Return(s.DB(), nil)
    38  
    39  	w, err := NewTrackedDBWorker(context.Background(), s.dbApp, "controller", WithClock(s.clock), WithLogger(s.logger))
    40  	c.Assert(err, jc.ErrorIsNil)
    41  
    42  	defer workertest.DirtyKill(c, w)
    43  
    44  	workertest.CleanKill(c, w)
    45  }
    46  
    47  func (s *trackedDBWorkerSuite) TestWorkerReport(c *gc.C) {
    48  	defer s.setupMocks(c).Finish()
    49  
    50  	s.expectAnyLogs()
    51  	s.expectClock()
    52  	s.expectTimer(0)
    53  
    54  	s.dbApp.EXPECT().Open(gomock.Any(), "controller").Return(s.DB(), nil)
    55  
    56  	w, err := NewTrackedDBWorker(context.Background(), s.dbApp, "controller", WithClock(s.clock), WithLogger(s.logger))
    57  	c.Assert(err, jc.ErrorIsNil)
    58  
    59  	defer workertest.DirtyKill(c, w)
    60  
    61  	report := w.(interface{ Report() map[string]any }).Report()
    62  	c.Assert(report, MapHasKeys, []string{
    63  		"db-replacements",
    64  		"max-ping-duration",
    65  		"last-ping-attempts",
    66  		"last-ping-duration",
    67  	})
    68  
    69  	workertest.CleanKill(c, w)
    70  }
    71  
    72  func (s *trackedDBWorkerSuite) TestWorkerDBIsNotNil(c *gc.C) {
    73  	defer s.setupMocks(c).Finish()
    74  
    75  	s.expectAnyLogs()
    76  	s.expectClock()
    77  	s.expectTimer(0)
    78  
    79  	s.dbApp.EXPECT().Open(gomock.Any(), "controller").Return(s.DB(), nil)
    80  
    81  	w, err := s.newTrackedDBWorker(defaultPingDBFunc)
    82  	c.Assert(err, jc.ErrorIsNil)
    83  
    84  	defer workertest.DirtyKill(c, w)
    85  
    86  	err = w.TxnNoRetry(context.Background(), func(_ context.Context, tx *sql.Tx) error {
    87  		if tx == nil {
    88  			return errors.New("nil transaction")
    89  		}
    90  		return nil
    91  	})
    92  	c.Assert(err, jc.ErrorIsNil)
    93  
    94  	workertest.CleanKill(c, w)
    95  }
    96  
    97  func (s *trackedDBWorkerSuite) TestWorkerTxnIsNotNil(c *gc.C) {
    98  	defer s.setupMocks(c).Finish()
    99  
   100  	s.expectAnyLogs()
   101  	s.expectClock()
   102  	s.expectTimer(0)
   103  
   104  	s.dbApp.EXPECT().Open(gomock.Any(), "controller").Return(s.DB(), nil)
   105  
   106  	w, err := s.newTrackedDBWorker(defaultPingDBFunc)
   107  	c.Assert(err, jc.ErrorIsNil)
   108  
   109  	defer workertest.DirtyKill(c, w)
   110  
   111  	done := make(chan struct{})
   112  	err = w.Txn(context.TODO(), func(ctx context.Context, tx *sql.Tx) error {
   113  		defer close(done)
   114  
   115  		c.Assert(tx, gc.NotNil)
   116  		return nil
   117  	})
   118  	c.Assert(err, jc.ErrorIsNil)
   119  
   120  	select {
   121  	case <-done:
   122  	case <-time.After(testing.ShortWait):
   123  		c.Fatal("timed out waiting for DB callback")
   124  	}
   125  
   126  	workertest.CleanKill(c, w)
   127  }
   128  
   129  func (s *trackedDBWorkerSuite) TestWorkerAttemptsToVerifyDB(c *gc.C) {
   130  	defer s.setupMocks(c).Finish()
   131  
   132  	s.expectAnyLogs()
   133  	s.expectClock()
   134  	done := s.expectTimer(1)
   135  
   136  	s.timer.EXPECT().Reset(PollInterval).Times(1)
   137  	s.dbApp.EXPECT().Open(gomock.Any(), "controller").Return(s.DB(), nil)
   138  
   139  	var count uint64
   140  	pingFn := func(context.Context, *sql.DB) error {
   141  		atomic.AddUint64(&count, 1)
   142  		return nil
   143  	}
   144  
   145  	w, err := s.newTrackedDBWorker(pingFn)
   146  	c.Assert(err, jc.ErrorIsNil)
   147  
   148  	defer workertest.DirtyKill(c, w)
   149  
   150  	select {
   151  	case <-done:
   152  	case <-time.After(testing.ShortWait):
   153  		c.Fatal("timed out waiting for DB callback")
   154  	}
   155  
   156  	// Attempt to use the new db, note there shouldn't be any leases in this db.
   157  	tables := readTableNames(c, w)
   158  	c.Assert(tables, SliceContains, "lease")
   159  
   160  	workertest.CleanKill(c, w)
   161  
   162  	c.Assert(count, gc.Equals, uint64(1))
   163  	c.Assert(w.Err(), jc.ErrorIsNil)
   164  }
   165  
   166  func (s *trackedDBWorkerSuite) TestWorkerAttemptsToVerifyDBButSucceeds(c *gc.C) {
   167  	defer s.setupMocks(c).Finish()
   168  
   169  	s.expectAnyLogs()
   170  	s.expectClock()
   171  	done := s.expectTimer(1)
   172  
   173  	s.timer.EXPECT().Reset(PollInterval).Times(1)
   174  
   175  	dbChange := make(chan struct{})
   176  	s.dbApp.EXPECT().Open(gomock.Any(), "controller").Return(s.DB(), nil).Times(DefaultVerifyAttempts - 1)
   177  	s.dbApp.EXPECT().Open(gomock.Any(), "controller").Return(s.DB(), nil).DoAndReturn(func(_ context.Context, _ string) (*sql.DB, error) {
   178  		defer close(dbChange)
   179  		return s.DB(), nil
   180  	})
   181  
   182  	var count uint64
   183  	pingFn := func(context.Context, *sql.DB) error {
   184  		val := atomic.AddUint64(&count, 1)
   185  
   186  		if val == DefaultVerifyAttempts {
   187  			return nil
   188  		}
   189  		return errors.New("boom")
   190  	}
   191  
   192  	w, err := s.newTrackedDBWorker(pingFn)
   193  	c.Assert(err, jc.ErrorIsNil)
   194  
   195  	defer workertest.DirtyKill(c, w)
   196  
   197  	select {
   198  	case <-done:
   199  	case <-time.After(testing.ShortWait):
   200  		c.Fatal("timed out waiting for DB callback")
   201  	}
   202  
   203  	// The db should have changed to the new db.
   204  	select {
   205  	case <-dbChange:
   206  	case <-time.After(testing.ShortWait):
   207  		c.Fatal("timed out waiting for DB callback")
   208  	}
   209  
   210  	tables := readTableNames(c, w)
   211  	c.Assert(tables, SliceContains, "lease")
   212  
   213  	workertest.CleanKill(c, w)
   214  
   215  	c.Assert(w.Err(), jc.ErrorIsNil)
   216  }
   217  
   218  func (s *trackedDBWorkerSuite) TestWorkerAttemptsToVerifyDBRepeatedly(c *gc.C) {
   219  	defer s.setupMocks(c).Finish()
   220  
   221  	s.expectAnyLogs()
   222  	s.expectClock()
   223  	done := s.expectTimer(2)
   224  
   225  	s.timer.EXPECT().Reset(PollInterval).Times(2)
   226  
   227  	s.dbApp.EXPECT().Open(gomock.Any(), "controller").Return(s.DB(), nil)
   228  
   229  	var count uint64
   230  	pingFn := func(context.Context, *sql.DB) error {
   231  		atomic.AddUint64(&count, 1)
   232  		return nil
   233  	}
   234  
   235  	w, err := s.newTrackedDBWorker(pingFn)
   236  	c.Assert(err, jc.ErrorIsNil)
   237  
   238  	defer workertest.DirtyKill(c, w)
   239  
   240  	select {
   241  	case <-done:
   242  	case <-time.After(testing.ShortWait):
   243  		c.Fatal("timed out waiting for DB callback")
   244  	}
   245  
   246  	// Attempt to use the new db, note there shouldn't be any leases in this db.
   247  	tables := readTableNames(c, w)
   248  	c.Assert(tables, SliceContains, "lease")
   249  
   250  	workertest.CleanKill(c, w)
   251  
   252  	c.Assert(count, gc.Equals, uint64(2))
   253  	c.Assert(w.Err(), jc.ErrorIsNil)
   254  }
   255  
   256  func (s *trackedDBWorkerSuite) TestWorkerAttemptsToVerifyDBButSucceedsWithDifferentDB(c *gc.C) {
   257  	defer s.setupMocks(c).Finish()
   258  
   259  	s.expectAnyLogs()
   260  	s.expectClock()
   261  	done := s.expectTimer(1)
   262  
   263  	s.timer.EXPECT().Reset(PollInterval).Times(1)
   264  
   265  	dbChange := make(chan struct{})
   266  	exp := s.dbApp.EXPECT()
   267  	gomock.InOrder(
   268  		exp.Open(gomock.Any(), "controller").Return(s.DB(), nil),
   269  		exp.Open(gomock.Any(), "controller").Return(s.DB(), nil),
   270  		exp.Open(gomock.Any(), "controller").DoAndReturn(func(_ context.Context, _ string) (*sql.DB, error) {
   271  			defer close(dbChange)
   272  			return s.NewCleanDB(c), nil
   273  		}),
   274  	)
   275  
   276  	var count uint64
   277  	pingFn := func(context.Context, *sql.DB) error {
   278  		val := atomic.AddUint64(&count, 1)
   279  
   280  		if val == DefaultVerifyAttempts {
   281  			return nil
   282  		}
   283  		return errors.New("boom")
   284  	}
   285  
   286  	w, err := s.newTrackedDBWorker(pingFn)
   287  	c.Assert(err, jc.ErrorIsNil)
   288  
   289  	defer workertest.DirtyKill(c, w)
   290  
   291  	select {
   292  	case <-done:
   293  	case <-time.After(testing.ShortWait):
   294  		c.Fatal("timed out waiting for DB callback")
   295  	}
   296  
   297  	// Wait for the clean database to have been returned.
   298  	select {
   299  	case <-dbChange:
   300  	case <-time.After(testing.ShortWait):
   301  		c.Fatal("timed out waiting for DB callback")
   302  	}
   303  
   304  	// There is a race potential race with the composition here, because
   305  	// although the ping func may return a new database, it is not instantly
   306  	// set as the worker's DB reference. We need to give it a chance.
   307  	// In-theatre this will be OK, because a DB in an error state recoverable
   308  	// by reconnecting will be replaced within the default retry strategy's
   309  	// backoff/repeat loop.
   310  	timeout := time.After(testing.ShortWait)
   311  	tables := readTableNames(c, w)
   312  loop:
   313  	for {
   314  		select {
   315  		case <-timeout:
   316  			c.Fatal("did not reach expected clean DB state")
   317  		default:
   318  			if set.NewStrings(tables...).Contains("lease") {
   319  				tables = readTableNames(c, w)
   320  			} else {
   321  				break loop
   322  			}
   323  		}
   324  	}
   325  
   326  	workertest.CleanKill(c, w)
   327  	c.Assert(w.Err(), jc.ErrorIsNil)
   328  }
   329  
   330  func (s *trackedDBWorkerSuite) TestWorkerAttemptsToVerifyDBButFails(c *gc.C) {
   331  	defer s.setupMocks(c).Finish()
   332  
   333  	s.expectAnyLogs()
   334  	s.expectClock()
   335  	done := s.expectTimer(1)
   336  
   337  	s.dbApp.EXPECT().Open(gomock.Any(), "controller").Return(s.DB(), nil).Times(DefaultVerifyAttempts)
   338  
   339  	pingFn := func(context.Context, *sql.DB) error {
   340  		return errors.New("boom")
   341  	}
   342  
   343  	w, err := s.newTrackedDBWorker(pingFn)
   344  	c.Assert(err, jc.ErrorIsNil)
   345  
   346  	defer workertest.DirtyKill(c, w)
   347  
   348  	select {
   349  	case <-done:
   350  	case <-time.After(testing.ShortWait):
   351  		c.Fatal("timed out waiting for DB callback")
   352  	}
   353  
   354  	c.Assert(w.Wait(), gc.ErrorMatches, "boom")
   355  	c.Assert(w.Err(), gc.ErrorMatches, "boom")
   356  
   357  	// Ensure that the DB is dead.
   358  	err = w.Txn(context.TODO(), func(ctx context.Context, tx *sql.Tx) error {
   359  		c.Fatal("failed if called")
   360  		return nil
   361  	})
   362  	c.Assert(err, gc.ErrorMatches, "boom")
   363  }
   364  
   365  func (s *trackedDBWorkerSuite) newTrackedDBWorker(pingFn func(context.Context, *sql.DB) error) (TrackedDB, error) {
   366  	collector := NewMetricsCollector()
   367  	return NewTrackedDBWorker(context.Background(), s.dbApp, "controller",
   368  		WithClock(s.clock),
   369  		WithLogger(s.logger),
   370  		WithPingDBFunc(pingFn),
   371  		WithMetricsCollector(collector),
   372  	)
   373  }
   374  
   375  func readTableNames(c *gc.C, w coredatabase.TrackedDB) []string {
   376  	// Attempt to use the new db, note there shouldn't be any leases in this
   377  	// db.
   378  	var tables []string
   379  	err := w.Txn(context.TODO(), func(ctx context.Context, tx *sql.Tx) error {
   380  		rows, err := tx.Query("SELECT tbl_name FROM sqlite_schema")
   381  		c.Assert(err, jc.ErrorIsNil)
   382  		defer rows.Close()
   383  
   384  		for rows.Next() {
   385  			var table string
   386  			err = rows.Scan(&table)
   387  			c.Assert(err, jc.ErrorIsNil)
   388  			tables = append(tables, table)
   389  		}
   390  
   391  		return nil
   392  	})
   393  	c.Assert(err, jc.ErrorIsNil)
   394  	return set.NewStrings(tables...).SortedValues()
   395  }
   396  
   397  type sliceContainsChecker[T comparable] struct {
   398  	*gc.CheckerInfo
   399  }
   400  
   401  var SliceContains gc.Checker = &sliceContainsChecker[string]{
   402  	&gc.CheckerInfo{Name: "SliceContains", Params: []string{"obtained", "expected"}},
   403  }
   404  
   405  func (checker *sliceContainsChecker[T]) Check(params []interface{}, names []string) (result bool, error string) {
   406  	expected, ok := params[1].(T)
   407  	if !ok {
   408  		var t T
   409  		return false, fmt.Sprintf("expected must be %T", t)
   410  	}
   411  
   412  	obtained, ok := params[0].([]T)
   413  	if !ok {
   414  		var t T
   415  		return false, fmt.Sprintf("Obtained value is not a []%T", t)
   416  	}
   417  
   418  	for _, o := range obtained {
   419  		if o == expected {
   420  			return true, ""
   421  		}
   422  	}
   423  	return false, ""
   424  }
   425  
   426  type hasKeysChecker[T comparable] struct {
   427  	*gc.CheckerInfo
   428  }
   429  
   430  var MapHasKeys gc.Checker = &hasKeysChecker[string]{
   431  	&gc.CheckerInfo{Name: "hasKeysChecker", Params: []string{"obtained", "expected"}},
   432  }
   433  
   434  func (checker *hasKeysChecker[T]) Check(params []interface{}, names []string) (result bool, error string) {
   435  	expected, ok := params[1].([]T)
   436  	if !ok {
   437  		var t T
   438  		return false, fmt.Sprintf("expected must be %T", t)
   439  	}
   440  
   441  	obtained, ok := params[0].(map[T]any)
   442  	if !ok {
   443  		var t T
   444  		return false, fmt.Sprintf("Obtained value is not a map[%T]any", t)
   445  	}
   446  
   447  	for _, k := range expected {
   448  		if _, ok := obtained[k]; !ok {
   449  			return false, fmt.Sprintf("expected key %v not found", k)
   450  		}
   451  	}
   452  	return true, ""
   453  }