github.com/tommi2day/gomodules@v1.13.2-0.20240423190010-b7d55d252a27/pwlib/kms.go (about)

     1  package pwlib
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"encoding/base64"
     7  	"errors"
     8  	"fmt"
     9  	"os"
    10  	"strings"
    11  
    12  	"github.com/tommi2day/gomodules/common"
    13  
    14  	"github.com/Luzifer/go-openssl/v4"
    15  	"github.com/aws/aws-sdk-go-v2/aws"
    16  	"github.com/aws/aws-sdk-go-v2/config"
    17  	"github.com/aws/aws-sdk-go-v2/service/kms"
    18  	"github.com/aws/aws-sdk-go-v2/service/kms/types"
    19  	"github.com/aws/smithy-go"
    20  	log "github.com/sirupsen/logrus"
    21  )
    22  
    23  // KmsEndpoint is the alternative endpoint for the KMS service
    24  var KmsEndpoint = ""
    25  
    26  const aliasPrefix = "alias/"
    27  
    28  // ConnectToKMS Establish a connection to AWS KMS
    29  func ConnectToKMS() (svc *kms.Client) {
    30  	log.Debugf("Connect to KMS")
    31  	cfg, err := config.LoadDefaultConfig(context.TODO())
    32  	if err != nil {
    33  		log.Fatal(err)
    34  	}
    35  	ep := common.GetStringEnv("KMS_ENDPOINT", "")
    36  	if ep != "" {
    37  		KmsEndpoint = ep
    38  	}
    39  	svc = kms.NewFromConfig(cfg, func(o *kms.Options) {
    40  		if KmsEndpoint != "" {
    41  			log.Debugf("use KMS Endpoint %s", KmsEndpoint)
    42  			o.BaseEndpoint = aws.String(KmsEndpoint)
    43  		}
    44  	})
    45  	return svc
    46  }
    47  
    48  func checkOperationError(err error) error {
    49  	var oe *smithy.OperationError
    50  	e := err
    51  	if errors.As(err, &oe) {
    52  		e = fmt.Errorf("failed to call service: %s, operation: %s, error: %v", oe.Service(), oe.Operation(), oe.Unwrap())
    53  	}
    54  	log.Debugf("OperationError:%v", e)
    55  	return e
    56  }
    57  
    58  // ListKMSKeys List all KMS keys
    59  func ListKMSKeys(svc *kms.Client) ([]types.KeyListEntry, error) {
    60  	log.Debugf("List KMS Keys")
    61  	if svc == nil {
    62  		return nil, errors.New("KMS service is nil")
    63  	}
    64  	output, err := svc.ListKeys(context.TODO(), &kms.ListKeysInput{})
    65  	if err != nil {
    66  		e := checkOperationError(err)
    67  		return nil, e
    68  	}
    69  	if output == nil {
    70  		return nil, errors.New("listKeys returned nil")
    71  	}
    72  	keys := output.Keys
    73  	log.Debugf("ListKeys returned %d entries", len(keys))
    74  	return keys, nil
    75  }
    76  
    77  // GetKMSKeyIDs Get KMS IDs from key meta data
    78  func GetKMSKeyIDs(metadata *types.KeyMetadata) (keyID string, keyARN string) {
    79  	if metadata == nil {
    80  		return
    81  	}
    82  	keyID = *metadata.KeyId
    83  	keyARN = *metadata.Arn
    84  	return
    85  }
    86  
    87  // DescribeKMSKey Describe a KMS key
    88  func DescribeKMSKey(svc *kms.Client, keyID string) (*kms.DescribeKeyOutput, error) {
    89  	if svc == nil {
    90  		return nil, errors.New("KMS service is nil")
    91  	}
    92  	if keyID == "" {
    93  		return nil, errors.New("keyID is empty")
    94  	}
    95  	log.Debugf("Describe KMS Key:%s", keyID)
    96  	output, err := svc.DescribeKey(context.TODO(), &kms.DescribeKeyInput{
    97  		KeyId: aws.String(keyID),
    98  	})
    99  	if err != nil {
   100  		e := checkOperationError(err)
   101  		return nil, e
   102  	}
   103  	log.Debugf("Describe Key was OK")
   104  	return output, nil
   105  }
   106  
   107  // ListKMSAliases List all KMS aliases
   108  func ListKMSAliases(svc *kms.Client, keyID string) ([]types.AliasListEntry, error) {
   109  	if svc == nil {
   110  		return nil, errors.New("KMS service is nil")
   111  	}
   112  	ip := &kms.ListAliasesInput{}
   113  	if keyID != "" {
   114  		ip = &kms.ListAliasesInput{
   115  			KeyId: aws.String(keyID),
   116  		}
   117  	}
   118  	log.Debugf("List KMS Key Aliases for %s", keyID)
   119  	output, err := svc.ListAliases(context.TODO(), ip)
   120  	if err != nil {
   121  		e := checkOperationError(err)
   122  		return nil, e
   123  	}
   124  	if output == nil {
   125  		return nil, errors.New("listAliases returned nil")
   126  	}
   127  	aliases := output.Aliases
   128  	log.Debugf("ListAliases returned %d entries", len(aliases))
   129  	return aliases, nil
   130  }
   131  
   132  // CreateKMSAlias Create a KMS alias
   133  func CreateKMSAlias(svc *kms.Client, aliasName string, targetKeyID string) (*kms.CreateAliasOutput, error) {
   134  	if svc == nil {
   135  		return nil, errors.New("KMS service is nil")
   136  	}
   137  	if targetKeyID == "" {
   138  		return nil, errors.New("targetKeyID is empty")
   139  	}
   140  	if aliasName == "" {
   141  		return nil, errors.New("aliasName is empty")
   142  	}
   143  	log.Debugf("Create KMS Alias:%s for Key:%s", aliasName, targetKeyID)
   144  	if !strings.HasPrefix(aliasName, aliasPrefix) {
   145  		aliasName = aliasPrefix + aliasName
   146  	}
   147  	output, err := svc.CreateAlias(context.TODO(), &kms.CreateAliasInput{
   148  		AliasName:   aws.String(aliasName),
   149  		TargetKeyId: aws.String(targetKeyID),
   150  	})
   151  	if err != nil {
   152  		e := checkOperationError(err)
   153  		return nil, e
   154  	}
   155  	log.Debugf("Create Alias was OK")
   156  	return output, nil
   157  }
   158  
   159  // DeleteKMSAlias Delete a KMS alias
   160  func DeleteKMSAlias(svc *kms.Client, aliasName string) (*kms.DeleteAliasOutput, error) {
   161  	if svc == nil {
   162  		return nil, errors.New("KMS service is nil")
   163  	}
   164  	if aliasName == "" {
   165  		return nil, errors.New("aliasName is empty")
   166  	}
   167  	if !strings.HasPrefix(aliasName, aliasPrefix) {
   168  		aliasName = aliasPrefix + aliasName
   169  	}
   170  	log.Debugf("Delete KMS Alias:%s", aliasName)
   171  	output, err := svc.DeleteAlias(context.TODO(), &kms.DeleteAliasInput{
   172  		AliasName: aws.String(aliasName),
   173  	})
   174  	if err != nil {
   175  		e := checkOperationError(err)
   176  		return nil, e
   177  	}
   178  	log.Debugf("Delete Alias was OK")
   179  	return output, nil
   180  }
   181  
   182  // GetKMSAliasIDs Get KMS AliasIDs from Alias Entry
   183  func GetKMSAliasIDs(entry *types.AliasListEntry) (targetKeyID string, aliasName string, aliasARN string) {
   184  	log.Debugf("Get KMS IDs from Aliias entry")
   185  	if entry == nil {
   186  		return
   187  	}
   188  	aliasName = *entry.AliasName
   189  	targetKeyID = *entry.TargetKeyId
   190  	aliasARN = *entry.AliasArn
   191  	log.Debugf("Entry %s point to target %s", aliasName, targetKeyID)
   192  	return
   193  }
   194  
   195  // DescribeKMSAlias Search and Describe a KMS alias
   196  func DescribeKMSAlias(svc *kms.Client, aliasName string) (*types.AliasListEntry, error) {
   197  	if svc == nil {
   198  		return nil, errors.New("KMS service is nil")
   199  	}
   200  	if aliasName == "" {
   201  		return nil, errors.New("aliasName is empty")
   202  	}
   203  	if !strings.HasPrefix(aliasName, "alias/") {
   204  		aliasName = aliasPrefix + aliasName
   205  	}
   206  	log.Debugf("Describe KMS Alias:%s", aliasName)
   207  	output, err := svc.ListAliases(context.TODO(), &kms.ListAliasesInput{})
   208  	if err != nil {
   209  		e := checkOperationError(err)
   210  		return nil, e
   211  	}
   212  	if output == nil {
   213  		return nil, errors.New("listAliases returned nil")
   214  	}
   215  	aliases := output.Aliases
   216  	for _, a := range aliases {
   217  		if *a.AliasName == aliasName {
   218  			log.Debugf("Alias %s found", aliasName)
   219  			return &a, nil
   220  		}
   221  	}
   222  	log.Debugf("Alias %s not found", aliasName)
   223  	return nil, fmt.Errorf("alias %s not found", aliasName)
   224  }
   225  
   226  // GenKMSKey Create a new KMS key
   227  func GenKMSKey(svc *kms.Client, keyspec string, description string, tags map[string]string) (*kms.CreateKeyOutput, error) {
   228  	log.Debugf("Create KMS Key")
   229  	if svc == nil {
   230  		return nil, errors.New("KMS service is nil")
   231  	}
   232  	var keytags []types.Tag
   233  	for k, v := range tags {
   234  		kt := types.Tag{
   235  			TagKey:   aws.String(k),
   236  			TagValue: aws.String(v),
   237  		}
   238  		keytags = append(keytags, kt)
   239  	}
   240  	keyOutput, err := svc.CreateKey(
   241  		context.TODO(),
   242  		&kms.CreateKeyInput{
   243  			Description: aws.String(description),
   244  			KeySpec:     types.KeySpec(keyspec),
   245  			Tags:        keytags,
   246  		},
   247  	)
   248  	if err != nil {
   249  		e := checkOperationError(err)
   250  		return nil, e
   251  	}
   252  	log.Debugf("Create Key was OK")
   253  	return keyOutput, nil
   254  }
   255  
   256  // KMSEncryptString Encrypt a string using the KMS key
   257  func KMSEncryptString(svc *kms.Client, keyID string, plaintext string) (string, error) {
   258  	if svc == nil {
   259  		return "", errors.New("KMS service is nil")
   260  	}
   261  	if keyID == "" {
   262  		return "", errors.New("keyID is empty")
   263  	}
   264  	log.Debugf("Encrypt with KMS key:%s", keyID)
   265  	output, err := svc.Encrypt(
   266  		context.TODO(),
   267  		&kms.EncryptInput{
   268  			KeyId:     aws.String(keyID),
   269  			Plaintext: []byte(plaintext),
   270  		},
   271  	)
   272  	if err != nil {
   273  		e := checkOperationError(err)
   274  		return "", e
   275  	}
   276  	log.Debugf("Encrypt was OK")
   277  	return string(output.CiphertextBlob), nil
   278  }
   279  
   280  // KMSDecryptString Decrypt a string using the KMS key
   281  func KMSDecryptString(svc *kms.Client, keyID string, ciphertext string) (string, error) {
   282  	if svc == nil {
   283  		return "", errors.New("KMS service is nil")
   284  	}
   285  	if keyID == "" {
   286  		return "", errors.New("keyID is empty")
   287  	}
   288  	if ciphertext == "" {
   289  		return "", errors.New("ciphertext is empty")
   290  	}
   291  	log.Debugf("Decrypt with KMS key:%s", keyID)
   292  	output, err := svc.Decrypt(
   293  		context.TODO(),
   294  		&kms.DecryptInput{
   295  			KeyId:          aws.String(keyID),
   296  			CiphertextBlob: []byte(ciphertext),
   297  		},
   298  	)
   299  	if err != nil {
   300  		e := checkOperationError(err)
   301  		return "", e
   302  	}
   303  	log.Debugf("Decrypt was OK")
   304  	return string(output.Plaintext), nil
   305  }
   306  
   307  // KMSEncryptFile Encrypt a file using the KMS key
   308  func KMSEncryptFile(plainFile string, targetFile string, keyID string, sessionPassFile string) (err error) {
   309  	const rb = 16
   310  	log.Debugf("Encrypt %s with KMS key %s in OpenSSL compatible format", plainFile, keyID)
   311  	if keyID == "" || plainFile == "" || targetFile == "" {
   312  		err = fmt.Errorf("keyID, plainFile or targetFile is empty")
   313  		log.Debug(err)
   314  		return
   315  	}
   316  	svc := ConnectToKMS()
   317  	if svc == nil {
   318  		err = fmt.Errorf("cannot connect to KMS")
   319  		log.Debug(err)
   320  		return
   321  	}
   322  	random := make([]byte, rb)
   323  	_, err = rand.Read(random)
   324  	if err != nil {
   325  		log.Debugf("Cannot generate session key:%s", err)
   326  		return
   327  	}
   328  	sessionKey := base64.StdEncoding.EncodeToString(random)
   329  	crypted, err := KMSEncryptString(svc, keyID, sessionKey)
   330  	if err != nil {
   331  		log.Errorf("Encrypting Keyfile failed: %v", err)
   332  	}
   333  
   334  	if len(sessionPassFile) > 0 {
   335  		//nolint gosec
   336  		err = os.WriteFile(sessionPassFile, []byte(crypted), 0644)
   337  		if err != nil {
   338  			log.Errorf("Cannot write session Key file %s:%v", sessionPassFile, err)
   339  		}
   340  	}
   341  
   342  	//nolint gosec
   343  	plaindata, err := os.ReadFile(plainFile)
   344  	if err != nil {
   345  		log.Debugf("Cannot read plaintext file %s:%s", plainFile, err)
   346  		return
   347  	}
   348  
   349  	o := openssl.New()
   350  	// openssl enc -e -aes-256-cbc -md sha246 -base64 -in $SOURCE -out $TARGET -pass pass:$PASSPHRASE
   351  	encrypted, err := o.EncryptBytes(sessionKey, plaindata, SSLDigest)
   352  	if err != nil {
   353  		log.Errorf("cannot encrypt plaintext file %s:%s", plainFile, err)
   354  		return
   355  	}
   356  	// write crypted output file
   357  	//nolint gosec
   358  	err = os.WriteFile(targetFile, encrypted, 0644)
   359  	if err != nil {
   360  		log.Errorf("Cannot write: %s", err.Error())
   361  		return
   362  	}
   363  	return
   364  }
   365  
   366  // KMSDecryptFile Decrypt a file using the KMS key
   367  func KMSDecryptFile(cryptedFile string, keyID string, sessionPassFile string) (content string, err error) {
   368  	if cryptedFile == "" || sessionPassFile == "" || keyID == "" {
   369  		err = fmt.Errorf("keyID, crypted filename or sessionpassfilename is empty")
   370  		log.Debug(err)
   371  		return
   372  	}
   373  	log.Debugf("decrypt %s with KMS key %s", cryptedFile, keyID)
   374  	svc := ConnectToKMS()
   375  	if svc == nil {
   376  		err = fmt.Errorf("cannot connect to KMS")
   377  		log.Debug(err)
   378  		return
   379  	}
   380  	//nolint gosec
   381  	cryptedData, err := os.ReadFile(cryptedFile)
   382  	if err != nil {
   383  		log.Debugf("cannot Read file '%s': %s", cryptedFile, err)
   384  		return
   385  	}
   386  	encSessionKey := ""
   387  
   388  	var sp []byte
   389  	//nolint gosec
   390  	sp, err = os.ReadFile(sessionPassFile)
   391  	if err != nil {
   392  		log.Debugf("cannot Read file '%s': %s", sessionPassFile, err)
   393  		return
   394  	}
   395  	encSessionKey = string(sp)
   396  
   397  	if err != nil {
   398  		log.Debugf("Cannot Read file '%s': %s", sessionPassFile, err)
   399  		return
   400  	}
   401  
   402  	sessionKey, err := KMSDecryptString(svc, keyID, encSessionKey)
   403  	if err != nil {
   404  		log.Debugf("decode session key failed:%s", err)
   405  		return
   406  	}
   407  	// sk := string(sessionKey)
   408  	log.Debug("Session key decrypted")
   409  
   410  	// OPENSSL enc -d -aes-256-cbc -md sha256 -base64 -in $SOURCE -pass pass:$SESSIONKEY
   411  	o := openssl.New()
   412  	decoded, err := o.DecryptBytes(sessionKey, cryptedData, SSLDigest)
   413  	if err != nil {
   414  		log.Debugf("Cannot decrypt data from '%s': %s", cryptedFile, err)
   415  		return
   416  	}
   417  	content = string(decoded)
   418  	log.Debug("Decoding successfully")
   419  	return
   420  }