github.com/zntrio/harp/v2@v2.0.9/pkg/vault/transit/service.go (about)

     1  // Licensed to Elasticsearch B.V. under one or more contributor
     2  // license agreements. See the NOTICE file distributed with
     3  // this work for additional information regarding copyright
     4  // ownership. Elasticsearch B.V. licenses this file to you under
     5  // the Apache License, Version 2.0 (the "License"); you may
     6  // not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //     http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing,
    12  // software distributed under the License is distributed on an
    13  // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    14  // KIND, either express or implied.  See the License for the
    15  // specific language governing permissions and limitations
    16  // under the License.
    17  
    18  package transit
    19  
    20  import (
    21  	"context"
    22  	"encoding/base64"
    23  	"errors"
    24  	"fmt"
    25  	"net/url"
    26  	"path"
    27  	"strings"
    28  
    29  	"github.com/hashicorp/vault/api"
    30  
    31  	"github.com/zntrio/harp/v2/pkg/vault/logical"
    32  	vpath "github.com/zntrio/harp/v2/pkg/vault/path"
    33  )
    34  
    35  type service struct {
    36  	logical   logical.Logical
    37  	mountPath string
    38  	keyName   string
    39  }
    40  
    41  // New instantiates a Vault transit backend encryption service.
    42  func New(client *api.Client, mountPath, keyName string) (Service, error) {
    43  	return &service{
    44  		logical:   client.Logical(),
    45  		mountPath: strings.TrimSuffix(path.Clean(mountPath), "/"),
    46  		keyName:   keyName,
    47  	}, nil
    48  }
    49  
    50  // -----------------------------------------------------------------------------
    51  
    52  func (s *service) Encrypt(ctx context.Context, cleartext []byte) ([]byte, error) {
    53  	// Prepare query
    54  	encryptPath := vpath.SanitizePath(path.Join(url.PathEscape(s.mountPath), "encrypt", url.PathEscape(s.keyName)))
    55  	data := map[string]interface{}{
    56  		"plaintext": base64.StdEncoding.EncodeToString(cleartext),
    57  	}
    58  
    59  	// Send to Vault.
    60  	secret, err := s.logical.Write(encryptPath, data)
    61  	if err != nil {
    62  		return nil, fmt.Errorf("unable to encrypt with %q key: %w", s.keyName, err)
    63  	}
    64  
    65  	// Check response wrapping
    66  	if secret.WrapInfo != nil {
    67  		// Unwrap with response token
    68  		secret, err = s.logical.Unwrap(secret.WrapInfo.Token)
    69  		if err != nil {
    70  			return nil, fmt.Errorf("unable to unwrap the response: %w", err)
    71  		}
    72  	}
    73  
    74  	// Parse server response.
    75  	if cipherText, ok := secret.Data["ciphertext"].(string); ok && cipherText != "" {
    76  		return []byte(cipherText), nil
    77  	}
    78  
    79  	// Return error.
    80  	return nil, errors.New("could not encrypt given data")
    81  }
    82  
    83  func (s *service) Decrypt(ctx context.Context, ciphertext []byte) ([]byte, error) {
    84  	// Prepare query
    85  	decryptPath := vpath.SanitizePath(path.Join(url.PathEscape(s.mountPath), "decrypt", url.PathEscape(s.keyName)))
    86  	data := map[string]interface{}{
    87  		"ciphertext": string(ciphertext),
    88  	}
    89  
    90  	// Send to Vault.
    91  	secret, err := s.logical.Write(decryptPath, data)
    92  	if err != nil {
    93  		return nil, fmt.Errorf("unable to decrypt with %q key: %w", s.keyName, err)
    94  	}
    95  
    96  	// Check response wrapping
    97  	if secret.WrapInfo != nil {
    98  		// Unwrap with response token
    99  		secret, err = s.logical.Unwrap(secret.WrapInfo.Token)
   100  		if err != nil {
   101  			return nil, fmt.Errorf("unable to unwrap the response: %w", err)
   102  		}
   103  	}
   104  
   105  	// Parse server response.
   106  	if plainText64, ok := secret.Data["plaintext"].(string); ok && plainText64 != "" {
   107  		plainText, err := base64.StdEncoding.DecodeString(plainText64)
   108  		if err != nil {
   109  			return nil, fmt.Errorf("unable to decode secret: %w", err)
   110  		}
   111  
   112  		// Return no error
   113  		return plainText, nil
   114  	}
   115  
   116  	// Return error.
   117  	return nil, errors.New("could not decrypt given data")
   118  }