github.com/viant/toolbox@v0.34.5/kms/gcp/service.go (about)

     1  package gcp
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/base64"
     7  	"fmt"
     8  	"github.com/pkg/errors"
     9  	"github.com/viant/toolbox"
    10  	"github.com/viant/toolbox/kms"
    11  	"github.com/viant/toolbox/storage"
    12  	"github.com/viant/toolbox/url"
    13  	"google.golang.org/api/cloudkms/v1"
    14  	"google.golang.org/api/option"
    15  	"io/ioutil"
    16  )
    17  
    18  type KmsService interface {
    19  	Encrypt(ctx context.Context, key string, value string) (string, error)
    20  	Decrypt(ctx context.Context, key string, value string) (string, error)
    21  }
    22  
    23  func (k *kmsService) Encrypt(ctx context.Context, key string, plainText string) (string, error) {
    24  	kms, err := cloudkms.NewService(ctx, option.WithScopes(cloudkms.CloudPlatformScope, cloudkms.CloudkmsScope))
    25  	if err != nil {
    26  		return "", errors.Wrap(err, fmt.Sprintf("failed to create kms server for key %v", key))
    27  	}
    28  	service := cloudkms.NewProjectsLocationsKeyRingsCryptoKeysService(kms)
    29  
    30  	response, err := service.Encrypt(key, &cloudkms.EncryptRequest{Plaintext: plainText}).Context(ctx).Do()
    31  	if err != nil {
    32  		return "", errors.Wrap(err, fmt.Sprintf("failed to encrypt with key %v", key))
    33  	}
    34  	return response.Ciphertext, nil
    35  }
    36  
    37  func (k *kmsService) Decrypt(ctx context.Context, key string, plainText string) (string, error) {
    38  	kms, err := cloudkms.NewService(ctx, option.WithScopes(cloudkms.CloudPlatformScope, cloudkms.CloudkmsScope))
    39  	if err != nil {
    40  		return "", errors.Wrap(err, fmt.Sprintf("failed to create kms server for key %v", key))
    41  	}
    42  	service := cloudkms.NewProjectsLocationsKeyRingsCryptoKeysService(kms)
    43  	response, err := service.Decrypt(key, &cloudkms.DecryptRequest{Ciphertext: plainText}).Context(ctx).Do()
    44  	if err != nil {
    45  		return "", errors.Wrap(err, fmt.Sprintf("failed to encrypt with key %v", key))
    46  	}
    47  	return response.Plaintext, nil
    48  }
    49  
    50  type kmsService struct{}
    51  
    52  type service struct {
    53  	KmsService
    54  }
    55  
    56  //New returns service
    57  func New() kms.Service {
    58  	return newService()
    59  }
    60  
    61  func newService() kms.Service {
    62  	return &service{KmsService: &kmsService{}}
    63  }
    64  
    65  func (s *service) Decode(ctx context.Context, decryptRequest *kms.DecryptRequest, factory toolbox.DecoderFactory, target interface{}) error {
    66  	response, err := s.Decrypt(ctx, decryptRequest)
    67  	if err != nil {
    68  		return err
    69  	}
    70  	reader := bytes.NewReader(response.Data)
    71  	return factory.Create(reader).Decode(target)
    72  }
    73  
    74  func (s *service) Encrypt(ctx context.Context, request *kms.EncryptRequest) (*kms.EncryptResponse, error) {
    75  
    76  	if request.URL != "" {
    77  		data, err := getDataFromURL(request.URL)
    78  		if err != nil {
    79  			return nil, err
    80  		}
    81  		if data == nil || len(data) == 0 {
    82  			return nil, fmt.Errorf("data empty in the encrypt")
    83  		}
    84  		request.Data = data
    85  
    86  	}
    87  	plainText := getBase64(request.Data)
    88  	encryptedText, err := s.KmsService.Encrypt(ctx, request.Key, plainText)
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  	if encryptedText == "" {
    93  		return nil, fmt.Errorf("encryptedText was empty")
    94  	}
    95  
    96  	if request.TargetURL != "" {
    97  		err = upload(request.TargetURL, encryptedText)
    98  		if err != nil {
    99  			return nil, err
   100  		}
   101  	}
   102  	encryptedData, err := base64.StdEncoding.DecodeString(encryptedText)
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  	return &kms.EncryptResponse{
   107  		EncryptedData: encryptedData,
   108  		EncryptedText: encryptedText,
   109  	}, nil
   110  }
   111  
   112  func (s *service) Decrypt(ctx context.Context, request *kms.DecryptRequest) (*kms.DecryptResponse, error) {
   113  	if request.URL != "" {
   114  		resource := url.NewResource(request.URL)
   115  		base64Text, err := resource.DownloadBase64()
   116  		if err != nil {
   117  			return nil, err
   118  		}
   119  		request.Data = []byte(base64Text)
   120  	} else if len(request.Data) > 0 {
   121  		base64Text := getBase64(request.Data)
   122  		request.Data = []byte(base64Text)
   123  	}
   124  	plainText := string(request.Data)
   125  	text, err := s.KmsService.Decrypt(ctx, request.Key, plainText)
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  	if text == "" {
   130  		return nil, fmt.Errorf("no text in the decrypt")
   131  	}
   132  
   133  	data, err := base64.StdEncoding.DecodeString(text)
   134  	if err != nil {
   135  		return nil, errors.Wrap(err, fmt.Sprintf("failed to base64 decode text %v", text))
   136  	}
   137  	decryptResponse := &kms.DecryptResponse{
   138  		Data: data,
   139  		Text: text,
   140  	}
   141  	return decryptResponse, nil
   142  }
   143  
   144  func getBase64(data []byte) string {
   145  	plainText := string(data)
   146  	isBase64 := false
   147  	if _, err := base64.StdEncoding.DecodeString(string(data)); err == nil {
   148  		isBase64 = true
   149  	}
   150  
   151  	if !isBase64 {
   152  		plainText = base64.StdEncoding.EncodeToString(data)
   153  	}
   154  	return plainText
   155  }
   156  
   157  func upload(targetURL string, encryptedText string) error {
   158  	storageService, err := storage.NewServiceForURL(targetURL, "")
   159  	if err != nil {
   160  		return errors.Wrap(err, fmt.Sprintf("failed to get storage for url %v", targetURL))
   161  	}
   162  	return storageService.Upload(targetURL, bytes.NewReader([]byte(encryptedText)))
   163  }
   164  
   165  func getDataFromURL(URL string) ([]byte, error) {
   166  	storageService, err := storage.NewServiceForURL(URL, "")
   167  	if err != nil {
   168  		return nil, errors.Wrap(err, fmt.Sprintf("failed to create storage for url: %v", URL))
   169  	}
   170  	reader, err := storageService.DownloadWithURL(URL)
   171  	if err != nil {
   172  		return nil, errors.Wrap(err, fmt.Sprintf("failed to download url: %v", URL))
   173  	}
   174  	data, err := ioutil.ReadAll(reader)
   175  	if err != nil {
   176  		return nil, errors.Wrap(err, fmt.Sprintf("failed to read data from %v", URL))
   177  	}
   178  	return data, nil
   179  
   180  }