github.com/hashicorp/vault/sdk@v0.13.0/database/dbplugin/databasemiddleware.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package dbplugin
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"net/url"
    10  	"strings"
    11  	"sync"
    12  	"time"
    13  
    14  	metrics "github.com/armon/go-metrics"
    15  	"github.com/hashicorp/errwrap"
    16  	log "github.com/hashicorp/go-hclog"
    17  	"google.golang.org/grpc/status"
    18  )
    19  
    20  // ---- Tracing Middleware Domain ----
    21  
    22  // databaseTracingMiddleware wraps a implementation of Database and executes
    23  // trace logging on function call.
    24  type databaseTracingMiddleware struct {
    25  	next   Database
    26  	logger log.Logger
    27  }
    28  
    29  func (mw *databaseTracingMiddleware) Type() (string, error) {
    30  	return mw.next.Type()
    31  }
    32  
    33  func (mw *databaseTracingMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
    34  	defer func(then time.Time) {
    35  		mw.logger.Trace("create user", "status", "finished", "err", err, "took", time.Since(then))
    36  	}(time.Now())
    37  
    38  	mw.logger.Trace("create user", "status", "started")
    39  	return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
    40  }
    41  
    42  func (mw *databaseTracingMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
    43  	defer func(then time.Time) {
    44  		mw.logger.Trace("renew user", "status", "finished", "err", err, "took", time.Since(then))
    45  	}(time.Now())
    46  
    47  	mw.logger.Trace("renew user", "status", "started")
    48  	return mw.next.RenewUser(ctx, statements, username, expiration)
    49  }
    50  
    51  func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
    52  	defer func(then time.Time) {
    53  		mw.logger.Trace("revoke user", "status", "finished", "err", err, "took", time.Since(then))
    54  	}(time.Now())
    55  
    56  	mw.logger.Trace("revoke user", "status", "started")
    57  	return mw.next.RevokeUser(ctx, statements, username)
    58  }
    59  
    60  func (mw *databaseTracingMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
    61  	defer func(then time.Time) {
    62  		mw.logger.Trace("rotate root credentials", "status", "finished", "err", err, "took", time.Since(then))
    63  	}(time.Now())
    64  
    65  	mw.logger.Trace("rotate root credentials", "status", "started")
    66  	return mw.next.RotateRootCredentials(ctx, statements)
    67  }
    68  
    69  func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
    70  	_, err := mw.Init(ctx, conf, verifyConnection)
    71  	return err
    72  }
    73  
    74  func (mw *databaseTracingMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
    75  	defer func(then time.Time) {
    76  		mw.logger.Trace("initialize", "status", "finished", "verify", verifyConnection, "err", err, "took", time.Since(then))
    77  	}(time.Now())
    78  
    79  	mw.logger.Trace("initialize", "status", "started")
    80  	return mw.next.Init(ctx, conf, verifyConnection)
    81  }
    82  
    83  func (mw *databaseTracingMiddleware) Close() (err error) {
    84  	defer func(then time.Time) {
    85  		mw.logger.Trace("close", "status", "finished", "err", err, "took", time.Since(then))
    86  	}(time.Now())
    87  
    88  	mw.logger.Trace("close", "status", "started")
    89  	return mw.next.Close()
    90  }
    91  
    92  func (mw *databaseTracingMiddleware) GenerateCredentials(ctx context.Context) (password string, err error) {
    93  	defer func(then time.Time) {
    94  		mw.logger.Trace("generate credentials", "status", "finished", "err", err, "took", time.Since(then))
    95  	}(time.Now())
    96  
    97  	mw.logger.Trace("generate credentials", "status", "started")
    98  	return mw.next.GenerateCredentials(ctx)
    99  }
   100  
   101  func (mw *databaseTracingMiddleware) SetCredentials(ctx context.Context, statements Statements, staticConfig StaticUserConfig) (username, password string, err error) {
   102  	defer func(then time.Time) {
   103  		mw.logger.Trace("set credentials", "status", "finished", "err", err, "took", time.Since(then))
   104  	}(time.Now())
   105  
   106  	mw.logger.Trace("set credentials", "status", "started")
   107  	return mw.next.SetCredentials(ctx, statements, staticConfig)
   108  }
   109  
   110  // ---- Metrics Middleware Domain ----
   111  
   112  // databaseMetricsMiddleware wraps an implementation of Databases and on
   113  // function call logs metrics about this instance.
   114  type databaseMetricsMiddleware struct {
   115  	next Database
   116  
   117  	typeStr string
   118  }
   119  
   120  func (mw *databaseMetricsMiddleware) Type() (string, error) {
   121  	return mw.next.Type()
   122  }
   123  
   124  func (mw *databaseMetricsMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
   125  	defer func(now time.Time) {
   126  		metrics.MeasureSince([]string{"database", "CreateUser"}, now)
   127  		metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now)
   128  
   129  		if err != nil {
   130  			metrics.IncrCounter([]string{"database", "CreateUser", "error"}, 1)
   131  			metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser", "error"}, 1)
   132  		}
   133  	}(time.Now())
   134  
   135  	metrics.IncrCounter([]string{"database", "CreateUser"}, 1)
   136  	metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1)
   137  	return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
   138  }
   139  
   140  func (mw *databaseMetricsMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
   141  	defer func(now time.Time) {
   142  		metrics.MeasureSince([]string{"database", "RenewUser"}, now)
   143  		metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now)
   144  
   145  		if err != nil {
   146  			metrics.IncrCounter([]string{"database", "RenewUser", "error"}, 1)
   147  			metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser", "error"}, 1)
   148  		}
   149  	}(time.Now())
   150  
   151  	metrics.IncrCounter([]string{"database", "RenewUser"}, 1)
   152  	metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1)
   153  	return mw.next.RenewUser(ctx, statements, username, expiration)
   154  }
   155  
   156  func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
   157  	defer func(now time.Time) {
   158  		metrics.MeasureSince([]string{"database", "RevokeUser"}, now)
   159  		metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now)
   160  
   161  		if err != nil {
   162  			metrics.IncrCounter([]string{"database", "RevokeUser", "error"}, 1)
   163  			metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser", "error"}, 1)
   164  		}
   165  	}(time.Now())
   166  
   167  	metrics.IncrCounter([]string{"database", "RevokeUser"}, 1)
   168  	metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1)
   169  	return mw.next.RevokeUser(ctx, statements, username)
   170  }
   171  
   172  func (mw *databaseMetricsMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
   173  	defer func(now time.Time) {
   174  		metrics.MeasureSince([]string{"database", "RotateRootCredentials"}, now)
   175  		metrics.MeasureSince([]string{"database", mw.typeStr, "RotateRootCredentials"}, now)
   176  
   177  		if err != nil {
   178  			metrics.IncrCounter([]string{"database", "RotateRootCredentials", "error"}, 1)
   179  			metrics.IncrCounter([]string{"database", mw.typeStr, "RotateRootCredentials", "error"}, 1)
   180  		}
   181  	}(time.Now())
   182  
   183  	metrics.IncrCounter([]string{"database", "RotateRootCredentials"}, 1)
   184  	metrics.IncrCounter([]string{"database", mw.typeStr, "RotateRootCredentials"}, 1)
   185  	return mw.next.RotateRootCredentials(ctx, statements)
   186  }
   187  
   188  func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
   189  	_, err := mw.Init(ctx, conf, verifyConnection)
   190  	return err
   191  }
   192  
   193  func (mw *databaseMetricsMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
   194  	defer func(now time.Time) {
   195  		metrics.MeasureSince([]string{"database", "Initialize"}, now)
   196  		metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
   197  
   198  		if err != nil {
   199  			metrics.IncrCounter([]string{"database", "Initialize", "error"}, 1)
   200  			metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize", "error"}, 1)
   201  		}
   202  	}(time.Now())
   203  
   204  	metrics.IncrCounter([]string{"database", "Initialize"}, 1)
   205  	metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
   206  	return mw.next.Init(ctx, conf, verifyConnection)
   207  }
   208  
   209  func (mw *databaseMetricsMiddleware) Close() (err error) {
   210  	defer func(now time.Time) {
   211  		metrics.MeasureSince([]string{"database", "Close"}, now)
   212  		metrics.MeasureSince([]string{"database", mw.typeStr, "Close"}, now)
   213  
   214  		if err != nil {
   215  			metrics.IncrCounter([]string{"database", "Close", "error"}, 1)
   216  			metrics.IncrCounter([]string{"database", mw.typeStr, "Close", "error"}, 1)
   217  		}
   218  	}(time.Now())
   219  
   220  	metrics.IncrCounter([]string{"database", "Close"}, 1)
   221  	metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1)
   222  	return mw.next.Close()
   223  }
   224  
   225  func (mw *databaseMetricsMiddleware) GenerateCredentials(ctx context.Context) (password string, err error) {
   226  	defer func(now time.Time) {
   227  		metrics.MeasureSince([]string{"database", "GenerateCredentials"}, now)
   228  		metrics.MeasureSince([]string{"database", mw.typeStr, "GenerateCredentials"}, now)
   229  
   230  		if err != nil {
   231  			metrics.IncrCounter([]string{"database", "GenerateCredentials", "error"}, 1)
   232  			metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateCredentials", "error"}, 1)
   233  		}
   234  	}(time.Now())
   235  
   236  	metrics.IncrCounter([]string{"database", "GenerateCredentials"}, 1)
   237  	metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateCredentials"}, 1)
   238  	return mw.next.GenerateCredentials(ctx)
   239  }
   240  
   241  func (mw *databaseMetricsMiddleware) SetCredentials(ctx context.Context, statements Statements, staticConfig StaticUserConfig) (username, password string, err error) {
   242  	defer func(now time.Time) {
   243  		metrics.MeasureSince([]string{"database", "SetCredentials"}, now)
   244  		metrics.MeasureSince([]string{"database", mw.typeStr, "SetCredentials"}, now)
   245  
   246  		if err != nil {
   247  			metrics.IncrCounter([]string{"database", "SetCredentials", "error"}, 1)
   248  			metrics.IncrCounter([]string{"database", mw.typeStr, "SetCredentials", "error"}, 1)
   249  		}
   250  	}(time.Now())
   251  
   252  	metrics.IncrCounter([]string{"database", "SetCredentials"}, 1)
   253  	metrics.IncrCounter([]string{"database", mw.typeStr, "SetCredentials"}, 1)
   254  	return mw.next.SetCredentials(ctx, statements, staticConfig)
   255  }
   256  
   257  // ---- Error Sanitizer Middleware Domain ----
   258  
   259  // DatabaseErrorSanitizerMiddleware wraps an implementation of Databases and
   260  // sanitizes returned error messages
   261  type DatabaseErrorSanitizerMiddleware struct {
   262  	l         sync.RWMutex
   263  	next      Database
   264  	secretsFn func() map[string]interface{}
   265  }
   266  
   267  func NewDatabaseErrorSanitizerMiddleware(next Database, secretsFn func() map[string]interface{}) *DatabaseErrorSanitizerMiddleware {
   268  	return &DatabaseErrorSanitizerMiddleware{
   269  		next:      next,
   270  		secretsFn: secretsFn,
   271  	}
   272  }
   273  
   274  func (mw *DatabaseErrorSanitizerMiddleware) Type() (string, error) {
   275  	dbType, err := mw.next.Type()
   276  	return dbType, mw.sanitize(err)
   277  }
   278  
   279  func (mw *DatabaseErrorSanitizerMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
   280  	username, password, err = mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
   281  	return username, password, mw.sanitize(err)
   282  }
   283  
   284  func (mw *DatabaseErrorSanitizerMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
   285  	return mw.sanitize(mw.next.RenewUser(ctx, statements, username, expiration))
   286  }
   287  
   288  func (mw *DatabaseErrorSanitizerMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
   289  	return mw.sanitize(mw.next.RevokeUser(ctx, statements, username))
   290  }
   291  
   292  func (mw *DatabaseErrorSanitizerMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
   293  	conf, err = mw.next.RotateRootCredentials(ctx, statements)
   294  	return conf, mw.sanitize(err)
   295  }
   296  
   297  func (mw *DatabaseErrorSanitizerMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
   298  	_, err := mw.Init(ctx, conf, verifyConnection)
   299  	return err
   300  }
   301  
   302  func (mw *DatabaseErrorSanitizerMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
   303  	saveConf, err = mw.next.Init(ctx, conf, verifyConnection)
   304  	return saveConf, mw.sanitize(err)
   305  }
   306  
   307  func (mw *DatabaseErrorSanitizerMiddleware) Close() (err error) {
   308  	return mw.sanitize(mw.next.Close())
   309  }
   310  
   311  // sanitize
   312  func (mw *DatabaseErrorSanitizerMiddleware) sanitize(err error) error {
   313  	if err == nil {
   314  		return nil
   315  	}
   316  	if errwrap.ContainsType(err, new(url.Error)) {
   317  		return errors.New("unable to parse connection url")
   318  	}
   319  	if mw.secretsFn != nil {
   320  		for k, v := range mw.secretsFn() {
   321  			if k == "" {
   322  				continue
   323  			}
   324  
   325  			// Attempt to keep the status code attached to the
   326  			// error without changing the actual error message
   327  			s, ok := status.FromError(err)
   328  			if ok {
   329  				err = status.Error(s.Code(), strings.ReplaceAll(s.Message(), k, v.(string)))
   330  				continue
   331  			}
   332  
   333  			err = errors.New(strings.ReplaceAll(err.Error(), k, v.(string)))
   334  		}
   335  	}
   336  	return err
   337  }
   338  
   339  func (mw *DatabaseErrorSanitizerMiddleware) GenerateCredentials(ctx context.Context) (password string, err error) {
   340  	password, err = mw.next.GenerateCredentials(ctx)
   341  	return password, mw.sanitize(err)
   342  }
   343  
   344  func (mw *DatabaseErrorSanitizerMiddleware) SetCredentials(ctx context.Context, statements Statements, staticConfig StaticUserConfig) (username, password string, err error) {
   345  	username, password, err = mw.next.SetCredentials(ctx, statements, staticConfig)
   346  	return username, password, mw.sanitize(err)
   347  }