github.com/juju/juju@v0.0.0-20240327075706-a90865de2538/worker/dbaccessor/tracker.go (about)

     1  // Copyright 2023 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package dbaccessor
     5  
     6  import (
     7  	"context"
     8  	"database/sql"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/juju/clock"
    13  	"github.com/juju/errors"
    14  	"github.com/juju/worker/v3"
    15  	"gopkg.in/tomb.v2"
    16  
    17  	coredatabase "github.com/juju/juju/core/database"
    18  	"github.com/juju/juju/database"
    19  	"github.com/juju/juju/database/pragma"
    20  )
    21  
    22  const (
    23  	// PollInterval is the amount of time to wait between polling the database.
    24  	PollInterval = time.Second * 10
    25  
    26  	// DefaultVerifyAttempts is the number of attempts to verify the database,
    27  	// by opening a new database on verification failure.
    28  	DefaultVerifyAttempts = 3
    29  )
    30  
    31  // TrackedDB defines the union of a TrackedDB and a worker.Worker interface.
    32  // This is local to the package, allowing for better testing of the underlying
    33  // trackerDB worker.
    34  type TrackedDB interface {
    35  	coredatabase.TrackedDB
    36  	worker.Worker
    37  }
    38  
    39  // TrackedDBWorkerOption is a function that configures a TrackedDBWorker.
    40  type TrackedDBWorkerOption func(*trackedDBWorker)
    41  
    42  // WithPingDBFunc sets the function used to verify the database connection.
    43  func WithPingDBFunc(f func(context.Context, *sql.DB) error) TrackedDBWorkerOption {
    44  	return func(w *trackedDBWorker) {
    45  		w.pingDBFunc = f
    46  	}
    47  }
    48  
    49  // WithClock sets the clock used by the worker.
    50  func WithClock(clock clock.Clock) TrackedDBWorkerOption {
    51  	return func(w *trackedDBWorker) {
    52  		w.clock = clock
    53  	}
    54  }
    55  
    56  // WithLogger sets the logger used by the worker.
    57  func WithLogger(logger Logger) TrackedDBWorkerOption {
    58  	return func(w *trackedDBWorker) {
    59  		w.logger = logger
    60  	}
    61  }
    62  
    63  // WithMetricsCollector sets the metrics collector used by the worker.
    64  func WithMetricsCollector(metrics *Collector) TrackedDBWorkerOption {
    65  	return func(w *trackedDBWorker) {
    66  		w.metrics = metrics
    67  	}
    68  }
    69  
    70  type trackedDBWorker struct {
    71  	tomb tomb.Tomb
    72  
    73  	dbApp     DBApp
    74  	namespace string
    75  
    76  	mutex sync.RWMutex
    77  	db    *sql.DB
    78  	err   error
    79  
    80  	clock   clock.Clock
    81  	logger  Logger
    82  	metrics *Collector
    83  
    84  	pingDBFunc func(context.Context, *sql.DB) error
    85  
    86  	report *report
    87  }
    88  
    89  // NewTrackedDBWorker creates a new TrackedDBWorker
    90  func NewTrackedDBWorker(ctx context.Context, dbApp DBApp, namespace string, opts ...TrackedDBWorkerOption) (TrackedDB, error) {
    91  	w := &trackedDBWorker{
    92  		dbApp:      dbApp,
    93  		namespace:  namespace,
    94  		clock:      clock.WallClock,
    95  		pingDBFunc: defaultPingDBFunc,
    96  		report:     &report{},
    97  	}
    98  
    99  	for _, opt := range opts {
   100  		opt(w)
   101  	}
   102  
   103  	var err error
   104  	w.db, err = w.dbApp.Open(context.TODO(), w.namespace)
   105  	if err != nil {
   106  		return nil, errors.Trace(err)
   107  	}
   108  
   109  	if err := pragma.SetPragma(ctx, w.db, pragma.ForeignKeysPragma, true); err != nil {
   110  		return nil, errors.Annotate(err, "setting foreign keys pragma")
   111  	}
   112  
   113  	w.tomb.Go(w.loop)
   114  
   115  	return w, nil
   116  }
   117  
   118  // Txn executes the input function against the tracked database,
   119  // within a transaction that depends on the input context.
   120  // Retry semantics are applied automatically based on transient failures.
   121  // This is the function that almost all downstream database consumers
   122  // should use.
   123  func (w *trackedDBWorker) Txn(ctx context.Context, fn func(context.Context, *sql.Tx) error) error {
   124  	return database.Retry(ctx, func() error {
   125  		return errors.Trace(w.TxnNoRetry(ctx, fn))
   126  	})
   127  }
   128  
   129  // TxnNoRetry executes the input function against the tracked database,
   130  // within a transaction that depends on the input context.
   131  // We meter both the total transaction count and active operations.
   132  func (w *trackedDBWorker) TxnNoRetry(ctx context.Context, fn func(context.Context, *sql.Tx) error) (err error) {
   133  	begin := w.clock.Now()
   134  	w.metrics.TxnRequests.WithLabelValues(w.namespace).Inc()
   135  	w.metrics.DBRequests.WithLabelValues(w.namespace).Inc()
   136  	defer w.meterDBOpResult(begin, err)
   137  
   138  	// If the DB health check failed, the worker's error will be set,
   139  	// and we will be without a usable database reference. Return the error.
   140  	w.mutex.RLock()
   141  	if w.err != nil {
   142  		w.mutex.RUnlock()
   143  		return errors.Trace(w.err)
   144  	}
   145  
   146  	db := w.db
   147  	w.mutex.RUnlock()
   148  
   149  	return errors.Trace(database.Txn(ctx, db, fn))
   150  }
   151  
   152  // meterDBOpResults decrements the active DB operation count,
   153  // and records the result and duration of the completed operation.
   154  func (w *trackedDBWorker) meterDBOpResult(begin time.Time, err error) {
   155  	w.metrics.DBRequests.WithLabelValues(w.namespace).Dec()
   156  	result := "success"
   157  	if err != nil {
   158  		result = "error"
   159  	}
   160  	w.metrics.DBDuration.WithLabelValues(w.namespace, result).Observe(w.clock.Now().Sub(begin).Seconds())
   161  }
   162  
   163  // Err will return any fatal errors that have occurred on the worker, trying
   164  // to acquire the database.
   165  func (w *trackedDBWorker) Err() error {
   166  	w.mutex.RLock()
   167  	defer w.mutex.RUnlock()
   168  
   169  	return w.err
   170  }
   171  
   172  // Kill implements worker.Worker
   173  func (w *trackedDBWorker) Kill() {
   174  	w.tomb.Kill(nil)
   175  }
   176  
   177  // Wait implements worker.Worker
   178  func (w *trackedDBWorker) Wait() error {
   179  	return w.tomb.Wait()
   180  }
   181  
   182  // Report provides information for the engine report.
   183  func (w *trackedDBWorker) Report() map[string]any {
   184  	return w.report.Report()
   185  }
   186  
   187  func (w *trackedDBWorker) loop() error {
   188  	timer := w.clock.NewTimer(PollInterval)
   189  	defer timer.Stop()
   190  
   191  	for {
   192  		select {
   193  		case <-w.tomb.Dying():
   194  			return tomb.ErrDying
   195  		case <-timer.Chan():
   196  			// Any retryable errors are handled at the txn level. If we get an
   197  			// error returning here, we've either exhausted the number of
   198  			// retries or the error was fatal.
   199  			w.mutex.RLock()
   200  			currentDB := w.db
   201  			w.mutex.RUnlock()
   202  
   203  			newDB, err := w.ensureDBAliveAndOpenIfRequired(currentDB)
   204  			if err != nil {
   205  				// If we get an error, ensure we close the underlying db and
   206  				// mark the tracked db in an error state.
   207  				w.mutex.Lock()
   208  				if err := w.db.Close(); err != nil {
   209  					w.logger.Errorf("error closing database: %v", err)
   210  				}
   211  				w.err = errors.Trace(err)
   212  				w.mutex.Unlock()
   213  
   214  				// As we failed attempting to verify the db, we're in a fatal
   215  				// state. Collapse the worker and if required, cause the other
   216  				// workers to follow suite.
   217  				return errors.Trace(err)
   218  			}
   219  
   220  			// We've got a new DB. Close the old one and replace it with the
   221  			// new one, if they're not the same.
   222  			if newDB != currentDB {
   223  				w.mutex.Lock()
   224  				if err := w.db.Close(); err != nil {
   225  					w.logger.Errorf("error closing database: %v", err)
   226  				}
   227  				w.db = newDB
   228  				w.report.Set(func(r *report) {
   229  					r.dbReplacements++
   230  				})
   231  				w.err = nil
   232  				w.mutex.Unlock()
   233  			}
   234  
   235  			timer.Reset(PollInterval)
   236  		}
   237  	}
   238  }
   239  
   240  // ensureDBAliveAndOpenNewIfRequired is a bit long-winded, but it is a way to
   241  // ensure that the underlying database is alive and well. If it is not, we
   242  // attempt to open a new one. If that fails, we return an error.
   243  func (w *trackedDBWorker) ensureDBAliveAndOpenIfRequired(db *sql.DB) (*sql.DB, error) {
   244  	// Allow killing the tomb to cancel the context,
   245  	// so shutdown/restart can not be blocked by this call.
   246  	ctx, cancel := context.WithTimeout(context.TODO(), time.Second*10)
   247  	ctx = w.tomb.Context(ctx)
   248  	defer cancel()
   249  
   250  	if w.logger.IsTraceEnabled() {
   251  		w.logger.Tracef("ensuring database %q is alive", w.namespace)
   252  	}
   253  
   254  	// There are multiple levels of retries here.
   255  	// - We want to retry the ping function for retryable errors.
   256  	//   These might be DB-locked or busy-syncing errors for example.
   257  	// - If the error is fatal, we discard the DB instance and reconnect
   258  	//   before attempting health verification again.
   259  	for i := 0; i < DefaultVerifyAttempts; i++ {
   260  		// Verify that we don't have a potential nil database from the retry
   261  		// semantics.
   262  		if db == nil {
   263  			return nil, errors.NotFoundf("database")
   264  		}
   265  
   266  		// Record the total ping.
   267  		pingStart := w.clock.Now()
   268  		var pingAttempts uint32 = 0
   269  		err := database.Retry(ctx, func() error {
   270  			if w.logger.IsTraceEnabled() {
   271  				w.logger.Tracef("pinging database %q", w.namespace)
   272  			}
   273  			pingAttempts++
   274  			return w.pingDBFunc(ctx, db)
   275  		})
   276  		pingDur := w.clock.Now().Sub(pingStart)
   277  
   278  		// Record the ping attempt and duration.
   279  		w.report.Set(func(r *report) {
   280  			r.pingAttempts = pingAttempts
   281  			r.pingDuration = pingDur
   282  			if pingDur > r.maxPingDuration {
   283  				r.maxPingDuration = pingDur
   284  			}
   285  		})
   286  
   287  		// We were successful at requesting the schema, so we can bail out
   288  		// early.
   289  		if err == nil {
   290  			return db, nil
   291  		}
   292  
   293  		// We exhausted the retry strategy for pinging the database.
   294  		// Terminate the worker with the error.
   295  		if i == DefaultVerifyAttempts-1 {
   296  			return nil, errors.Trace(err)
   297  		}
   298  
   299  		// We got an error that is non-retryable, attempt to open a new database
   300  		// connection and see if that works.
   301  		w.logger.Warningf("unable to ping database %q: attempting to reopen the database before trying again: %v",
   302  			w.namespace, err)
   303  
   304  		// Attempt to open a new database. If there is an error, just crash
   305  		// the worker, we can't do anything else.
   306  		if db, err = w.dbApp.Open(ctx, w.namespace); err != nil {
   307  			return nil, errors.Trace(err)
   308  		}
   309  
   310  		if err := pragma.SetPragma(ctx, db, pragma.ForeignKeysPragma, true); err != nil {
   311  			return nil, errors.Annotate(err, "setting foreign keys pragma")
   312  		}
   313  	}
   314  	return nil, errors.NotValidf("database")
   315  }
   316  
   317  func defaultPingDBFunc(ctx context.Context, db *sql.DB) error {
   318  	return db.PingContext(ctx)
   319  }
   320  
   321  // report fields for the engine report.
   322  type report struct {
   323  	sync.Mutex
   324  
   325  	// pingDuration is the duration of the last ping.
   326  	pingDuration time.Duration
   327  	// pingAttempts is the number of attempts to ping the database for the
   328  	// last ping.
   329  	pingAttempts uint32
   330  	// maxPingDuration is the maximum duration of a ping for a given lifetime
   331  	// of the worker.
   332  	maxPingDuration time.Duration
   333  	// dbReplacements is the number of times the database has been replaced
   334  	// due to a failed ping.
   335  	dbReplacements uint32
   336  }
   337  
   338  // Report provides information for the engine report.
   339  func (r *report) Report() map[string]any {
   340  	r.Lock()
   341  	defer r.Unlock()
   342  
   343  	return map[string]any{
   344  		"last-ping-duration": r.pingDuration.String(),
   345  		"last-ping-attempts": r.pingAttempts,
   346  		"max-ping-duration":  r.maxPingDuration.String(),
   347  		"db-replacements":    r.dbReplacements,
   348  	}
   349  }
   350  
   351  // Set allows to set the report fields, guarded by a mutex.
   352  func (r *report) Set(f func(*report)) {
   353  	r.Lock()
   354  	defer r.Unlock()
   355  
   356  	f(r)
   357  }