github.com/hashicorp/vault/sdk@v0.13.0/database/dbplugin/v5/grpc_server.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package dbplugin
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"fmt"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/golang/protobuf/ptypes"
    14  	"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
    15  	"github.com/hashicorp/vault/sdk/helper/base62"
    16  	"github.com/hashicorp/vault/sdk/helper/pluginutil"
    17  	"github.com/hashicorp/vault/sdk/logical"
    18  	"google.golang.org/grpc/codes"
    19  	"google.golang.org/grpc/status"
    20  )
    21  
    22  var _ proto.DatabaseServer = &gRPCServer{}
    23  
    24  type gRPCServer struct {
    25  	proto.UnimplementedDatabaseServer
    26  	logical.UnimplementedPluginVersionServer
    27  
    28  	// holds the non-multiplexed Database
    29  	// when this is set the plugin does not support multiplexing
    30  	singleImpl Database
    31  
    32  	// instances holds the multiplexed Databases
    33  	instances   map[string]Database
    34  	factoryFunc func() (interface{}, error)
    35  
    36  	sync.RWMutex
    37  }
    38  
    39  func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) {
    40  	g.Lock()
    41  	defer g.Unlock()
    42  
    43  	if g.singleImpl != nil {
    44  		return g.singleImpl, nil
    45  	}
    46  
    47  	id, err := pluginutil.GetMultiplexIDFromContext(ctx)
    48  	if err != nil {
    49  		return nil, err
    50  	}
    51  	if db, ok := g.instances[id]; ok {
    52  		return db, nil
    53  	}
    54  	return g.createDatabase(id)
    55  }
    56  
    57  // must hold the g.Lock() to call this function
    58  func (g *gRPCServer) createDatabase(id string) (Database, error) {
    59  	db, err := g.factoryFunc()
    60  	if err != nil {
    61  		return nil, err
    62  	}
    63  
    64  	database := db.(Database)
    65  	g.instances[id] = database
    66  
    67  	return database, nil
    68  }
    69  
    70  // getDatabaseInternal returns the database but does not hold a lock
    71  func (g *gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error) {
    72  	if g.singleImpl != nil {
    73  		return g.singleImpl, nil
    74  	}
    75  
    76  	id, err := pluginutil.GetMultiplexIDFromContext(ctx)
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  
    81  	if db, ok := g.instances[id]; ok {
    82  		return db, nil
    83  	}
    84  
    85  	return nil, fmt.Errorf("no database instance found")
    86  }
    87  
    88  // getDatabase holds a read lock and returns the database
    89  func (g *gRPCServer) getDatabase(ctx context.Context) (Database, error) {
    90  	g.RLock()
    91  	impl, err := g.getDatabaseInternal(ctx)
    92  	g.RUnlock()
    93  	return impl, err
    94  }
    95  
    96  // Initialize the database plugin
    97  func (g *gRPCServer) Initialize(ctx context.Context, request *proto.InitializeRequest) (*proto.InitializeResponse, error) {
    98  	impl, err := g.getOrCreateDatabase(ctx)
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  
   103  	rawConfig := structToMap(request.ConfigData)
   104  
   105  	dbReq := InitializeRequest{
   106  		Config:           rawConfig,
   107  		VerifyConnection: request.VerifyConnection,
   108  	}
   109  
   110  	dbResp, err := impl.Initialize(ctx, dbReq)
   111  	if err != nil {
   112  		return &proto.InitializeResponse{}, status.Errorf(codes.Internal, "failed to initialize: %s", err)
   113  	}
   114  
   115  	newConfig, err := mapToStruct(dbResp.Config)
   116  	if err != nil {
   117  		return &proto.InitializeResponse{}, status.Errorf(codes.Internal, "failed to marshal new config to JSON: %s", err)
   118  	}
   119  
   120  	resp := &proto.InitializeResponse{
   121  		ConfigData: newConfig,
   122  	}
   123  
   124  	return resp, nil
   125  }
   126  
   127  func (g *gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*proto.NewUserResponse, error) {
   128  	if req.GetUsernameConfig() == nil {
   129  		return &proto.NewUserResponse{}, status.Errorf(codes.InvalidArgument, "missing username config")
   130  	}
   131  
   132  	var expiration time.Time
   133  
   134  	if req.GetExpiration() != nil {
   135  		exp, err := ptypes.Timestamp(req.GetExpiration())
   136  		if err != nil {
   137  			return &proto.NewUserResponse{}, status.Errorf(codes.InvalidArgument, "unable to parse expiration date: %s", err)
   138  		}
   139  		expiration = exp
   140  	}
   141  
   142  	impl, err := g.getDatabase(ctx)
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	dbReq := NewUserRequest{
   148  		UsernameConfig: UsernameMetadata{
   149  			DisplayName: req.GetUsernameConfig().GetDisplayName(),
   150  			RoleName:    req.GetUsernameConfig().GetRoleName(),
   151  		},
   152  		CredentialType:     CredentialType(req.GetCredentialType()),
   153  		Password:           req.GetPassword(),
   154  		PublicKey:          req.GetPublicKey(),
   155  		Subject:            req.GetSubject(),
   156  		Expiration:         expiration,
   157  		Statements:         getStatementsFromProto(req.GetStatements()),
   158  		RollbackStatements: getStatementsFromProto(req.GetRollbackStatements()),
   159  	}
   160  
   161  	dbResp, err := impl.NewUser(ctx, dbReq)
   162  	if err != nil {
   163  		return &proto.NewUserResponse{}, status.Errorf(codes.Internal, "unable to create new user: %s", err)
   164  	}
   165  
   166  	resp := &proto.NewUserResponse{
   167  		Username: dbResp.Username,
   168  	}
   169  	return resp, nil
   170  }
   171  
   172  func (g *gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest) (*proto.UpdateUserResponse, error) {
   173  	if req.GetUsername() == "" {
   174  		return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided")
   175  	}
   176  
   177  	dbReq, err := getUpdateUserRequest(req)
   178  	if err != nil {
   179  		return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, err.Error())
   180  	}
   181  
   182  	impl, err := g.getDatabase(ctx)
   183  	if err != nil {
   184  		return nil, err
   185  	}
   186  
   187  	_, err = impl.UpdateUser(ctx, dbReq)
   188  	if err != nil {
   189  		return &proto.UpdateUserResponse{}, status.Errorf(codes.Internal, "unable to update user: %s", err)
   190  	}
   191  	return &proto.UpdateUserResponse{}, nil
   192  }
   193  
   194  func getUpdateUserRequest(req *proto.UpdateUserRequest) (UpdateUserRequest, error) {
   195  	var password *ChangePassword
   196  	if req.GetPassword() != nil && req.GetPassword().GetNewPassword() != "" {
   197  		password = &ChangePassword{
   198  			NewPassword: req.GetPassword().GetNewPassword(),
   199  			Statements:  getStatementsFromProto(req.GetPassword().GetStatements()),
   200  		}
   201  	}
   202  
   203  	var publicKey *ChangePublicKey
   204  	if req.GetPublicKey() != nil && len(req.GetPublicKey().GetNewPublicKey()) > 0 {
   205  		publicKey = &ChangePublicKey{
   206  			NewPublicKey: req.GetPublicKey().GetNewPublicKey(),
   207  			Statements:   getStatementsFromProto(req.GetPublicKey().GetStatements()),
   208  		}
   209  	}
   210  
   211  	var expiration *ChangeExpiration
   212  	if req.GetExpiration() != nil && req.GetExpiration().GetNewExpiration() != nil {
   213  		newExpiration, err := ptypes.Timestamp(req.GetExpiration().GetNewExpiration())
   214  		if err != nil {
   215  			return UpdateUserRequest{}, fmt.Errorf("unable to parse new expiration: %w", err)
   216  		}
   217  
   218  		expiration = &ChangeExpiration{
   219  			NewExpiration: newExpiration,
   220  			Statements:    getStatementsFromProto(req.GetExpiration().GetStatements()),
   221  		}
   222  	}
   223  
   224  	dbReq := UpdateUserRequest{
   225  		Username:       req.GetUsername(),
   226  		CredentialType: CredentialType(req.GetCredentialType()),
   227  		Password:       password,
   228  		PublicKey:      publicKey,
   229  		Expiration:     expiration,
   230  	}
   231  
   232  	if !hasChange(dbReq) {
   233  		return UpdateUserRequest{}, fmt.Errorf("update user request has no changes")
   234  	}
   235  
   236  	return dbReq, nil
   237  }
   238  
   239  func hasChange(dbReq UpdateUserRequest) bool {
   240  	if dbReq.Password != nil && dbReq.Password.NewPassword != "" {
   241  		return true
   242  	}
   243  	if dbReq.PublicKey != nil && len(dbReq.PublicKey.NewPublicKey) > 0 {
   244  		return true
   245  	}
   246  	if dbReq.Expiration != nil && !dbReq.Expiration.NewExpiration.IsZero() {
   247  		return true
   248  	}
   249  	return false
   250  }
   251  
   252  func (g *gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest) (*proto.DeleteUserResponse, error) {
   253  	if req.GetUsername() == "" {
   254  		return &proto.DeleteUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided")
   255  	}
   256  	dbReq := DeleteUserRequest{
   257  		Username:   req.GetUsername(),
   258  		Statements: getStatementsFromProto(req.GetStatements()),
   259  	}
   260  
   261  	impl, err := g.getDatabase(ctx)
   262  	if err != nil {
   263  		return nil, err
   264  	}
   265  
   266  	_, err = impl.DeleteUser(ctx, dbReq)
   267  	if err != nil {
   268  		return &proto.DeleteUserResponse{}, status.Errorf(codes.Internal, "unable to delete user: %s", err)
   269  	}
   270  	return &proto.DeleteUserResponse{}, nil
   271  }
   272  
   273  func (g *gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeResponse, error) {
   274  	impl, err := g.getOrCreateDatabase(ctx)
   275  	if err != nil {
   276  		return nil, err
   277  	}
   278  
   279  	t, err := impl.Type()
   280  	if err != nil {
   281  		return &proto.TypeResponse{}, status.Errorf(codes.Internal, "unable to retrieve type: %s", err)
   282  	}
   283  
   284  	resp := &proto.TypeResponse{
   285  		Type: t,
   286  	}
   287  	return resp, nil
   288  }
   289  
   290  func (g *gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, error) {
   291  	g.Lock()
   292  	defer g.Unlock()
   293  
   294  	impl, err := g.getDatabaseInternal(ctx)
   295  	if err != nil {
   296  		return nil, err
   297  	}
   298  
   299  	err = impl.Close()
   300  	if err != nil {
   301  		return &proto.Empty{}, status.Errorf(codes.Internal, "unable to close database plugin: %s", err)
   302  	}
   303  
   304  	if g.singleImpl == nil {
   305  		// only cleanup instances map when multiplexing is supported
   306  		id, err := pluginutil.GetMultiplexIDFromContext(ctx)
   307  		if err != nil {
   308  			return nil, err
   309  		}
   310  		delete(g.instances, id)
   311  	}
   312  
   313  	return &proto.Empty{}, nil
   314  }
   315  
   316  // getOrForceCreateDatabase will create a database even if the multiplexing ID is not present
   317  func (g *gRPCServer) getOrForceCreateDatabase(ctx context.Context) (Database, error) {
   318  	impl, err := g.getOrCreateDatabase(ctx)
   319  	if errors.Is(err, pluginutil.ErrNoMultiplexingIDFound) {
   320  		// if this is called without a multiplexing context, like from the plugin catalog directly,
   321  		// then we won't have a database ID, so let's generate a new database instance
   322  		id, err := base62.Random(10)
   323  		if err != nil {
   324  			return nil, err
   325  		}
   326  
   327  		g.Lock()
   328  		defer g.Unlock()
   329  		impl, err = g.createDatabase(id)
   330  		if err != nil {
   331  			return nil, err
   332  		}
   333  	} else if err != nil {
   334  		return nil, err
   335  	}
   336  	return impl, nil
   337  }
   338  
   339  // Version forwards the version request to the underlying Database implementation.
   340  func (g *gRPCServer) Version(ctx context.Context, _ *logical.Empty) (*logical.VersionReply, error) {
   341  	impl, err := g.getOrForceCreateDatabase(ctx)
   342  	if err != nil {
   343  		return nil, err
   344  	}
   345  
   346  	if versioner, ok := impl.(logical.PluginVersioner); ok {
   347  		return &logical.VersionReply{PluginVersion: versioner.PluginVersion().Version}, nil
   348  	}
   349  	return &logical.VersionReply{}, nil
   350  }
   351  
   352  func getStatementsFromProto(protoStmts *proto.Statements) (statements Statements) {
   353  	if protoStmts == nil {
   354  		return statements
   355  	}
   356  	cmds := protoStmts.GetCommands()
   357  	statements = Statements{
   358  		Commands: cmds,
   359  	}
   360  	return statements
   361  }