github.com/hashicorp/vault/sdk@v0.13.0/database/dbplugin/v5/grpc_client.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package dbplugin
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"fmt"
    10  	"time"
    11  
    12  	"github.com/golang/protobuf/ptypes"
    13  	"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
    14  	"github.com/hashicorp/vault/sdk/helper/pluginutil"
    15  	"github.com/hashicorp/vault/sdk/logical"
    16  )
    17  
    18  var (
    19  	_ Database                = gRPCClient{}
    20  	_ logical.PluginVersioner = gRPCClient{}
    21  
    22  	ErrPluginShutdown = errors.New("plugin shutdown")
    23  )
    24  
    25  type gRPCClient struct {
    26  	client        proto.DatabaseClient
    27  	versionClient logical.PluginVersionClient
    28  	doneCtx       context.Context
    29  }
    30  
    31  func (c gRPCClient) PluginVersion() logical.PluginVersion {
    32  	version, _ := c.versionClient.Version(context.Background(), &logical.Empty{})
    33  	if version != nil {
    34  		return logical.PluginVersion{Version: version.PluginVersion}
    35  	}
    36  	return logical.EmptyPluginVersion
    37  }
    38  
    39  func (c gRPCClient) Initialize(ctx context.Context, req InitializeRequest) (InitializeResponse, error) {
    40  	rpcReq, err := initReqToProto(req)
    41  	if err != nil {
    42  		return InitializeResponse{}, err
    43  	}
    44  
    45  	rpcResp, err := c.client.Initialize(ctx, rpcReq)
    46  	if err != nil {
    47  		return InitializeResponse{}, fmt.Errorf("unable to initialize: %s", err.Error())
    48  	}
    49  
    50  	return initRespFromProto(rpcResp)
    51  }
    52  
    53  func initReqToProto(req InitializeRequest) (*proto.InitializeRequest, error) {
    54  	config, err := mapToStruct(req.Config)
    55  	if err != nil {
    56  		return nil, fmt.Errorf("unable to marshal config: %w", err)
    57  	}
    58  
    59  	rpcReq := &proto.InitializeRequest{
    60  		ConfigData:       config,
    61  		VerifyConnection: req.VerifyConnection,
    62  	}
    63  	return rpcReq, nil
    64  }
    65  
    66  func initRespFromProto(rpcResp *proto.InitializeResponse) (InitializeResponse, error) {
    67  	newConfig := structToMap(rpcResp.GetConfigData())
    68  
    69  	resp := InitializeResponse{
    70  		Config: newConfig,
    71  	}
    72  	return resp, nil
    73  }
    74  
    75  func (c gRPCClient) NewUser(ctx context.Context, req NewUserRequest) (NewUserResponse, error) {
    76  	ctx, cancel := context.WithCancel(ctx)
    77  	quitCh := pluginutil.CtxCancelIfCanceled(cancel, c.doneCtx)
    78  	defer close(quitCh)
    79  	defer cancel()
    80  
    81  	rpcReq, err := newUserReqToProto(req)
    82  	if err != nil {
    83  		return NewUserResponse{}, err
    84  	}
    85  
    86  	rpcResp, err := c.client.NewUser(ctx, rpcReq)
    87  	if err != nil {
    88  		if c.doneCtx.Err() != nil {
    89  			return NewUserResponse{}, ErrPluginShutdown
    90  		}
    91  		return NewUserResponse{}, fmt.Errorf("unable to create new user: %w", err)
    92  	}
    93  
    94  	return newUserRespFromProto(rpcResp)
    95  }
    96  
    97  func newUserReqToProto(req NewUserRequest) (*proto.NewUserRequest, error) {
    98  	switch req.CredentialType {
    99  	case CredentialTypePassword:
   100  		if req.Password == "" {
   101  			return nil, fmt.Errorf("missing password credential")
   102  		}
   103  	case CredentialTypeRSAPrivateKey:
   104  		if len(req.PublicKey) == 0 {
   105  			return nil, fmt.Errorf("missing public key credential")
   106  		}
   107  	case CredentialTypeClientCertificate:
   108  		if req.Subject == "" {
   109  			return nil, fmt.Errorf("missing certificate subject")
   110  		}
   111  	default:
   112  		return nil, fmt.Errorf("unknown credential type")
   113  	}
   114  
   115  	expiration, err := ptypes.TimestampProto(req.Expiration)
   116  	if err != nil {
   117  		return nil, fmt.Errorf("unable to marshal expiration date: %w", err)
   118  	}
   119  
   120  	rpcReq := &proto.NewUserRequest{
   121  		UsernameConfig: &proto.UsernameConfig{
   122  			DisplayName: req.UsernameConfig.DisplayName,
   123  			RoleName:    req.UsernameConfig.RoleName,
   124  		},
   125  		CredentialType: int32(req.CredentialType),
   126  		Password:       req.Password,
   127  		PublicKey:      req.PublicKey,
   128  		Subject:        req.Subject,
   129  		Expiration:     expiration,
   130  		Statements: &proto.Statements{
   131  			Commands: req.Statements.Commands,
   132  		},
   133  		RollbackStatements: &proto.Statements{
   134  			Commands: req.RollbackStatements.Commands,
   135  		},
   136  	}
   137  	return rpcReq, nil
   138  }
   139  
   140  func newUserRespFromProto(rpcResp *proto.NewUserResponse) (NewUserResponse, error) {
   141  	resp := NewUserResponse{
   142  		Username: rpcResp.GetUsername(),
   143  	}
   144  	return resp, nil
   145  }
   146  
   147  func (c gRPCClient) UpdateUser(ctx context.Context, req UpdateUserRequest) (UpdateUserResponse, error) {
   148  	rpcReq, err := updateUserReqToProto(req)
   149  	if err != nil {
   150  		return UpdateUserResponse{}, err
   151  	}
   152  
   153  	rpcResp, err := c.client.UpdateUser(ctx, rpcReq)
   154  	if err != nil {
   155  		if c.doneCtx.Err() != nil {
   156  			return UpdateUserResponse{}, ErrPluginShutdown
   157  		}
   158  
   159  		return UpdateUserResponse{}, fmt.Errorf("unable to update user: %w", err)
   160  	}
   161  
   162  	return updateUserRespFromProto(rpcResp)
   163  }
   164  
   165  func updateUserReqToProto(req UpdateUserRequest) (*proto.UpdateUserRequest, error) {
   166  	if req.Username == "" {
   167  		return nil, fmt.Errorf("missing username")
   168  	}
   169  
   170  	if (req.Password == nil || req.Password.NewPassword == "") &&
   171  		(req.PublicKey == nil || len(req.PublicKey.NewPublicKey) == 0) &&
   172  		(req.Expiration == nil || req.Expiration.NewExpiration.IsZero()) {
   173  		return nil, fmt.Errorf("missing changes")
   174  	}
   175  
   176  	expiration, err := expirationToProto(req.Expiration)
   177  	if err != nil {
   178  		return nil, fmt.Errorf("unable to parse new expiration date: %w", err)
   179  	}
   180  
   181  	var password *proto.ChangePassword
   182  	if req.Password != nil && req.Password.NewPassword != "" {
   183  		password = &proto.ChangePassword{
   184  			NewPassword: req.Password.NewPassword,
   185  			Statements: &proto.Statements{
   186  				Commands: req.Password.Statements.Commands,
   187  			},
   188  		}
   189  	}
   190  
   191  	var publicKey *proto.ChangePublicKey
   192  	if req.PublicKey != nil && len(req.PublicKey.NewPublicKey) > 0 {
   193  		publicKey = &proto.ChangePublicKey{
   194  			NewPublicKey: req.PublicKey.NewPublicKey,
   195  			Statements: &proto.Statements{
   196  				Commands: req.PublicKey.Statements.Commands,
   197  			},
   198  		}
   199  	}
   200  
   201  	rpcReq := &proto.UpdateUserRequest{
   202  		Username:       req.Username,
   203  		CredentialType: int32(req.CredentialType),
   204  		Password:       password,
   205  		PublicKey:      publicKey,
   206  		Expiration:     expiration,
   207  	}
   208  	return rpcReq, nil
   209  }
   210  
   211  func updateUserRespFromProto(rpcResp *proto.UpdateUserResponse) (UpdateUserResponse, error) {
   212  	// Placeholder for future conversion if data is returned
   213  	return UpdateUserResponse{}, nil
   214  }
   215  
   216  func expirationToProto(exp *ChangeExpiration) (*proto.ChangeExpiration, error) {
   217  	if exp == nil {
   218  		return nil, nil
   219  	}
   220  
   221  	expiration, err := ptypes.TimestampProto(exp.NewExpiration)
   222  	if err != nil {
   223  		return nil, err
   224  	}
   225  
   226  	changeExp := &proto.ChangeExpiration{
   227  		NewExpiration: expiration,
   228  		Statements: &proto.Statements{
   229  			Commands: exp.Statements.Commands,
   230  		},
   231  	}
   232  	return changeExp, nil
   233  }
   234  
   235  func (c gRPCClient) DeleteUser(ctx context.Context, req DeleteUserRequest) (DeleteUserResponse, error) {
   236  	rpcReq, err := deleteUserReqToProto(req)
   237  	if err != nil {
   238  		return DeleteUserResponse{}, err
   239  	}
   240  
   241  	rpcResp, err := c.client.DeleteUser(ctx, rpcReq)
   242  	if err != nil {
   243  		if c.doneCtx.Err() != nil {
   244  			return DeleteUserResponse{}, ErrPluginShutdown
   245  		}
   246  		return DeleteUserResponse{}, fmt.Errorf("unable to delete user: %w", err)
   247  	}
   248  
   249  	return deleteUserRespFromProto(rpcResp)
   250  }
   251  
   252  func deleteUserReqToProto(req DeleteUserRequest) (*proto.DeleteUserRequest, error) {
   253  	if req.Username == "" {
   254  		return nil, fmt.Errorf("missing username")
   255  	}
   256  
   257  	rpcReq := &proto.DeleteUserRequest{
   258  		Username: req.Username,
   259  		Statements: &proto.Statements{
   260  			Commands: req.Statements.Commands,
   261  		},
   262  	}
   263  	return rpcReq, nil
   264  }
   265  
   266  func deleteUserRespFromProto(rpcResp *proto.DeleteUserResponse) (DeleteUserResponse, error) {
   267  	// Placeholder for future conversion if data is returned
   268  	return DeleteUserResponse{}, nil
   269  }
   270  
   271  func (c gRPCClient) Type() (string, error) {
   272  	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
   273  	defer cancel()
   274  
   275  	typeResp, err := c.client.Type(ctx, &proto.Empty{})
   276  	if err != nil {
   277  		if c.doneCtx.Err() != nil {
   278  			return "", ErrPluginShutdown
   279  		}
   280  		return "", fmt.Errorf("unable to get database plugin type: %w", err)
   281  	}
   282  	return typeResp.GetType(), nil
   283  }
   284  
   285  func (c gRPCClient) Close() error {
   286  	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
   287  	defer cancel()
   288  
   289  	_, err := c.client.Close(ctx, &proto.Empty{})
   290  	if err != nil {
   291  		if c.doneCtx.Err() != nil {
   292  			return ErrPluginShutdown
   293  		}
   294  		return err
   295  	}
   296  	return nil
   297  }