github.com/hashicorp/vault/sdk@v0.13.0/database/dbplugin/v5/grpc_server.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package dbplugin 5 6 import ( 7 "context" 8 "errors" 9 "fmt" 10 "sync" 11 "time" 12 13 "github.com/golang/protobuf/ptypes" 14 "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" 15 "github.com/hashicorp/vault/sdk/helper/base62" 16 "github.com/hashicorp/vault/sdk/helper/pluginutil" 17 "github.com/hashicorp/vault/sdk/logical" 18 "google.golang.org/grpc/codes" 19 "google.golang.org/grpc/status" 20 ) 21 22 var _ proto.DatabaseServer = &gRPCServer{} 23 24 type gRPCServer struct { 25 proto.UnimplementedDatabaseServer 26 logical.UnimplementedPluginVersionServer 27 28 // holds the non-multiplexed Database 29 // when this is set the plugin does not support multiplexing 30 singleImpl Database 31 32 // instances holds the multiplexed Databases 33 instances map[string]Database 34 factoryFunc func() (interface{}, error) 35 36 sync.RWMutex 37 } 38 39 func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) { 40 g.Lock() 41 defer g.Unlock() 42 43 if g.singleImpl != nil { 44 return g.singleImpl, nil 45 } 46 47 id, err := pluginutil.GetMultiplexIDFromContext(ctx) 48 if err != nil { 49 return nil, err 50 } 51 if db, ok := g.instances[id]; ok { 52 return db, nil 53 } 54 return g.createDatabase(id) 55 } 56 57 // must hold the g.Lock() to call this function 58 func (g *gRPCServer) createDatabase(id string) (Database, error) { 59 db, err := g.factoryFunc() 60 if err != nil { 61 return nil, err 62 } 63 64 database := db.(Database) 65 g.instances[id] = database 66 67 return database, nil 68 } 69 70 // getDatabaseInternal returns the database but does not hold a lock 71 func (g *gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error) { 72 if g.singleImpl != nil { 73 return g.singleImpl, nil 74 } 75 76 id, err := pluginutil.GetMultiplexIDFromContext(ctx) 77 if err != nil { 78 return nil, err 79 } 80 81 if db, ok := g.instances[id]; ok { 82 return db, nil 83 } 84 85 return nil, fmt.Errorf("no database instance found") 86 } 87 88 // getDatabase holds a read lock and returns the database 89 func (g *gRPCServer) getDatabase(ctx context.Context) (Database, error) { 90 g.RLock() 91 impl, err := g.getDatabaseInternal(ctx) 92 g.RUnlock() 93 return impl, err 94 } 95 96 // Initialize the database plugin 97 func (g *gRPCServer) Initialize(ctx context.Context, request *proto.InitializeRequest) (*proto.InitializeResponse, error) { 98 impl, err := g.getOrCreateDatabase(ctx) 99 if err != nil { 100 return nil, err 101 } 102 103 rawConfig := structToMap(request.ConfigData) 104 105 dbReq := InitializeRequest{ 106 Config: rawConfig, 107 VerifyConnection: request.VerifyConnection, 108 } 109 110 dbResp, err := impl.Initialize(ctx, dbReq) 111 if err != nil { 112 return &proto.InitializeResponse{}, status.Errorf(codes.Internal, "failed to initialize: %s", err) 113 } 114 115 newConfig, err := mapToStruct(dbResp.Config) 116 if err != nil { 117 return &proto.InitializeResponse{}, status.Errorf(codes.Internal, "failed to marshal new config to JSON: %s", err) 118 } 119 120 resp := &proto.InitializeResponse{ 121 ConfigData: newConfig, 122 } 123 124 return resp, nil 125 } 126 127 func (g *gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*proto.NewUserResponse, error) { 128 if req.GetUsernameConfig() == nil { 129 return &proto.NewUserResponse{}, status.Errorf(codes.InvalidArgument, "missing username config") 130 } 131 132 var expiration time.Time 133 134 if req.GetExpiration() != nil { 135 exp, err := ptypes.Timestamp(req.GetExpiration()) 136 if err != nil { 137 return &proto.NewUserResponse{}, status.Errorf(codes.InvalidArgument, "unable to parse expiration date: %s", err) 138 } 139 expiration = exp 140 } 141 142 impl, err := g.getDatabase(ctx) 143 if err != nil { 144 return nil, err 145 } 146 147 dbReq := NewUserRequest{ 148 UsernameConfig: UsernameMetadata{ 149 DisplayName: req.GetUsernameConfig().GetDisplayName(), 150 RoleName: req.GetUsernameConfig().GetRoleName(), 151 }, 152 CredentialType: CredentialType(req.GetCredentialType()), 153 Password: req.GetPassword(), 154 PublicKey: req.GetPublicKey(), 155 Subject: req.GetSubject(), 156 Expiration: expiration, 157 Statements: getStatementsFromProto(req.GetStatements()), 158 RollbackStatements: getStatementsFromProto(req.GetRollbackStatements()), 159 } 160 161 dbResp, err := impl.NewUser(ctx, dbReq) 162 if err != nil { 163 return &proto.NewUserResponse{}, status.Errorf(codes.Internal, "unable to create new user: %s", err) 164 } 165 166 resp := &proto.NewUserResponse{ 167 Username: dbResp.Username, 168 } 169 return resp, nil 170 } 171 172 func (g *gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest) (*proto.UpdateUserResponse, error) { 173 if req.GetUsername() == "" { 174 return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided") 175 } 176 177 dbReq, err := getUpdateUserRequest(req) 178 if err != nil { 179 return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, err.Error()) 180 } 181 182 impl, err := g.getDatabase(ctx) 183 if err != nil { 184 return nil, err 185 } 186 187 _, err = impl.UpdateUser(ctx, dbReq) 188 if err != nil { 189 return &proto.UpdateUserResponse{}, status.Errorf(codes.Internal, "unable to update user: %s", err) 190 } 191 return &proto.UpdateUserResponse{}, nil 192 } 193 194 func getUpdateUserRequest(req *proto.UpdateUserRequest) (UpdateUserRequest, error) { 195 var password *ChangePassword 196 if req.GetPassword() != nil && req.GetPassword().GetNewPassword() != "" { 197 password = &ChangePassword{ 198 NewPassword: req.GetPassword().GetNewPassword(), 199 Statements: getStatementsFromProto(req.GetPassword().GetStatements()), 200 } 201 } 202 203 var publicKey *ChangePublicKey 204 if req.GetPublicKey() != nil && len(req.GetPublicKey().GetNewPublicKey()) > 0 { 205 publicKey = &ChangePublicKey{ 206 NewPublicKey: req.GetPublicKey().GetNewPublicKey(), 207 Statements: getStatementsFromProto(req.GetPublicKey().GetStatements()), 208 } 209 } 210 211 var expiration *ChangeExpiration 212 if req.GetExpiration() != nil && req.GetExpiration().GetNewExpiration() != nil { 213 newExpiration, err := ptypes.Timestamp(req.GetExpiration().GetNewExpiration()) 214 if err != nil { 215 return UpdateUserRequest{}, fmt.Errorf("unable to parse new expiration: %w", err) 216 } 217 218 expiration = &ChangeExpiration{ 219 NewExpiration: newExpiration, 220 Statements: getStatementsFromProto(req.GetExpiration().GetStatements()), 221 } 222 } 223 224 dbReq := UpdateUserRequest{ 225 Username: req.GetUsername(), 226 CredentialType: CredentialType(req.GetCredentialType()), 227 Password: password, 228 PublicKey: publicKey, 229 Expiration: expiration, 230 } 231 232 if !hasChange(dbReq) { 233 return UpdateUserRequest{}, fmt.Errorf("update user request has no changes") 234 } 235 236 return dbReq, nil 237 } 238 239 func hasChange(dbReq UpdateUserRequest) bool { 240 if dbReq.Password != nil && dbReq.Password.NewPassword != "" { 241 return true 242 } 243 if dbReq.PublicKey != nil && len(dbReq.PublicKey.NewPublicKey) > 0 { 244 return true 245 } 246 if dbReq.Expiration != nil && !dbReq.Expiration.NewExpiration.IsZero() { 247 return true 248 } 249 return false 250 } 251 252 func (g *gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest) (*proto.DeleteUserResponse, error) { 253 if req.GetUsername() == "" { 254 return &proto.DeleteUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided") 255 } 256 dbReq := DeleteUserRequest{ 257 Username: req.GetUsername(), 258 Statements: getStatementsFromProto(req.GetStatements()), 259 } 260 261 impl, err := g.getDatabase(ctx) 262 if err != nil { 263 return nil, err 264 } 265 266 _, err = impl.DeleteUser(ctx, dbReq) 267 if err != nil { 268 return &proto.DeleteUserResponse{}, status.Errorf(codes.Internal, "unable to delete user: %s", err) 269 } 270 return &proto.DeleteUserResponse{}, nil 271 } 272 273 func (g *gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeResponse, error) { 274 impl, err := g.getOrCreateDatabase(ctx) 275 if err != nil { 276 return nil, err 277 } 278 279 t, err := impl.Type() 280 if err != nil { 281 return &proto.TypeResponse{}, status.Errorf(codes.Internal, "unable to retrieve type: %s", err) 282 } 283 284 resp := &proto.TypeResponse{ 285 Type: t, 286 } 287 return resp, nil 288 } 289 290 func (g *gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, error) { 291 g.Lock() 292 defer g.Unlock() 293 294 impl, err := g.getDatabaseInternal(ctx) 295 if err != nil { 296 return nil, err 297 } 298 299 err = impl.Close() 300 if err != nil { 301 return &proto.Empty{}, status.Errorf(codes.Internal, "unable to close database plugin: %s", err) 302 } 303 304 if g.singleImpl == nil { 305 // only cleanup instances map when multiplexing is supported 306 id, err := pluginutil.GetMultiplexIDFromContext(ctx) 307 if err != nil { 308 return nil, err 309 } 310 delete(g.instances, id) 311 } 312 313 return &proto.Empty{}, nil 314 } 315 316 // getOrForceCreateDatabase will create a database even if the multiplexing ID is not present 317 func (g *gRPCServer) getOrForceCreateDatabase(ctx context.Context) (Database, error) { 318 impl, err := g.getOrCreateDatabase(ctx) 319 if errors.Is(err, pluginutil.ErrNoMultiplexingIDFound) { 320 // if this is called without a multiplexing context, like from the plugin catalog directly, 321 // then we won't have a database ID, so let's generate a new database instance 322 id, err := base62.Random(10) 323 if err != nil { 324 return nil, err 325 } 326 327 g.Lock() 328 defer g.Unlock() 329 impl, err = g.createDatabase(id) 330 if err != nil { 331 return nil, err 332 } 333 } else if err != nil { 334 return nil, err 335 } 336 return impl, nil 337 } 338 339 // Version forwards the version request to the underlying Database implementation. 340 func (g *gRPCServer) Version(ctx context.Context, _ *logical.Empty) (*logical.VersionReply, error) { 341 impl, err := g.getOrForceCreateDatabase(ctx) 342 if err != nil { 343 return nil, err 344 } 345 346 if versioner, ok := impl.(logical.PluginVersioner); ok { 347 return &logical.VersionReply{PluginVersion: versioner.PluginVersion().Version}, nil 348 } 349 return &logical.VersionReply{}, nil 350 } 351 352 func getStatementsFromProto(protoStmts *proto.Statements) (statements Statements) { 353 if protoStmts == nil { 354 return statements 355 } 356 cmds := protoStmts.GetCommands() 357 statements = Statements{ 358 Commands: cmds, 359 } 360 return statements 361 }