github.com/hashicorp/vault/sdk@v0.13.0/database/dbplugin/v5/middleware.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package dbplugin
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"net/url"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/armon/go-metrics"
    14  	"github.com/hashicorp/errwrap"
    15  	log "github.com/hashicorp/go-hclog"
    16  	"github.com/hashicorp/vault/sdk/logical"
    17  	"google.golang.org/grpc/status"
    18  )
    19  
    20  // ///////////////////////////////////////////////////
    21  // Tracing Middleware
    22  // ///////////////////////////////////////////////////
    23  
    24  var (
    25  	_ Database                = databaseTracingMiddleware{}
    26  	_ logical.PluginVersioner = databaseTracingMiddleware{}
    27  )
    28  
    29  // databaseTracingMiddleware wraps a implementation of Database and executes
    30  // trace logging on function call.
    31  type databaseTracingMiddleware struct {
    32  	next   Database
    33  	logger log.Logger
    34  }
    35  
    36  func (mw databaseTracingMiddleware) PluginVersion() (resp logical.PluginVersion) {
    37  	defer func(then time.Time) {
    38  		mw.logger.Trace("version",
    39  			"status", "finished",
    40  			"version", resp,
    41  			"took", time.Since(then))
    42  	}(time.Now())
    43  
    44  	mw.logger.Trace("version", "status", "started")
    45  	if versioner, ok := mw.next.(logical.PluginVersioner); ok {
    46  		return versioner.PluginVersion()
    47  	}
    48  	return logical.EmptyPluginVersion
    49  }
    50  
    51  func (mw databaseTracingMiddleware) Initialize(ctx context.Context, req InitializeRequest) (resp InitializeResponse, err error) {
    52  	defer func(then time.Time) {
    53  		mw.logger.Trace("initialize",
    54  			"status", "finished",
    55  			"verify", req.VerifyConnection,
    56  			"err", err,
    57  			"took", time.Since(then))
    58  	}(time.Now())
    59  
    60  	mw.logger.Trace("initialize", "status", "started")
    61  	return mw.next.Initialize(ctx, req)
    62  }
    63  
    64  func (mw databaseTracingMiddleware) NewUser(ctx context.Context, req NewUserRequest) (resp NewUserResponse, err error) {
    65  	defer func(then time.Time) {
    66  		mw.logger.Trace("create user",
    67  			"status", "finished",
    68  			"err", err,
    69  			"took", time.Since(then))
    70  	}(time.Now())
    71  
    72  	mw.logger.Trace("create user",
    73  		"status", "started")
    74  	return mw.next.NewUser(ctx, req)
    75  }
    76  
    77  func (mw databaseTracingMiddleware) UpdateUser(ctx context.Context, req UpdateUserRequest) (resp UpdateUserResponse, err error) {
    78  	defer func(then time.Time) {
    79  		mw.logger.Trace("update user",
    80  			"status", "finished",
    81  			"err", err,
    82  			"took", time.Since(then))
    83  	}(time.Now())
    84  
    85  	mw.logger.Trace("update user", "status", "started")
    86  	return mw.next.UpdateUser(ctx, req)
    87  }
    88  
    89  func (mw databaseTracingMiddleware) DeleteUser(ctx context.Context, req DeleteUserRequest) (resp DeleteUserResponse, err error) {
    90  	defer func(then time.Time) {
    91  		mw.logger.Trace("delete user",
    92  			"status", "finished",
    93  			"err", err,
    94  			"took", time.Since(then))
    95  	}(time.Now())
    96  
    97  	mw.logger.Trace("delete user",
    98  		"status", "started")
    99  	return mw.next.DeleteUser(ctx, req)
   100  }
   101  
   102  func (mw databaseTracingMiddleware) Type() (string, error) {
   103  	return mw.next.Type()
   104  }
   105  
   106  func (mw databaseTracingMiddleware) Close() (err error) {
   107  	defer func(then time.Time) {
   108  		mw.logger.Trace("close",
   109  			"status", "finished",
   110  			"err", err,
   111  			"took", time.Since(then))
   112  	}(time.Now())
   113  
   114  	mw.logger.Trace("close",
   115  		"status", "started")
   116  	return mw.next.Close()
   117  }
   118  
   119  // ///////////////////////////////////////////////////
   120  // Metrics Middleware Domain
   121  // ///////////////////////////////////////////////////
   122  
   123  var (
   124  	_ Database                = databaseMetricsMiddleware{}
   125  	_ logical.PluginVersioner = databaseMetricsMiddleware{}
   126  )
   127  
   128  // databaseMetricsMiddleware wraps an implementation of Databases and on
   129  // function call logs metrics about this instance.
   130  type databaseMetricsMiddleware struct {
   131  	next Database
   132  
   133  	typeStr string
   134  }
   135  
   136  func (mw databaseMetricsMiddleware) PluginVersion() logical.PluginVersion {
   137  	defer func(now time.Time) {
   138  		metrics.MeasureSince([]string{"database", "PluginVersion"}, now)
   139  		metrics.MeasureSince([]string{"database", mw.typeStr, "PluginVersion"}, now)
   140  	}(time.Now())
   141  
   142  	metrics.IncrCounter([]string{"database", "PluginVersion"}, 1)
   143  	metrics.IncrCounter([]string{"database", mw.typeStr, "PluginVersion"}, 1)
   144  
   145  	if versioner, ok := mw.next.(logical.PluginVersioner); ok {
   146  		return versioner.PluginVersion()
   147  	}
   148  	return logical.EmptyPluginVersion
   149  }
   150  
   151  func (mw databaseMetricsMiddleware) Initialize(ctx context.Context, req InitializeRequest) (resp InitializeResponse, err error) {
   152  	defer func(now time.Time) {
   153  		metrics.MeasureSince([]string{"database", "Initialize"}, now)
   154  		metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
   155  
   156  		if err != nil {
   157  			metrics.IncrCounter([]string{"database", "Initialize", "error"}, 1)
   158  			metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize", "error"}, 1)
   159  		}
   160  	}(time.Now())
   161  
   162  	metrics.IncrCounter([]string{"database", "Initialize"}, 1)
   163  	metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
   164  	return mw.next.Initialize(ctx, req)
   165  }
   166  
   167  func (mw databaseMetricsMiddleware) NewUser(ctx context.Context, req NewUserRequest) (resp NewUserResponse, err error) {
   168  	defer func(start time.Time) {
   169  		metrics.MeasureSince([]string{"database", "NewUser"}, start)
   170  		metrics.MeasureSince([]string{"database", mw.typeStr, "NewUser"}, start)
   171  
   172  		if err != nil {
   173  			metrics.IncrCounter([]string{"database", "NewUser", "error"}, 1)
   174  			metrics.IncrCounter([]string{"database", mw.typeStr, "NewUser", "error"}, 1)
   175  		}
   176  	}(time.Now())
   177  
   178  	metrics.IncrCounter([]string{"database", "NewUser"}, 1)
   179  	metrics.IncrCounter([]string{"database", mw.typeStr, "NewUser"}, 1)
   180  	return mw.next.NewUser(ctx, req)
   181  }
   182  
   183  func (mw databaseMetricsMiddleware) UpdateUser(ctx context.Context, req UpdateUserRequest) (resp UpdateUserResponse, err error) {
   184  	defer func(now time.Time) {
   185  		metrics.MeasureSince([]string{"database", "UpdateUser"}, now)
   186  		metrics.MeasureSince([]string{"database", mw.typeStr, "UpdateUser"}, now)
   187  
   188  		if err != nil {
   189  			metrics.IncrCounter([]string{"database", "UpdateUser", "error"}, 1)
   190  			metrics.IncrCounter([]string{"database", mw.typeStr, "UpdateUser", "error"}, 1)
   191  		}
   192  	}(time.Now())
   193  
   194  	metrics.IncrCounter([]string{"database", "UpdateUser"}, 1)
   195  	metrics.IncrCounter([]string{"database", mw.typeStr, "UpdateUser"}, 1)
   196  	return mw.next.UpdateUser(ctx, req)
   197  }
   198  
   199  func (mw databaseMetricsMiddleware) DeleteUser(ctx context.Context, req DeleteUserRequest) (resp DeleteUserResponse, err error) {
   200  	defer func(now time.Time) {
   201  		metrics.MeasureSince([]string{"database", "DeleteUser"}, now)
   202  		metrics.MeasureSince([]string{"database", mw.typeStr, "DeleteUser"}, now)
   203  
   204  		if err != nil {
   205  			metrics.IncrCounter([]string{"database", "DeleteUser", "error"}, 1)
   206  			metrics.IncrCounter([]string{"database", mw.typeStr, "DeleteUser", "error"}, 1)
   207  		}
   208  	}(time.Now())
   209  
   210  	metrics.IncrCounter([]string{"database", "DeleteUser"}, 1)
   211  	metrics.IncrCounter([]string{"database", mw.typeStr, "DeleteUser"}, 1)
   212  	return mw.next.DeleteUser(ctx, req)
   213  }
   214  
   215  func (mw databaseMetricsMiddleware) Type() (string, error) {
   216  	return mw.next.Type()
   217  }
   218  
   219  func (mw databaseMetricsMiddleware) Close() (err error) {
   220  	defer func(now time.Time) {
   221  		metrics.MeasureSince([]string{"database", "Close"}, now)
   222  		metrics.MeasureSince([]string{"database", mw.typeStr, "Close"}, now)
   223  
   224  		if err != nil {
   225  			metrics.IncrCounter([]string{"database", "Close", "error"}, 1)
   226  			metrics.IncrCounter([]string{"database", mw.typeStr, "Close", "error"}, 1)
   227  		}
   228  	}(time.Now())
   229  
   230  	metrics.IncrCounter([]string{"database", "Close"}, 1)
   231  	metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1)
   232  	return mw.next.Close()
   233  }
   234  
   235  // ///////////////////////////////////////////////////
   236  // Error Sanitizer Middleware Domain
   237  // ///////////////////////////////////////////////////
   238  
   239  var (
   240  	_ Database                = (*DatabaseErrorSanitizerMiddleware)(nil)
   241  	_ logical.PluginVersioner = (*DatabaseErrorSanitizerMiddleware)(nil)
   242  )
   243  
   244  // DatabaseErrorSanitizerMiddleware wraps an implementation of Databases and
   245  // sanitizes returned error messages
   246  type DatabaseErrorSanitizerMiddleware struct {
   247  	next      Database
   248  	secretsFn secretsFn
   249  }
   250  
   251  type secretsFn func() map[string]string
   252  
   253  func NewDatabaseErrorSanitizerMiddleware(next Database, secrets secretsFn) DatabaseErrorSanitizerMiddleware {
   254  	return DatabaseErrorSanitizerMiddleware{
   255  		next:      next,
   256  		secretsFn: secrets,
   257  	}
   258  }
   259  
   260  func (mw DatabaseErrorSanitizerMiddleware) Initialize(ctx context.Context, req InitializeRequest) (resp InitializeResponse, err error) {
   261  	resp, err = mw.next.Initialize(ctx, req)
   262  	return resp, mw.sanitize(err)
   263  }
   264  
   265  func (mw DatabaseErrorSanitizerMiddleware) NewUser(ctx context.Context, req NewUserRequest) (resp NewUserResponse, err error) {
   266  	resp, err = mw.next.NewUser(ctx, req)
   267  	return resp, mw.sanitize(err)
   268  }
   269  
   270  func (mw DatabaseErrorSanitizerMiddleware) UpdateUser(ctx context.Context, req UpdateUserRequest) (UpdateUserResponse, error) {
   271  	resp, err := mw.next.UpdateUser(ctx, req)
   272  	return resp, mw.sanitize(err)
   273  }
   274  
   275  func (mw DatabaseErrorSanitizerMiddleware) DeleteUser(ctx context.Context, req DeleteUserRequest) (DeleteUserResponse, error) {
   276  	resp, err := mw.next.DeleteUser(ctx, req)
   277  	return resp, mw.sanitize(err)
   278  }
   279  
   280  func (mw DatabaseErrorSanitizerMiddleware) Type() (string, error) {
   281  	dbType, err := mw.next.Type()
   282  	return dbType, mw.sanitize(err)
   283  }
   284  
   285  func (mw DatabaseErrorSanitizerMiddleware) Close() (err error) {
   286  	return mw.sanitize(mw.next.Close())
   287  }
   288  
   289  func (mw DatabaseErrorSanitizerMiddleware) PluginVersion() logical.PluginVersion {
   290  	if versioner, ok := mw.next.(logical.PluginVersioner); ok {
   291  		return versioner.PluginVersion()
   292  	}
   293  	return logical.EmptyPluginVersion
   294  }
   295  
   296  // sanitize errors by removing any sensitive strings within their messages. This uses
   297  // the secretsFn to determine what fields should be sanitized.
   298  func (mw DatabaseErrorSanitizerMiddleware) sanitize(err error) error {
   299  	if err == nil {
   300  		return nil
   301  	}
   302  	if errwrap.ContainsType(err, new(url.Error)) {
   303  		return errors.New("unable to parse connection url")
   304  	}
   305  	if mw.secretsFn == nil {
   306  		return err
   307  	}
   308  	for find, replace := range mw.secretsFn() {
   309  		if find == "" {
   310  			continue
   311  		}
   312  
   313  		// Attempt to keep the status code attached to the
   314  		// error while changing the actual error message
   315  		s, ok := status.FromError(err)
   316  		if ok {
   317  			err = status.Error(s.Code(), strings.ReplaceAll(s.Message(), find, replace))
   318  			continue
   319  		}
   320  
   321  		err = errors.New(strings.ReplaceAll(err.Error(), find, replace))
   322  	}
   323  	return err
   324  }