github.com/viant/toolbox@v0.34.5/kms/aws/service.go (about)

     1  package aws
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"github.com/aws/aws-sdk-go/aws"
     8  	"github.com/aws/aws-sdk-go/aws/session"
     9  	akms "github.com/aws/aws-sdk-go/service/kms"
    10  	"github.com/aws/aws-sdk-go/service/ssm"
    11  	"github.com/pkg/errors"
    12  	"strings"
    13  
    14  	"github.com/viant/toolbox"
    15  	"github.com/viant/toolbox/kms"
    16  )
    17  
    18  type service struct {
    19  	*ssm.SSM
    20  	*akms.KMS
    21  }
    22  
    23  func (s *service) Encrypt(ctx context.Context, request *kms.EncryptRequest) (*kms.EncryptResponse, error) {
    24  	err := request.Validate()
    25  	if err != nil {
    26  		return nil, errors.Wrap(err, "invalid encrypt request")
    27  	}
    28  	if request.Parameter == "" {
    29  		return nil, errors.New("parameter was empty")
    30  	}
    31  	response := &kms.EncryptResponse{}
    32  	err = s.putParameters(request.Key, request.Parameter, string(request.Data))
    33  	if err == nil {
    34  		parameter, err := s.getParameters(request.Parameter, false)
    35  		if err != nil {
    36  			return nil, err
    37  		}
    38  		response.EncryptedText = *parameter.Value
    39  		response.EncryptedData = []byte(response.EncryptedText)
    40  	}
    41  	return response, err
    42  }
    43  
    44  func (s *service) Decrypt(ctx context.Context, request *kms.DecryptRequest) (*kms.DecryptResponse, error) {
    45  	err := request.Validate()
    46  	if err != nil {
    47  		return nil, errors.Wrap(err, "invalid encrypt request")
    48  	}
    49  	if request.Parameter == "" {
    50  		return nil, errors.New("parameter was empty")
    51  	}
    52  	response := &kms.DecryptResponse{}
    53  	parameter, err := s.getParameters(request.Parameter, true)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  	response.Text = *parameter.Value
    58  	response.Data = []byte(response.Text)
    59  	return response, nil
    60  }
    61  
    62  func (s *service) Decode(ctx context.Context, decryptRequest *kms.DecryptRequest, factory toolbox.DecoderFactory, target interface{}) error {
    63  	response, err := s.Decrypt(ctx, decryptRequest)
    64  	if err != nil {
    65  		return err
    66  	}
    67  	reader := bytes.NewReader(response.Data)
    68  	return factory.Create(reader).Decode(target)
    69  }
    70  
    71  func (s *service) putParameters(keyOrAlias, name, value string) error {
    72  	targetKeyID, err := s.getKeyByAlias(keyOrAlias)
    73  	if err != nil {
    74  		return err
    75  	}
    76  	_, err = s.PutParameter(&ssm.PutParameterInput{
    77  		Name:  aws.String(name),
    78  		KeyId: &targetKeyID,
    79  		Value: &value,
    80  	})
    81  	return err
    82  }
    83  
    84  func (s *service) getKeyByAlias(keyOrAlias string) (string, error) {
    85  	if strings.Count(keyOrAlias, ":") > 0 {
    86  		return keyOrAlias, nil
    87  	}
    88  	var nextMarker *string
    89  	for {
    90  		output, err := s.ListAliases(&akms.ListAliasesInput{
    91  			Marker: nextMarker,
    92  		})
    93  		if err != nil {
    94  			return "", err
    95  		}
    96  		if len(output.Aliases) == 0 {
    97  			break
    98  		}
    99  		for _, candidate := range output.Aliases {
   100  			if *candidate.AliasName == keyOrAlias {
   101  				return *candidate.TargetKeyId, nil
   102  			}
   103  		}
   104  		nextMarker = output.NextMarker
   105  		if nextMarker == nil {
   106  			break
   107  		}
   108  	}
   109  	return "", fmt.Errorf("key for alias %v no found", keyOrAlias)
   110  }
   111  
   112  func (s *service) getParameters(name string, withDecryption bool) (*ssm.Parameter, error) {
   113  	output, err := s.GetParameter(&ssm.GetParameterInput{
   114  		Name:           aws.String(name),
   115  		WithDecryption: &withDecryption,
   116  	})
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  	return output.Parameter, nil
   121  }
   122  
   123  //New returns new kms service
   124  func New() (kms.Service, error) {
   125  	sess, err := session.NewSession()
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  	return &service{
   130  		SSM: ssm.New(sess),
   131  		KMS: akms.New(sess),
   132  	}, nil
   133  }