github.com/hashicorp/vault/sdk@v0.13.0/database/dbplugin/v5/middleware.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package dbplugin 5 6 import ( 7 "context" 8 "errors" 9 "net/url" 10 "strings" 11 "time" 12 13 "github.com/armon/go-metrics" 14 "github.com/hashicorp/errwrap" 15 log "github.com/hashicorp/go-hclog" 16 "github.com/hashicorp/vault/sdk/logical" 17 "google.golang.org/grpc/status" 18 ) 19 20 // /////////////////////////////////////////////////// 21 // Tracing Middleware 22 // /////////////////////////////////////////////////// 23 24 var ( 25 _ Database = databaseTracingMiddleware{} 26 _ logical.PluginVersioner = databaseTracingMiddleware{} 27 ) 28 29 // databaseTracingMiddleware wraps a implementation of Database and executes 30 // trace logging on function call. 31 type databaseTracingMiddleware struct { 32 next Database 33 logger log.Logger 34 } 35 36 func (mw databaseTracingMiddleware) PluginVersion() (resp logical.PluginVersion) { 37 defer func(then time.Time) { 38 mw.logger.Trace("version", 39 "status", "finished", 40 "version", resp, 41 "took", time.Since(then)) 42 }(time.Now()) 43 44 mw.logger.Trace("version", "status", "started") 45 if versioner, ok := mw.next.(logical.PluginVersioner); ok { 46 return versioner.PluginVersion() 47 } 48 return logical.EmptyPluginVersion 49 } 50 51 func (mw databaseTracingMiddleware) Initialize(ctx context.Context, req InitializeRequest) (resp InitializeResponse, err error) { 52 defer func(then time.Time) { 53 mw.logger.Trace("initialize", 54 "status", "finished", 55 "verify", req.VerifyConnection, 56 "err", err, 57 "took", time.Since(then)) 58 }(time.Now()) 59 60 mw.logger.Trace("initialize", "status", "started") 61 return mw.next.Initialize(ctx, req) 62 } 63 64 func (mw databaseTracingMiddleware) NewUser(ctx context.Context, req NewUserRequest) (resp NewUserResponse, err error) { 65 defer func(then time.Time) { 66 mw.logger.Trace("create user", 67 "status", "finished", 68 "err", err, 69 "took", time.Since(then)) 70 }(time.Now()) 71 72 mw.logger.Trace("create user", 73 "status", "started") 74 return mw.next.NewUser(ctx, req) 75 } 76 77 func (mw databaseTracingMiddleware) UpdateUser(ctx context.Context, req UpdateUserRequest) (resp UpdateUserResponse, err error) { 78 defer func(then time.Time) { 79 mw.logger.Trace("update user", 80 "status", "finished", 81 "err", err, 82 "took", time.Since(then)) 83 }(time.Now()) 84 85 mw.logger.Trace("update user", "status", "started") 86 return mw.next.UpdateUser(ctx, req) 87 } 88 89 func (mw databaseTracingMiddleware) DeleteUser(ctx context.Context, req DeleteUserRequest) (resp DeleteUserResponse, err error) { 90 defer func(then time.Time) { 91 mw.logger.Trace("delete user", 92 "status", "finished", 93 "err", err, 94 "took", time.Since(then)) 95 }(time.Now()) 96 97 mw.logger.Trace("delete user", 98 "status", "started") 99 return mw.next.DeleteUser(ctx, req) 100 } 101 102 func (mw databaseTracingMiddleware) Type() (string, error) { 103 return mw.next.Type() 104 } 105 106 func (mw databaseTracingMiddleware) Close() (err error) { 107 defer func(then time.Time) { 108 mw.logger.Trace("close", 109 "status", "finished", 110 "err", err, 111 "took", time.Since(then)) 112 }(time.Now()) 113 114 mw.logger.Trace("close", 115 "status", "started") 116 return mw.next.Close() 117 } 118 119 // /////////////////////////////////////////////////// 120 // Metrics Middleware Domain 121 // /////////////////////////////////////////////////// 122 123 var ( 124 _ Database = databaseMetricsMiddleware{} 125 _ logical.PluginVersioner = databaseMetricsMiddleware{} 126 ) 127 128 // databaseMetricsMiddleware wraps an implementation of Databases and on 129 // function call logs metrics about this instance. 130 type databaseMetricsMiddleware struct { 131 next Database 132 133 typeStr string 134 } 135 136 func (mw databaseMetricsMiddleware) PluginVersion() logical.PluginVersion { 137 defer func(now time.Time) { 138 metrics.MeasureSince([]string{"database", "PluginVersion"}, now) 139 metrics.MeasureSince([]string{"database", mw.typeStr, "PluginVersion"}, now) 140 }(time.Now()) 141 142 metrics.IncrCounter([]string{"database", "PluginVersion"}, 1) 143 metrics.IncrCounter([]string{"database", mw.typeStr, "PluginVersion"}, 1) 144 145 if versioner, ok := mw.next.(logical.PluginVersioner); ok { 146 return versioner.PluginVersion() 147 } 148 return logical.EmptyPluginVersion 149 } 150 151 func (mw databaseMetricsMiddleware) Initialize(ctx context.Context, req InitializeRequest) (resp InitializeResponse, err error) { 152 defer func(now time.Time) { 153 metrics.MeasureSince([]string{"database", "Initialize"}, now) 154 metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now) 155 156 if err != nil { 157 metrics.IncrCounter([]string{"database", "Initialize", "error"}, 1) 158 metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize", "error"}, 1) 159 } 160 }(time.Now()) 161 162 metrics.IncrCounter([]string{"database", "Initialize"}, 1) 163 metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1) 164 return mw.next.Initialize(ctx, req) 165 } 166 167 func (mw databaseMetricsMiddleware) NewUser(ctx context.Context, req NewUserRequest) (resp NewUserResponse, err error) { 168 defer func(start time.Time) { 169 metrics.MeasureSince([]string{"database", "NewUser"}, start) 170 metrics.MeasureSince([]string{"database", mw.typeStr, "NewUser"}, start) 171 172 if err != nil { 173 metrics.IncrCounter([]string{"database", "NewUser", "error"}, 1) 174 metrics.IncrCounter([]string{"database", mw.typeStr, "NewUser", "error"}, 1) 175 } 176 }(time.Now()) 177 178 metrics.IncrCounter([]string{"database", "NewUser"}, 1) 179 metrics.IncrCounter([]string{"database", mw.typeStr, "NewUser"}, 1) 180 return mw.next.NewUser(ctx, req) 181 } 182 183 func (mw databaseMetricsMiddleware) UpdateUser(ctx context.Context, req UpdateUserRequest) (resp UpdateUserResponse, err error) { 184 defer func(now time.Time) { 185 metrics.MeasureSince([]string{"database", "UpdateUser"}, now) 186 metrics.MeasureSince([]string{"database", mw.typeStr, "UpdateUser"}, now) 187 188 if err != nil { 189 metrics.IncrCounter([]string{"database", "UpdateUser", "error"}, 1) 190 metrics.IncrCounter([]string{"database", mw.typeStr, "UpdateUser", "error"}, 1) 191 } 192 }(time.Now()) 193 194 metrics.IncrCounter([]string{"database", "UpdateUser"}, 1) 195 metrics.IncrCounter([]string{"database", mw.typeStr, "UpdateUser"}, 1) 196 return mw.next.UpdateUser(ctx, req) 197 } 198 199 func (mw databaseMetricsMiddleware) DeleteUser(ctx context.Context, req DeleteUserRequest) (resp DeleteUserResponse, err error) { 200 defer func(now time.Time) { 201 metrics.MeasureSince([]string{"database", "DeleteUser"}, now) 202 metrics.MeasureSince([]string{"database", mw.typeStr, "DeleteUser"}, now) 203 204 if err != nil { 205 metrics.IncrCounter([]string{"database", "DeleteUser", "error"}, 1) 206 metrics.IncrCounter([]string{"database", mw.typeStr, "DeleteUser", "error"}, 1) 207 } 208 }(time.Now()) 209 210 metrics.IncrCounter([]string{"database", "DeleteUser"}, 1) 211 metrics.IncrCounter([]string{"database", mw.typeStr, "DeleteUser"}, 1) 212 return mw.next.DeleteUser(ctx, req) 213 } 214 215 func (mw databaseMetricsMiddleware) Type() (string, error) { 216 return mw.next.Type() 217 } 218 219 func (mw databaseMetricsMiddleware) Close() (err error) { 220 defer func(now time.Time) { 221 metrics.MeasureSince([]string{"database", "Close"}, now) 222 metrics.MeasureSince([]string{"database", mw.typeStr, "Close"}, now) 223 224 if err != nil { 225 metrics.IncrCounter([]string{"database", "Close", "error"}, 1) 226 metrics.IncrCounter([]string{"database", mw.typeStr, "Close", "error"}, 1) 227 } 228 }(time.Now()) 229 230 metrics.IncrCounter([]string{"database", "Close"}, 1) 231 metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1) 232 return mw.next.Close() 233 } 234 235 // /////////////////////////////////////////////////// 236 // Error Sanitizer Middleware Domain 237 // /////////////////////////////////////////////////// 238 239 var ( 240 _ Database = (*DatabaseErrorSanitizerMiddleware)(nil) 241 _ logical.PluginVersioner = (*DatabaseErrorSanitizerMiddleware)(nil) 242 ) 243 244 // DatabaseErrorSanitizerMiddleware wraps an implementation of Databases and 245 // sanitizes returned error messages 246 type DatabaseErrorSanitizerMiddleware struct { 247 next Database 248 secretsFn secretsFn 249 } 250 251 type secretsFn func() map[string]string 252 253 func NewDatabaseErrorSanitizerMiddleware(next Database, secrets secretsFn) DatabaseErrorSanitizerMiddleware { 254 return DatabaseErrorSanitizerMiddleware{ 255 next: next, 256 secretsFn: secrets, 257 } 258 } 259 260 func (mw DatabaseErrorSanitizerMiddleware) Initialize(ctx context.Context, req InitializeRequest) (resp InitializeResponse, err error) { 261 resp, err = mw.next.Initialize(ctx, req) 262 return resp, mw.sanitize(err) 263 } 264 265 func (mw DatabaseErrorSanitizerMiddleware) NewUser(ctx context.Context, req NewUserRequest) (resp NewUserResponse, err error) { 266 resp, err = mw.next.NewUser(ctx, req) 267 return resp, mw.sanitize(err) 268 } 269 270 func (mw DatabaseErrorSanitizerMiddleware) UpdateUser(ctx context.Context, req UpdateUserRequest) (UpdateUserResponse, error) { 271 resp, err := mw.next.UpdateUser(ctx, req) 272 return resp, mw.sanitize(err) 273 } 274 275 func (mw DatabaseErrorSanitizerMiddleware) DeleteUser(ctx context.Context, req DeleteUserRequest) (DeleteUserResponse, error) { 276 resp, err := mw.next.DeleteUser(ctx, req) 277 return resp, mw.sanitize(err) 278 } 279 280 func (mw DatabaseErrorSanitizerMiddleware) Type() (string, error) { 281 dbType, err := mw.next.Type() 282 return dbType, mw.sanitize(err) 283 } 284 285 func (mw DatabaseErrorSanitizerMiddleware) Close() (err error) { 286 return mw.sanitize(mw.next.Close()) 287 } 288 289 func (mw DatabaseErrorSanitizerMiddleware) PluginVersion() logical.PluginVersion { 290 if versioner, ok := mw.next.(logical.PluginVersioner); ok { 291 return versioner.PluginVersion() 292 } 293 return logical.EmptyPluginVersion 294 } 295 296 // sanitize errors by removing any sensitive strings within their messages. This uses 297 // the secretsFn to determine what fields should be sanitized. 298 func (mw DatabaseErrorSanitizerMiddleware) sanitize(err error) error { 299 if err == nil { 300 return nil 301 } 302 if errwrap.ContainsType(err, new(url.Error)) { 303 return errors.New("unable to parse connection url") 304 } 305 if mw.secretsFn == nil { 306 return err 307 } 308 for find, replace := range mw.secretsFn() { 309 if find == "" { 310 continue 311 } 312 313 // Attempt to keep the status code attached to the 314 // error while changing the actual error message 315 s, ok := status.FromError(err) 316 if ok { 317 err = status.Error(s.Code(), strings.ReplaceAll(s.Message(), find, replace)) 318 continue 319 } 320 321 err = errors.New(strings.ReplaceAll(err.Error(), find, replace)) 322 } 323 return err 324 }