github.com/blend/go-sdk@v1.20220411.3/vault/api_client.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package vault
     9  
    10  import (
    11  	"context"
    12  	"crypto/x509"
    13  	"encoding/json"
    14  	"io"
    15  	"net/http"
    16  	"net/url"
    17  	"path/filepath"
    18  	"time"
    19  
    20  	"golang.org/x/net/http2"
    21  
    22  	"github.com/blend/go-sdk/bufferutil"
    23  	"github.com/blend/go-sdk/ex"
    24  	"github.com/blend/go-sdk/logger"
    25  )
    26  
    27  // Assert APIClient implements Client
    28  var (
    29  	_ Client = (*APIClient)(nil)
    30  )
    31  
    32  // New creates a new vault client with a default set of options.
    33  func New(options ...Option) (*APIClient, error) {
    34  	remote, err := url.ParseRequestURI(DefaultAddr)
    35  	if err != nil {
    36  		return nil, err
    37  	}
    38  
    39  	client := &APIClient{
    40  		Timeout:    DefaultTimeout,
    41  		Mount:      DefaultMount,
    42  		Remote:     remote,
    43  		BufferPool: bufferutil.NewPool(DefaultBufferPoolSize),
    44  	}
    45  
    46  	client.KV1 = &KV1{Client: client}
    47  	client.KV2 = &KV2{Client: client}
    48  	client.Transit = &Transit{Client: client}
    49  	client.AWSAuth, err = NewAWSAuth()
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  
    54  	for _, option := range options {
    55  		if err = option(client); err != nil {
    56  			return nil, err
    57  		}
    58  	}
    59  
    60  	xport := client.Transport
    61  	if xport == nil {
    62  		xport = &http.Transport{}
    63  		err = http2.ConfigureTransport(xport)
    64  		if err != nil {
    65  			return nil, err
    66  		}
    67  	}
    68  
    69  	client.Client = &http.Client{
    70  		Transport: xport,
    71  		Timeout:   client.Timeout,
    72  	}
    73  
    74  	return client, nil
    75  }
    76  
    77  // APIClient is a client to talk to vault.
    78  type APIClient struct {
    79  	Timeout    time.Duration
    80  	Transport  *http.Transport
    81  	Remote     *url.URL
    82  	Token      string
    83  	Mount      string
    84  	Log        logger.Log
    85  	BufferPool *bufferutil.Pool
    86  	KV1        *KV1
    87  	KV2        *KV2
    88  	Transit    TransitClient
    89  	Client     HTTPClient
    90  	CertPool   *x509.CertPool
    91  	Tracer     Tracer
    92  	AWSAuth    *AWSAuth
    93  }
    94  
    95  // Put puts a value.
    96  func (c *APIClient) Put(ctx context.Context, key string, data Values, options ...CallOption) error {
    97  	backend, err := c.backendKV(ctx, key)
    98  	if err != nil {
    99  		return err
   100  	}
   101  
   102  	return backend.Put(ctx, key, data, options...)
   103  }
   104  
   105  // Get gets a value at a given key.
   106  func (c *APIClient) Get(ctx context.Context, key string, options ...CallOption) (Values, error) {
   107  	backend, err := c.backendKV(ctx, key)
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  
   112  	return backend.Get(ctx, key, options...)
   113  }
   114  
   115  // Delete puts a key.
   116  func (c *APIClient) Delete(ctx context.Context, key string, options ...CallOption) error {
   117  	backend, err := c.backendKV(ctx, key)
   118  	if err != nil {
   119  		return err
   120  	}
   121  	return backend.Delete(ctx, key, options...)
   122  }
   123  
   124  // List returns a slice of key and subfolder names at this path.
   125  func (c *APIClient) List(ctx context.Context, path string, options ...CallOption) ([]string, error) {
   126  	backend, err := c.backendKV(ctx, path)
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  
   131  	return backend.List(ctx, path, options...)
   132  }
   133  
   134  // ReadInto reads a secret into an object.
   135  func (c *APIClient) ReadInto(ctx context.Context, key string, obj interface{}, options ...CallOption) error {
   136  	response, err := c.Get(ctx, key, options...)
   137  	if err != nil {
   138  		return err
   139  	}
   140  	asStrings := make(map[string]string)
   141  	for k, v := range response {
   142  		if s, ok := v.(string); ok {
   143  			asStrings[k] = s
   144  		}
   145  	}
   146  	return RestoreJSON(asStrings, obj)
   147  }
   148  
   149  // WriteInto writes an object into a secret at a given key.
   150  func (c *APIClient) WriteInto(ctx context.Context, key string, obj interface{}, options ...CallOption) error {
   151  	data, err := DecomposeJSON(obj)
   152  	if err != nil {
   153  		return err
   154  	}
   155  	asData := make(map[string]interface{})
   156  	for k, v := range data {
   157  		asData[k] = v
   158  	}
   159  	return c.Put(ctx, key, asData, options...)
   160  }
   161  
   162  // CreateTransitKey creates a transit key path
   163  func (c *APIClient) CreateTransitKey(ctx context.Context, key string, options ...CreateTransitKeyOption) error {
   164  	return c.Transit.CreateTransitKey(ctx, key, options...)
   165  }
   166  
   167  // ConfigureTransitKey configures a transit key path
   168  func (c *APIClient) ConfigureTransitKey(ctx context.Context, key string, options ...UpdateTransitKeyOption) error {
   169  	return c.Transit.ConfigureTransitKey(ctx, key, options...)
   170  }
   171  
   172  // ReadTransitKey returns data about a transit key path
   173  func (c *APIClient) ReadTransitKey(ctx context.Context, key string) (map[string]interface{}, error) {
   174  	return c.Transit.ReadTransitKey(ctx, key)
   175  }
   176  
   177  // DeleteTransitKey deletes a transit key path
   178  func (c *APIClient) DeleteTransitKey(ctx context.Context, key string) error {
   179  	return c.Transit.DeleteTransitKey(ctx, key)
   180  }
   181  
   182  // Encrypt encrypts a given set of data.
   183  func (c *APIClient) Encrypt(ctx context.Context, key string, context, data []byte) (string, error) {
   184  	return c.Transit.Encrypt(ctx, key, context, data)
   185  }
   186  
   187  // Decrypt decrypts a given set of data.
   188  func (c *APIClient) Decrypt(ctx context.Context, key string, context []byte, ciphertext string) ([]byte, error) {
   189  	return c.Transit.Decrypt(ctx, key, context, ciphertext)
   190  }
   191  
   192  // TransitHMAC decrypts a given set of data.
   193  func (c *APIClient) TransitHMAC(ctx context.Context, key string, input []byte) ([]byte, error) {
   194  	return c.Transit.TransitHMAC(ctx, key, input)
   195  }
   196  
   197  // BatchEncrypt batch encrypts a given set of data.
   198  func (c *APIClient) BatchEncrypt(ctx context.Context, key string, batchInput BatchTransitInput) ([]string, error) {
   199  	return c.Transit.BatchEncrypt(ctx, key, batchInput)
   200  }
   201  
   202  // BatchDecrypt batch decrypts a given set of data.
   203  func (c *APIClient) BatchDecrypt(ctx context.Context, key string, batchInput BatchTransitInput) ([][]byte, error) {
   204  	return c.Transit.BatchDecrypt(ctx, key, batchInput)
   205  }
   206  
   207  // --------------------------------------------------------------------------------
   208  // utility methods
   209  // --------------------------------------------------------------------------------
   210  
   211  func (c *APIClient) backendKV(ctx context.Context, key string) (KV, error) {
   212  	version, err := c.getVersion(ctx, key)
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  	switch version {
   217  	case Version1:
   218  		return c.KV1, nil
   219  	case Version2:
   220  		return c.KV2, nil
   221  	default:
   222  		return c.KV1, nil
   223  	}
   224  }
   225  
   226  func (c *APIClient) getVersion(ctx context.Context, key string) (string, error) {
   227  	meta, err := c.getMountMeta(ctx, filepath.Join(c.Mount, key))
   228  	if err != nil {
   229  		return "", err
   230  	}
   231  	return meta.Data.Options["version"], nil
   232  }
   233  
   234  func (c *APIClient) getMountMeta(ctx context.Context, key string) (*MountResponse, error) {
   235  	req := c.createRequest(MethodGet, filepath.Join("/v1/sys/internal/ui/mounts/", key))
   236  	req = req.WithContext(ctx)
   237  
   238  	res, err := c.Client.Do(req)
   239  	if err != nil {
   240  		return nil, err
   241  	}
   242  	defer res.Body.Close()
   243  
   244  	var response MountResponse
   245  	if err := json.NewDecoder(res.Body).Decode(&response); err != nil {
   246  		return nil, err
   247  	}
   248  	return &response, nil
   249  }
   250  
   251  func (c *APIClient) jsonBody(input interface{}) (io.ReadCloser, error) {
   252  	buf := c.BufferPool.Get()
   253  	err := json.NewEncoder(buf).Encode(input)
   254  	if err != nil {
   255  		return nil, err
   256  	}
   257  	return bufferutil.PutOnClose(buf, c.BufferPool), nil
   258  }
   259  
   260  func (c *APIClient) readJSON(r io.Reader, output interface{}) error {
   261  	return json.NewDecoder(r).Decode(output)
   262  }
   263  
   264  // copyRemote returns a copy of our remote.
   265  func (c *APIClient) copyRemote() *url.URL {
   266  	remoteCopy := *c.Remote
   267  	return &remoteCopy
   268  }
   269  
   270  // applyOptions applies options to a request.
   271  func (c *APIClient) applyOptions(req *http.Request, options ...CallOption) error {
   272  	var err error
   273  	for _, opt := range options {
   274  		if err = opt(req); err != nil {
   275  			return err
   276  		}
   277  	}
   278  	return nil
   279  }
   280  
   281  func (c *APIClient) createRequest(method, path string, options ...CallOption) *http.Request {
   282  	remote := c.copyRemote()
   283  	remote.Path = path
   284  	req := &http.Request{
   285  		Method: method,
   286  		URL:    remote,
   287  		Header: http.Header{
   288  			HeaderVaultToken: []string{c.Token},
   289  		},
   290  	}
   291  	_ = c.applyOptions(req, options...)
   292  	return req
   293  }
   294  
   295  func (c *APIClient) send(req *http.Request, traceOptions ...TraceOption) (body io.ReadCloser, err error) {
   296  	var statusCode int
   297  	var finisher TraceFinisher
   298  	if c.Log != nil {
   299  		e := NewEvent(req)
   300  		start := time.Now()
   301  		defer func() {
   302  			e.Elapsed = time.Since(start)
   303  			logger.MaybeTriggerContext(req.Context(), c.Log, e)
   304  		}()
   305  	}
   306  	if finisher != nil {
   307  		defer func() {
   308  			finisher.Finish(req.Context(), statusCode, err)
   309  		}()
   310  	}
   311  	if c.Tracer != nil {
   312  		var traceErr error
   313  		finisher, traceErr = c.Tracer.Start(req.Context(), traceOptions...)
   314  		if traceErr != nil {
   315  			logger.MaybeError(c.Log, traceErr)
   316  		}
   317  	}
   318  
   319  	var res *http.Response
   320  	res, err = c.Client.Do(req)
   321  	if err != nil {
   322  		statusCode = 500
   323  		return
   324  	}
   325  	statusCode = res.StatusCode
   326  
   327  	if statusCode > 299 {
   328  		buf := c.BufferPool.Get()
   329  		defer c.BufferPool.Put(buf)
   330  		if _, err = io.Copy(buf, res.Body); err != nil {
   331  			err = ex.New(err)
   332  			return
   333  		}
   334  		err = ex.New(ErrClassForStatus(statusCode), ex.OptMessagef("status: %d; %v", statusCode, buf.String()))
   335  		return
   336  	}
   337  	body = res.Body
   338  	return
   339  }
   340  
   341  func (c *APIClient) discard(res io.ReadCloser, err error) error {
   342  	if err != nil {
   343  		return err
   344  	}
   345  	defer res.Close()
   346  	if _, err = io.Copy(io.Discard, res); err != nil {
   347  		return err
   348  	}
   349  	return nil
   350  }