github.com/blend/go-sdk@v1.20220411.3/vault/api_client_test.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package vault
     9  
    10  import (
    11  	"bytes"
    12  	"context"
    13  	"encoding/base64"
    14  	"fmt"
    15  	"io"
    16  	"net/http"
    17  	"net/http/httptest"
    18  	"net/url"
    19  	"testing"
    20  
    21  	"github.com/blend/go-sdk/assert"
    22  	"github.com/blend/go-sdk/webutil"
    23  )
    24  
    25  func mustURLf(format string, args ...interface{}) *url.URL {
    26  	return webutil.MustParseURL(fmt.Sprintf(format, args...))
    27  }
    28  
    29  func TestVaultClientBackendKV(t *testing.T) {
    30  	assert := assert.New(t)
    31  	todo := context.TODO()
    32  
    33  	client, err := New()
    34  	assert.Nil(err)
    35  
    36  	mountMetaJSON := `{"request_id":"e114c628-6493-28ed-0975-418a75c7976f","lease_id":"","renewable":false,"lease_duration":0,"data":{"accessor":"kv_45f6a162","config":{"default_lease_ttl":0,"force_no_cache":false,"max_lease_ttl":0,"plugin_name":""},"description":"key/value secret storage","local":false,"options":{"version":"2"},"path":"secret/","seal_wrap":false,"type":"kv"},"wrap_info":null,"warnings":null,"auth":null}`
    37  
    38  	m := NewMockHTTPClient().WithString("GET", mustURLf("%s/v1/sys/internal/ui/mounts/secret/foo/bar", client.Remote.String()), mountMetaJSON)
    39  	client.Client = m
    40  
    41  	backend, err := client.backendKV(todo, "foo/bar")
    42  	assert.Nil(err)
    43  	assert.NotNil(backend)
    44  }
    45  
    46  func TestVaultClientGetVersion(t *testing.T) {
    47  	assert := assert.New(t)
    48  	todo := context.TODO()
    49  
    50  	client, err := New()
    51  	assert.Nil(err)
    52  
    53  	mountMetaJSONV1 := `{"request_id":"e114c628-6493-28ed-0975-418a75c7976f","lease_id":"","renewable":false,"lease_duration":0,"data":{"accessor":"kv_45f6a162","config":{"default_lease_ttl":0,"force_no_cache":false,"max_lease_ttl":0,"plugin_name":""},"description":"key/value secret storage","local":false,"options":{"version":"1"},"path":"secret/","seal_wrap":false,"type":"kv"},"wrap_info":null,"warnings":null,"auth":null}`
    54  	mountMetaJSONV2 := `{"request_id":"e114c628-6493-28ed-0975-418a75c7976f","lease_id":"","renewable":false,"lease_duration":0,"data":{"accessor":"kv_45f6a162","config":{"default_lease_ttl":0,"force_no_cache":false,"max_lease_ttl":0,"plugin_name":""},"description":"key/value secret storage","local":false,"options":{"version":"2"},"path":"secret/","seal_wrap":false,"type":"kv"},"wrap_info":null,"warnings":null,"auth":null}`
    55  
    56  	m := NewMockHTTPClient().
    57  		WithString("GET", mustURLf("%s/v1/sys/internal/ui/mounts/secret/foo/bar", client.Remote.String()), mountMetaJSONV1)
    58  
    59  	client.Client = m
    60  
    61  	version, err := client.getVersion(todo, "foo/bar")
    62  	assert.Nil(err)
    63  	assert.Equal(Version1, version)
    64  
    65  	m.WithString("GET", mustURLf("%s/v1/sys/internal/ui/mounts/secret/foo/bar", client.Remote.String()), mountMetaJSONV2)
    66  
    67  	version, err = client.getVersion(todo, "foo/bar")
    68  	assert.Nil(err)
    69  	assert.Equal(Version2, version)
    70  }
    71  
    72  func TestVaultClientGetMountMeta(t *testing.T) {
    73  	assert := assert.New(t)
    74  	todo := context.TODO()
    75  
    76  	client, err := New()
    77  	assert.Nil(err)
    78  
    79  	mountMetaJSON := `{"request_id":"e114c628-6493-28ed-0975-418a75c7976f","lease_id":"","renewable":false,"lease_duration":0,"data":{"accessor":"kv_45f6a162","config":{"default_lease_ttl":0,"force_no_cache":false,"max_lease_ttl":0,"plugin_name":""},"description":"key/value secret storage","local":false,"options":{"version":"2"},"path":"secret/","seal_wrap":false,"type":"kv"},"wrap_info":null,"warnings":null,"auth":null}`
    80  
    81  	m := NewMockHTTPClient().WithString("GET", mustURLf("%s/v1/sys/internal/ui/mounts/secret/foo/bar", client.Remote.String()), mountMetaJSON)
    82  	client.Client = m
    83  
    84  	mountMeta, err := client.getMountMeta(todo, "secret/foo/bar")
    85  	assert.Nil(err)
    86  	assert.NotNil(mountMeta)
    87  	assert.Equal(Version2, mountMeta.Data.Options["version"])
    88  }
    89  
    90  func TestVaultClientJSONBody(t *testing.T) {
    91  	assert := assert.New(t)
    92  
    93  	client, err := New()
    94  	assert.Nil(err)
    95  
    96  	output, err := client.jsonBody(map[string]interface{}{
    97  		"foo": "bar",
    98  	})
    99  	assert.Nil(err)
   100  	defer output.Close()
   101  
   102  	contents, err := io.ReadAll(output)
   103  	assert.Nil(err)
   104  	assert.Equal("{\"foo\":\"bar\"}\n", string(contents))
   105  }
   106  
   107  func TestVaultClientReadJSON(t *testing.T) {
   108  	assert := assert.New(t)
   109  
   110  	client, err := New()
   111  	assert.Nil(err)
   112  
   113  	jsonBody := bytes.NewBuffer([]byte(`{"foo":"bar"}`))
   114  
   115  	output := map[string]interface{}{}
   116  	assert.Nil(client.readJSON(jsonBody, &output))
   117  	assert.Equal("bar", output["foo"])
   118  }
   119  
   120  func TestVaultClientCopyRemote(t *testing.T) {
   121  	assert := assert.New(t)
   122  
   123  	client, err := New()
   124  	assert.Nil(err)
   125  
   126  	copy := client.copyRemote()
   127  	copy.Host = "not_" + copy.Host
   128  
   129  	anotherCopy := client.copyRemote()
   130  	assert.NotEqual(anotherCopy.Host, copy.Host)
   131  }
   132  
   133  func TestVaultClientDiscard(t *testing.T) {
   134  	assert := assert.New(t)
   135  
   136  	client, err := New()
   137  	assert.Nil(err)
   138  
   139  	assert.NotNil(client.discard(nil, fmt.Errorf("this is only a test")))
   140  
   141  	assert.Nil(client.discard(client.jsonBody(map[string]interface{}{
   142  		"foo": "bar",
   143  	})))
   144  }
   145  
   146  func TestVaultCreateTransitKey(t *testing.T) {
   147  	assert := assert.New(t)
   148  	todo := context.TODO()
   149  
   150  	client, err := New()
   151  	assert.Nil(err)
   152  
   153  	key := "key"
   154  
   155  	m := NewMockHTTPClient().
   156  		With(
   157  			"POST",
   158  			mustURLf("%s/v1/transit/keys/%s", client.Remote.String(), key),
   159  			&http.Response{
   160  				StatusCode: http.StatusNoContent,
   161  				Body:       io.NopCloser(bytes.NewBuffer([]byte{})),
   162  			},
   163  		)
   164  	client.Client = m
   165  
   166  	err = client.CreateTransitKey(todo, "key")
   167  	assert.Nil(err)
   168  }
   169  
   170  func TestVaultConfigureTransitKey(t *testing.T) {
   171  	assert := assert.New(t)
   172  	todo := context.TODO()
   173  
   174  	client, err := New()
   175  	assert.Nil(err)
   176  
   177  	key := "key"
   178  
   179  	m := NewMockHTTPClient().
   180  		With(
   181  			"POST",
   182  			mustURLf("%s/v1/transit/keys/%s/config", client.Remote.String(), key),
   183  			&http.Response{
   184  				StatusCode: http.StatusNoContent,
   185  				Body:       io.NopCloser(bytes.NewBuffer([]byte{})),
   186  			},
   187  		)
   188  	client.Client = m
   189  
   190  	err = client.ConfigureTransitKey(todo, "key", OptUpdateTransitDeletionAllowed(true))
   191  	assert.Nil(err)
   192  }
   193  
   194  func TestVaultReadTransitKey(t *testing.T) {
   195  	assert := assert.New(t)
   196  	todo := context.TODO()
   197  
   198  	client, err := New()
   199  	assert.Nil(err)
   200  
   201  	key := "key"
   202  	keyMetaJSON := `{"request_id":"e114c628-6493-28ed-0975-418a75c7976f","lease_id":"","renewable":false,"lease_duration":0,"data":{"deletion_allowed":true,"exportable":false,"allow_plaintext_backup":false,"keys": {"1": 1442851412},"min_decryption_version": 1,"min_encryption_version": 0,"name": "foo"},"wrap_info":null,"warnings":null,"auth":null}`
   203  
   204  	m := NewMockHTTPClient().WithString("GET", mustURLf("%s/v1/transit/keys/%s", client.Remote.String(), key), keyMetaJSON)
   205  	client.Client = m
   206  
   207  	data, err := client.ReadTransitKey(todo, "key")
   208  	assert.Nil(err)
   209  	assert.Equal(true, data["deletion_allowed"])
   210  }
   211  
   212  func TestVaultDeleteTransitKey(t *testing.T) {
   213  	assert := assert.New(t)
   214  	todo := context.TODO()
   215  
   216  	client, err := New()
   217  	assert.Nil(err)
   218  
   219  	key := "key"
   220  
   221  	m := NewMockHTTPClient().
   222  		With(
   223  			"DELETE",
   224  			mustURLf("%s/v1/transit/keys/%s", client.Remote.String(), key),
   225  			&http.Response{
   226  				StatusCode: http.StatusNoContent,
   227  				Body:       io.NopCloser(bytes.NewBuffer([]byte{})),
   228  			},
   229  		)
   230  	client.Client = m
   231  
   232  	err = client.DeleteTransitKey(todo, "key")
   233  	assert.Nil(err)
   234  }
   235  
   236  func TestVaultHandleRedirects(t *testing.T) {
   237  	assert := assert.New(t)
   238  
   239  	rawResponse := "{\"status\":\"ok!\"}\n"
   240  
   241  	inner := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   242  		w.Header().Set(webutil.HeaderContentType, webutil.ContentTypeApplicationJSON)
   243  		w.WriteHeader(http.StatusOK)
   244  		fmt.Fprint(w, rawResponse)
   245  	}))
   246  	defer inner.Close()
   247  	outer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   248  		http.Redirect(w, r, inner.URL, http.StatusTemporaryRedirect)
   249  	}))
   250  	defer outer.Close()
   251  
   252  	client, err := New(
   253  		OptRemote(outer.URL),
   254  	)
   255  	assert.Nil(err)
   256  	assert.NotNil(client)
   257  
   258  	rawURL, err := url.Parse(outer.URL)
   259  	assert.Nil(err)
   260  	res, err := client.Client.Do(&http.Request{URL: rawURL})
   261  	assert.Nil(err)
   262  	defer res.Body.Close()
   263  	assert.Equal(http.StatusOK, res.StatusCode)
   264  
   265  	contents, err := io.ReadAll(res.Body)
   266  	assert.Nil(err)
   267  	assert.Equal(rawResponse, string(contents))
   268  }
   269  
   270  func TestVaultBatchEncryptDecrypt_Happy(t *testing.T) {
   271  	assert := assert.New(t)
   272  	todo := context.Background()
   273  
   274  	client, err := New()
   275  	assert.Nil(err)
   276  
   277  	key := "key"
   278  
   279  	plaintext1 := []byte("this is plaintext")
   280  	plaintext2 := []byte("this is plaintext2")
   281  	batchInput := BatchTransitInput{
   282  		BatchTransitInputItems: []BatchTransitInputItem{
   283  			{
   284  				Context:   nil,
   285  				Plaintext: plaintext1,
   286  			},
   287  			{
   288  				Context:   nil,
   289  				Plaintext: plaintext2,
   290  			},
   291  		},
   292  	}
   293  
   294  	batchDecryptResultBytes := []byte(fmt.Sprintf(`
   295  			{
   296  			  "data": {
   297  				"batch_results": [
   298  				  {
   299  					"plaintext": "%s",
   300  					"key_version": 1
   301  				  },
   302  				  {
   303  					"plaintext": "%s",
   304  					"key_version": 1
   305  				  }
   306  				]
   307  			  }
   308  			}
   309  	`, base64.StdEncoding.EncodeToString(plaintext1), base64.StdEncoding.EncodeToString(plaintext2)))
   310  
   311  	batchEncryptResultBytes := []byte(fmt.Sprintf(`
   312  		{
   313  		  "data": {
   314  			"batch_results": [
   315  			  {
   316  				"ciphertext": "vault:%s",
   317  				"key_version": 1
   318  			  },
   319  			  {
   320  				"ciphertext": "vault:%s",
   321  				"key_version": 1
   322  			  }
   323  			]
   324  		  }
   325  		}
   326  	`, plaintext1, plaintext2))
   327  	m := NewMockHTTPClient().
   328  		With(
   329  			"POST",
   330  			mustURLf("%s/v1/transit/encrypt/%s", client.Remote.String(), key),
   331  			&http.Response{
   332  				StatusCode: http.StatusOK,
   333  				Body:       io.NopCloser(bytes.NewBuffer(batchEncryptResultBytes)),
   334  			},
   335  		).With(
   336  		"POST",
   337  		mustURLf("%s/v1/transit/decrypt/%s", client.Remote.String(), key),
   338  		&http.Response{
   339  			StatusCode: http.StatusOK,
   340  			Body:       io.NopCloser(bytes.NewBuffer(batchDecryptResultBytes)),
   341  		},
   342  	)
   343  	client.Client = m
   344  
   345  	ciphertextResults, err := client.BatchEncrypt(todo, "key", batchInput)
   346  	assert.Nil(err)
   347  	assert.Equal(fmt.Sprintf("vault:%s", plaintext1), ciphertextResults[0])
   348  	assert.Equal(fmt.Sprintf("vault:%s", plaintext2), ciphertextResults[1])
   349  
   350  	plaintextResults, err := client.BatchDecrypt(todo, "key", batchInput)
   351  	assert.Nil(err)
   352  	assert.Equal(plaintext1, plaintextResults[0])
   353  	assert.Equal(plaintext2, plaintextResults[1])
   354  }
   355  
   356  func TestVaultBatchEncryptDecrypt_EmptyInput(t *testing.T) {
   357  	assert := assert.New(t)
   358  	todo := context.Background()
   359  
   360  	client, err := New()
   361  	assert.Nil(err)
   362  
   363  	key := "key"
   364  
   365  	batchInput := BatchTransitInput{
   366  		BatchTransitInputItems: []BatchTransitInputItem{},
   367  	}
   368  
   369  	errorResultBytes := []byte(`
   370  		{
   371  		  "data": {
   372  				"error": "missing batch input to process"
   373  		  }
   374  		}
   375  	`)
   376  	m := NewMockHTTPClient().
   377  		With(
   378  			"POST",
   379  			mustURLf("%s/v1/transit/encrypt/%s", client.Remote.String(), key),
   380  			&http.Response{
   381  				StatusCode: http.StatusBadRequest,
   382  				Body:       io.NopCloser(bytes.NewBuffer(errorResultBytes)),
   383  			},
   384  		).With(
   385  		"POST",
   386  		mustURLf("%s/v1/transit/decrypt/%s", client.Remote.String(), key),
   387  		&http.Response{
   388  			StatusCode: http.StatusBadRequest,
   389  			Body:       io.NopCloser(bytes.NewBuffer(errorResultBytes)),
   390  		},
   391  	)
   392  	client.Client = m
   393  
   394  	ciphertextResults, err := client.BatchEncrypt(todo, "key", batchInput)
   395  	assert.Nil(err)
   396  	assert.Empty(ciphertextResults)
   397  
   398  	plaintextResults, err := client.BatchDecrypt(todo, "key", batchInput)
   399  	assert.Nil(err)
   400  	assert.Empty(plaintextResults)
   401  }
   402  
   403  func TestVaultBatchEncrypt_Error(t *testing.T) {
   404  	assert := assert.New(t)
   405  	todo := context.TODO()
   406  
   407  	client, err := New()
   408  	assert.Nil(err)
   409  
   410  	key := "key"
   411  
   412  	plaintext1 := []byte("this is plaintext")
   413  	plaintext2 := []byte("this is plaintext2")
   414  	batchInput := BatchTransitInput{
   415  		BatchTransitInputItems: []BatchTransitInputItem{
   416  			{
   417  				Context:   nil,
   418  				Plaintext: plaintext1,
   419  			},
   420  			{
   421  				Context:   nil,
   422  				Plaintext: plaintext2,
   423  			},
   424  		},
   425  	}
   426  
   427  	batchEncryptResultBytes := []byte(fmt.Sprintf(`
   428  		{
   429  		  "data": {
   430  			"batch_results": [
   431  			  {
   432  				"ciphertext": "vault:%s",
   433  				"key_version": 1
   434  			  },
   435  			  {
   436  				"error": "encryption error",
   437  				"ciphertext": "vault:%s",
   438  				"key_version": 1
   439  			  }
   440  			]
   441  		  }
   442  		}
   443  	`, plaintext1, plaintext2))
   444  	m := NewMockHTTPClient().
   445  		With(
   446  			"POST",
   447  			mustURLf("%s/v1/transit/encrypt/%s", client.Remote.String(), key),
   448  			&http.Response{
   449  				StatusCode: http.StatusOK,
   450  				Body:       io.NopCloser(bytes.NewBuffer(batchEncryptResultBytes)),
   451  			},
   452  		)
   453  	client.Client = m
   454  
   455  	ciphertextResults, err := client.BatchEncrypt(todo, "key", batchInput)
   456  	assert.NotNil(err)
   457  	assert.Equal(ErrBatchTransitEncryptError, err.Error())
   458  	assert.Nil(ciphertextResults)
   459  }
   460  
   461  func TestVaultBatchDecrypt_Error(t *testing.T) {
   462  	assert := assert.New(t)
   463  	todo := context.TODO()
   464  
   465  	client, err := New()
   466  	assert.Nil(err)
   467  
   468  	key := "key"
   469  
   470  	plaintext1 := []byte("this is plaintext")
   471  	plaintext2 := []byte("this is plaintext2")
   472  	batchInput := BatchTransitInput{
   473  		BatchTransitInputItems: []BatchTransitInputItem{
   474  			{
   475  				Context:   nil,
   476  				Plaintext: plaintext1,
   477  			},
   478  			{
   479  				Context:   nil,
   480  				Plaintext: plaintext2,
   481  			},
   482  		},
   483  	}
   484  
   485  	batchDecryptResultBytes := []byte(fmt.Sprintf(`
   486  			{
   487  			  "data": {
   488  				"batch_results": [
   489  				  {
   490  					"error": "error",
   491  					"plaintext": "%s",
   492  					"key_version": 1
   493  				  },
   494  				  {
   495  					"plaintext": "%s",
   496  					"key_version": 1
   497  				  }
   498  				]
   499  			  }
   500  			}
   501  	`, base64.StdEncoding.EncodeToString(plaintext1), base64.StdEncoding.EncodeToString(plaintext2)))
   502  
   503  	m := NewMockHTTPClient().
   504  		With(
   505  			"POST",
   506  			mustURLf("%s/v1/transit/decrypt/%s", client.Remote.String(), key),
   507  			&http.Response{
   508  				StatusCode: http.StatusOK,
   509  				Body:       io.NopCloser(bytes.NewBuffer(batchDecryptResultBytes)),
   510  			},
   511  		)
   512  	client.Client = m
   513  
   514  	plaintextResults, err := client.BatchDecrypt(todo, "key", batchInput)
   515  	assert.NotNil(err)
   516  	assert.Equal(ErrBatchTransitDecryptError, err.Error())
   517  	assert.Nil(plaintextResults)
   518  }
   519  
   520  func TestVaultHmac(t *testing.T) {
   521  	assert := assert.New(t)
   522  	todo := context.TODO()
   523  
   524  	client, err := New()
   525  	assert.Nil(err)
   526  
   527  	key := "key"
   528  	input := []byte("hmac!")
   529  	result := fmt.Sprintf(`{"data": {"hmac": "%s"}}`, input)
   530  
   531  	m := NewMockHTTPClient().
   532  		With(
   533  			"POST",
   534  			mustURLf("%s/v1/transit/hmac/%s/sha2-256", client.Remote.String(), key),
   535  			&http.Response{
   536  				StatusCode: http.StatusOK,
   537  				Body:       io.NopCloser(bytes.NewBuffer([]byte(result))),
   538  			},
   539  		)
   540  	client.Client = m
   541  
   542  	res, err := client.TransitHMAC(todo, "key", input)
   543  	assert.Nil(err)
   544  	assert.Equal(input, res)
   545  }
   546  
   547  func TestVaultHmacError(t *testing.T) {
   548  	assert := assert.New(t)
   549  	todo := context.TODO()
   550  
   551  	client, err := New()
   552  	assert.Nil(err)
   553  
   554  	key := "key"
   555  	input := []byte("hmac!")
   556  	result := `bad payload`
   557  
   558  	m := NewMockHTTPClient().
   559  		With(
   560  			"POST",
   561  			mustURLf("%s/v1/transit/hmac/%s/sha2-256", client.Remote.String(), key),
   562  			&http.Response{
   563  				StatusCode: http.StatusOK,
   564  				Body:       io.NopCloser(bytes.NewBuffer([]byte(result))),
   565  			},
   566  		)
   567  	client.Client = m
   568  
   569  	res, err := client.TransitHMAC(todo, "key", input)
   570  	assert.NotNil(err)
   571  	assert.Nil(res)
   572  }