github.com/blend/go-sdk@v1.20220411.3/vault/transit.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package vault
     9  
    10  import (
    11  	"context"
    12  	"encoding/base64"
    13  	"encoding/json"
    14  	"path/filepath"
    15  
    16  	"github.com/blend/go-sdk/ex"
    17  )
    18  
    19  // Assert Transit implements TransitClient
    20  var (
    21  	_ TransitClient = (*Transit)(nil)
    22  )
    23  
    24  // Transit defines vault transit interactions
    25  type Transit struct {
    26  	Client *APIClient
    27  }
    28  
    29  // CreateTransitKey creates a transit key path
    30  func (vt Transit) CreateTransitKey(ctx context.Context, key string, options ...CreateTransitKeyOption) error {
    31  	var config CreateTransitKeyConfig
    32  	for _, o := range options {
    33  		err := o(&config)
    34  		if err != nil {
    35  			return err
    36  		}
    37  	}
    38  
    39  	req := vt.Client.createRequest(MethodPost, filepath.Join("/v1/transit/keys/", key)).WithContext(ctx)
    40  
    41  	body, err := vt.Client.jsonBody(config)
    42  	if err != nil {
    43  		return err
    44  	}
    45  	req.Body = body
    46  
    47  	res, err := vt.Client.send(req, OptTraceVaultOperation("transit.create"), OptTraceKeyName(key))
    48  	if err != nil {
    49  		return err
    50  	}
    51  	defer res.Close()
    52  
    53  	return nil
    54  }
    55  
    56  // ConfigureTransitKey configures a transit key path
    57  func (vt Transit) ConfigureTransitKey(ctx context.Context, key string, options ...UpdateTransitKeyOption) error {
    58  	var config UpdateTransitKeyConfig
    59  	for _, o := range options {
    60  		err := o(&config)
    61  		if err != nil {
    62  			return err
    63  		}
    64  	}
    65  
    66  	req := vt.Client.createRequest(MethodPost, filepath.Join("/v1/transit/keys/", key, "config")).WithContext(ctx)
    67  
    68  	body, err := vt.Client.jsonBody(config)
    69  	if err != nil {
    70  		return err
    71  	}
    72  	req.Body = body
    73  
    74  	res, err := vt.Client.send(req, OptTraceVaultOperation("transit.configure"), OptTraceKeyName(key))
    75  	if err != nil {
    76  		return err
    77  	}
    78  	defer res.Close()
    79  
    80  	return nil
    81  }
    82  
    83  // ReadTransitKey returns data about a transit key path
    84  func (vt Transit) ReadTransitKey(ctx context.Context, key string) (map[string]interface{}, error) {
    85  	req := vt.Client.createRequest(MethodGet, filepath.Join("/v1/transit/keys/", key)).WithContext(ctx)
    86  
    87  	res, err := vt.Client.send(req, OptTraceVaultOperation("transit.read"), OptTraceKeyName(key))
    88  	if err != nil {
    89  		return map[string]interface{}{}, err
    90  	}
    91  	defer res.Close()
    92  
    93  	var keyResult TransitKey
    94  	if err = json.NewDecoder(res).Decode(&keyResult); err != nil {
    95  		return nil, err
    96  	}
    97  
    98  	return keyResult.Data, nil
    99  }
   100  
   101  // DeleteTransitKey deletes a transit key path
   102  func (vt Transit) DeleteTransitKey(ctx context.Context, key string) error {
   103  	req := vt.Client.createRequest(MethodDelete, filepath.Join("/v1/transit/keys/", key)).WithContext(ctx)
   104  
   105  	res, err := vt.Client.send(req, OptTraceVaultOperation("transit.delete"), OptTraceKeyName(key))
   106  	if err != nil {
   107  		return err
   108  	}
   109  	defer res.Close()
   110  
   111  	return nil
   112  }
   113  
   114  // Encrypt encrypts a given set of data
   115  //
   116  // It is required to create the transit key *before* you use it to encrypt or decrypt data.
   117  func (vt Transit) Encrypt(ctx context.Context, key string, context, data []byte) (string, error) {
   118  	req := vt.Client.createRequest(MethodPost, filepath.Join("/v1/transit/encrypt/", key)).WithContext(ctx)
   119  
   120  	payload := map[string]interface{}{
   121  		"plaintext": base64.StdEncoding.EncodeToString(data),
   122  	}
   123  	if context != nil {
   124  		contextEncoded := base64.StdEncoding.EncodeToString(context)
   125  		payload["context"] = contextEncoded
   126  	}
   127  	body, err := vt.Client.jsonBody(payload)
   128  	if err != nil {
   129  		return "", err
   130  	}
   131  	req.Body = body
   132  
   133  	res, err := vt.Client.send(req, OptTraceVaultOperation("transit.encrypt"), OptTraceKeyName(key))
   134  	if err != nil {
   135  		return "", err
   136  	}
   137  	defer res.Close()
   138  
   139  	var encryptionResult TransitResult
   140  	if err = json.NewDecoder(res).Decode(&encryptionResult); err != nil {
   141  		return "", err
   142  	}
   143  
   144  	return encryptionResult.Data.Ciphertext, nil
   145  }
   146  
   147  // Decrypt decrypts a given set of data.
   148  //
   149  // It is required to create the transit key *before* you use it to encrypt or decrypt data.
   150  func (vt Transit) Decrypt(ctx context.Context, key string, context []byte, ciphertext string) ([]byte, error) {
   151  	req := vt.Client.createRequest(MethodPost, filepath.Join("/v1/transit/decrypt/", key)).WithContext(ctx)
   152  
   153  	payload := map[string]interface{}{
   154  		"ciphertext": ciphertext,
   155  	}
   156  	if context != nil {
   157  		contextEncoded := base64.StdEncoding.EncodeToString(context)
   158  		payload["context"] = contextEncoded
   159  	}
   160  	body, err := vt.Client.jsonBody(payload)
   161  	if err != nil {
   162  		return nil, err
   163  	}
   164  	req.Body = body
   165  
   166  	res, err := vt.Client.send(req, OptTraceVaultOperation("transit.decrypt"), OptTraceKeyName(key))
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  	defer res.Close()
   171  
   172  	var decryptionResult TransitResult
   173  	if err = json.NewDecoder(res).Decode(&decryptionResult); err != nil {
   174  		return nil, err
   175  	}
   176  
   177  	return base64.StdEncoding.DecodeString(decryptionResult.Data.Plaintext)
   178  }
   179  
   180  // TransitHMAC batch encrypts a given set of data
   181  // It is required to create the transit key *before* you use it to encrypt or decrypt data.
   182  func (vt Transit) TransitHMAC(ctx context.Context, key string, input []byte) ([]byte, error) {
   183  	req := vt.Client.createRequest(MethodPost, filepath.Join("/v1/transit/hmac/", key, "sha2-256")).WithContext(ctx)
   184  
   185  	inputString := base64.StdEncoding.EncodeToString(input)
   186  	body, err := vt.Client.jsonBody(
   187  		map[string]string{
   188  			"input": inputString,
   189  		},
   190  	)
   191  
   192  	if err != nil {
   193  		return nil, err
   194  	}
   195  	req.Body = body
   196  
   197  	res, err := vt.Client.send(req, OptTraceVaultOperation("transit.hmac"), OptTraceKeyName(key))
   198  	if err != nil {
   199  		return nil, err
   200  	}
   201  	defer res.Close()
   202  
   203  	var decryptionResult TransitHmacResult
   204  	if err = json.NewDecoder(res).Decode(&decryptionResult); err != nil {
   205  		return nil, err
   206  	}
   207  	return []byte(decryptionResult.Data.Hmac), nil
   208  }
   209  
   210  // BatchEncrypt batch encrypts a given set of data
   211  // It is required to create the transit key *before* you use it to encrypt or decrypt data.
   212  func (vt Transit) BatchEncrypt(ctx context.Context, key string, batchInput BatchTransitInput) ([]string, error) {
   213  	if len(batchInput.BatchTransitInputItems) == 0 {
   214  		return []string{}, nil
   215  	}
   216  
   217  	req := vt.Client.createRequest(MethodPost, filepath.Join("/v1/transit/encrypt/", key)).WithContext(ctx)
   218  
   219  	body, err := vt.Client.jsonBody(batchInput)
   220  	if err != nil {
   221  		return nil, err
   222  	}
   223  	req.Body = body
   224  
   225  	res, err := vt.Client.send(req, OptTraceVaultOperation("transit.batch.encrypt"), OptTraceKeyName(key))
   226  	if err != nil {
   227  		return nil, err
   228  	}
   229  	defer res.Close()
   230  
   231  	var batchEncryptionResult BatchTransitResult
   232  	if err = json.NewDecoder(res).Decode(&batchEncryptionResult); err != nil {
   233  		return nil, err
   234  	}
   235  
   236  	var ciphertextResults []string
   237  	for _, result := range batchEncryptionResult.Data.BatchTransitResult {
   238  		if result.Error != "" {
   239  			return nil, ex.New(ErrBatchTransitEncryptError, ex.OptMessage(result.Error))
   240  		}
   241  		ciphertextResults = append(ciphertextResults, result.Ciphertext)
   242  	}
   243  
   244  	return ciphertextResults, nil
   245  
   246  }
   247  
   248  // BatchDecrypt batch decrypts a given set of data
   249  // It is required to create the transit key *before* you use it to encrypt or decrypt data.
   250  func (vt Transit) BatchDecrypt(ctx context.Context, key string, batchInput BatchTransitInput) ([][]byte, error) {
   251  	if len(batchInput.BatchTransitInputItems) == 0 {
   252  		return [][]byte{}, nil
   253  	}
   254  
   255  	req := vt.Client.createRequest(MethodPost, filepath.Join("/v1/transit/decrypt/", key)).WithContext(ctx)
   256  
   257  	body, err := vt.Client.jsonBody(batchInput)
   258  	if err != nil {
   259  		return nil, err
   260  	}
   261  	req.Body = body
   262  
   263  	res, err := vt.Client.send(req, OptTraceVaultOperation("transit.batch.decrypt"), OptTraceKeyName(key))
   264  	if err != nil {
   265  		return nil, err
   266  	}
   267  	defer res.Close()
   268  
   269  	var batchDecryptionResult BatchTransitResult
   270  	if err = json.NewDecoder(res).Decode(&batchDecryptionResult); err != nil {
   271  		return nil, err
   272  	}
   273  
   274  	var plaintextResults [][]byte
   275  	for _, result := range batchDecryptionResult.Data.BatchTransitResult {
   276  		if result.Error != "" {
   277  			return nil, ex.New(ErrBatchTransitDecryptError, ex.OptMessage(result.Error))
   278  		}
   279  		plaintext, err := base64.StdEncoding.DecodeString(result.Plaintext)
   280  		if err != nil {
   281  			return nil, ex.New(ErrBatchTransitDecryptError, ex.OptInnerClass(err))
   282  		}
   283  		plaintextResults = append(plaintextResults, plaintext)
   284  	}
   285  
   286  	return plaintextResults, nil
   287  
   288  }