github.com/blend/go-sdk@v1.20220411.3/vault/api_client_test.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package vault 9 10 import ( 11 "bytes" 12 "context" 13 "encoding/base64" 14 "fmt" 15 "io" 16 "net/http" 17 "net/http/httptest" 18 "net/url" 19 "testing" 20 21 "github.com/blend/go-sdk/assert" 22 "github.com/blend/go-sdk/webutil" 23 ) 24 25 func mustURLf(format string, args ...interface{}) *url.URL { 26 return webutil.MustParseURL(fmt.Sprintf(format, args...)) 27 } 28 29 func TestVaultClientBackendKV(t *testing.T) { 30 assert := assert.New(t) 31 todo := context.TODO() 32 33 client, err := New() 34 assert.Nil(err) 35 36 mountMetaJSON := `{"request_id":"e114c628-6493-28ed-0975-418a75c7976f","lease_id":"","renewable":false,"lease_duration":0,"data":{"accessor":"kv_45f6a162","config":{"default_lease_ttl":0,"force_no_cache":false,"max_lease_ttl":0,"plugin_name":""},"description":"key/value secret storage","local":false,"options":{"version":"2"},"path":"secret/","seal_wrap":false,"type":"kv"},"wrap_info":null,"warnings":null,"auth":null}` 37 38 m := NewMockHTTPClient().WithString("GET", mustURLf("%s/v1/sys/internal/ui/mounts/secret/foo/bar", client.Remote.String()), mountMetaJSON) 39 client.Client = m 40 41 backend, err := client.backendKV(todo, "foo/bar") 42 assert.Nil(err) 43 assert.NotNil(backend) 44 } 45 46 func TestVaultClientGetVersion(t *testing.T) { 47 assert := assert.New(t) 48 todo := context.TODO() 49 50 client, err := New() 51 assert.Nil(err) 52 53 mountMetaJSONV1 := `{"request_id":"e114c628-6493-28ed-0975-418a75c7976f","lease_id":"","renewable":false,"lease_duration":0,"data":{"accessor":"kv_45f6a162","config":{"default_lease_ttl":0,"force_no_cache":false,"max_lease_ttl":0,"plugin_name":""},"description":"key/value secret storage","local":false,"options":{"version":"1"},"path":"secret/","seal_wrap":false,"type":"kv"},"wrap_info":null,"warnings":null,"auth":null}` 54 mountMetaJSONV2 := `{"request_id":"e114c628-6493-28ed-0975-418a75c7976f","lease_id":"","renewable":false,"lease_duration":0,"data":{"accessor":"kv_45f6a162","config":{"default_lease_ttl":0,"force_no_cache":false,"max_lease_ttl":0,"plugin_name":""},"description":"key/value secret storage","local":false,"options":{"version":"2"},"path":"secret/","seal_wrap":false,"type":"kv"},"wrap_info":null,"warnings":null,"auth":null}` 55 56 m := NewMockHTTPClient(). 57 WithString("GET", mustURLf("%s/v1/sys/internal/ui/mounts/secret/foo/bar", client.Remote.String()), mountMetaJSONV1) 58 59 client.Client = m 60 61 version, err := client.getVersion(todo, "foo/bar") 62 assert.Nil(err) 63 assert.Equal(Version1, version) 64 65 m.WithString("GET", mustURLf("%s/v1/sys/internal/ui/mounts/secret/foo/bar", client.Remote.String()), mountMetaJSONV2) 66 67 version, err = client.getVersion(todo, "foo/bar") 68 assert.Nil(err) 69 assert.Equal(Version2, version) 70 } 71 72 func TestVaultClientGetMountMeta(t *testing.T) { 73 assert := assert.New(t) 74 todo := context.TODO() 75 76 client, err := New() 77 assert.Nil(err) 78 79 mountMetaJSON := `{"request_id":"e114c628-6493-28ed-0975-418a75c7976f","lease_id":"","renewable":false,"lease_duration":0,"data":{"accessor":"kv_45f6a162","config":{"default_lease_ttl":0,"force_no_cache":false,"max_lease_ttl":0,"plugin_name":""},"description":"key/value secret storage","local":false,"options":{"version":"2"},"path":"secret/","seal_wrap":false,"type":"kv"},"wrap_info":null,"warnings":null,"auth":null}` 80 81 m := NewMockHTTPClient().WithString("GET", mustURLf("%s/v1/sys/internal/ui/mounts/secret/foo/bar", client.Remote.String()), mountMetaJSON) 82 client.Client = m 83 84 mountMeta, err := client.getMountMeta(todo, "secret/foo/bar") 85 assert.Nil(err) 86 assert.NotNil(mountMeta) 87 assert.Equal(Version2, mountMeta.Data.Options["version"]) 88 } 89 90 func TestVaultClientJSONBody(t *testing.T) { 91 assert := assert.New(t) 92 93 client, err := New() 94 assert.Nil(err) 95 96 output, err := client.jsonBody(map[string]interface{}{ 97 "foo": "bar", 98 }) 99 assert.Nil(err) 100 defer output.Close() 101 102 contents, err := io.ReadAll(output) 103 assert.Nil(err) 104 assert.Equal("{\"foo\":\"bar\"}\n", string(contents)) 105 } 106 107 func TestVaultClientReadJSON(t *testing.T) { 108 assert := assert.New(t) 109 110 client, err := New() 111 assert.Nil(err) 112 113 jsonBody := bytes.NewBuffer([]byte(`{"foo":"bar"}`)) 114 115 output := map[string]interface{}{} 116 assert.Nil(client.readJSON(jsonBody, &output)) 117 assert.Equal("bar", output["foo"]) 118 } 119 120 func TestVaultClientCopyRemote(t *testing.T) { 121 assert := assert.New(t) 122 123 client, err := New() 124 assert.Nil(err) 125 126 copy := client.copyRemote() 127 copy.Host = "not_" + copy.Host 128 129 anotherCopy := client.copyRemote() 130 assert.NotEqual(anotherCopy.Host, copy.Host) 131 } 132 133 func TestVaultClientDiscard(t *testing.T) { 134 assert := assert.New(t) 135 136 client, err := New() 137 assert.Nil(err) 138 139 assert.NotNil(client.discard(nil, fmt.Errorf("this is only a test"))) 140 141 assert.Nil(client.discard(client.jsonBody(map[string]interface{}{ 142 "foo": "bar", 143 }))) 144 } 145 146 func TestVaultCreateTransitKey(t *testing.T) { 147 assert := assert.New(t) 148 todo := context.TODO() 149 150 client, err := New() 151 assert.Nil(err) 152 153 key := "key" 154 155 m := NewMockHTTPClient(). 156 With( 157 "POST", 158 mustURLf("%s/v1/transit/keys/%s", client.Remote.String(), key), 159 &http.Response{ 160 StatusCode: http.StatusNoContent, 161 Body: io.NopCloser(bytes.NewBuffer([]byte{})), 162 }, 163 ) 164 client.Client = m 165 166 err = client.CreateTransitKey(todo, "key") 167 assert.Nil(err) 168 } 169 170 func TestVaultConfigureTransitKey(t *testing.T) { 171 assert := assert.New(t) 172 todo := context.TODO() 173 174 client, err := New() 175 assert.Nil(err) 176 177 key := "key" 178 179 m := NewMockHTTPClient(). 180 With( 181 "POST", 182 mustURLf("%s/v1/transit/keys/%s/config", client.Remote.String(), key), 183 &http.Response{ 184 StatusCode: http.StatusNoContent, 185 Body: io.NopCloser(bytes.NewBuffer([]byte{})), 186 }, 187 ) 188 client.Client = m 189 190 err = client.ConfigureTransitKey(todo, "key", OptUpdateTransitDeletionAllowed(true)) 191 assert.Nil(err) 192 } 193 194 func TestVaultReadTransitKey(t *testing.T) { 195 assert := assert.New(t) 196 todo := context.TODO() 197 198 client, err := New() 199 assert.Nil(err) 200 201 key := "key" 202 keyMetaJSON := `{"request_id":"e114c628-6493-28ed-0975-418a75c7976f","lease_id":"","renewable":false,"lease_duration":0,"data":{"deletion_allowed":true,"exportable":false,"allow_plaintext_backup":false,"keys": {"1": 1442851412},"min_decryption_version": 1,"min_encryption_version": 0,"name": "foo"},"wrap_info":null,"warnings":null,"auth":null}` 203 204 m := NewMockHTTPClient().WithString("GET", mustURLf("%s/v1/transit/keys/%s", client.Remote.String(), key), keyMetaJSON) 205 client.Client = m 206 207 data, err := client.ReadTransitKey(todo, "key") 208 assert.Nil(err) 209 assert.Equal(true, data["deletion_allowed"]) 210 } 211 212 func TestVaultDeleteTransitKey(t *testing.T) { 213 assert := assert.New(t) 214 todo := context.TODO() 215 216 client, err := New() 217 assert.Nil(err) 218 219 key := "key" 220 221 m := NewMockHTTPClient(). 222 With( 223 "DELETE", 224 mustURLf("%s/v1/transit/keys/%s", client.Remote.String(), key), 225 &http.Response{ 226 StatusCode: http.StatusNoContent, 227 Body: io.NopCloser(bytes.NewBuffer([]byte{})), 228 }, 229 ) 230 client.Client = m 231 232 err = client.DeleteTransitKey(todo, "key") 233 assert.Nil(err) 234 } 235 236 func TestVaultHandleRedirects(t *testing.T) { 237 assert := assert.New(t) 238 239 rawResponse := "{\"status\":\"ok!\"}\n" 240 241 inner := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 242 w.Header().Set(webutil.HeaderContentType, webutil.ContentTypeApplicationJSON) 243 w.WriteHeader(http.StatusOK) 244 fmt.Fprint(w, rawResponse) 245 })) 246 defer inner.Close() 247 outer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 248 http.Redirect(w, r, inner.URL, http.StatusTemporaryRedirect) 249 })) 250 defer outer.Close() 251 252 client, err := New( 253 OptRemote(outer.URL), 254 ) 255 assert.Nil(err) 256 assert.NotNil(client) 257 258 rawURL, err := url.Parse(outer.URL) 259 assert.Nil(err) 260 res, err := client.Client.Do(&http.Request{URL: rawURL}) 261 assert.Nil(err) 262 defer res.Body.Close() 263 assert.Equal(http.StatusOK, res.StatusCode) 264 265 contents, err := io.ReadAll(res.Body) 266 assert.Nil(err) 267 assert.Equal(rawResponse, string(contents)) 268 } 269 270 func TestVaultBatchEncryptDecrypt_Happy(t *testing.T) { 271 assert := assert.New(t) 272 todo := context.Background() 273 274 client, err := New() 275 assert.Nil(err) 276 277 key := "key" 278 279 plaintext1 := []byte("this is plaintext") 280 plaintext2 := []byte("this is plaintext2") 281 batchInput := BatchTransitInput{ 282 BatchTransitInputItems: []BatchTransitInputItem{ 283 { 284 Context: nil, 285 Plaintext: plaintext1, 286 }, 287 { 288 Context: nil, 289 Plaintext: plaintext2, 290 }, 291 }, 292 } 293 294 batchDecryptResultBytes := []byte(fmt.Sprintf(` 295 { 296 "data": { 297 "batch_results": [ 298 { 299 "plaintext": "%s", 300 "key_version": 1 301 }, 302 { 303 "plaintext": "%s", 304 "key_version": 1 305 } 306 ] 307 } 308 } 309 `, base64.StdEncoding.EncodeToString(plaintext1), base64.StdEncoding.EncodeToString(plaintext2))) 310 311 batchEncryptResultBytes := []byte(fmt.Sprintf(` 312 { 313 "data": { 314 "batch_results": [ 315 { 316 "ciphertext": "vault:%s", 317 "key_version": 1 318 }, 319 { 320 "ciphertext": "vault:%s", 321 "key_version": 1 322 } 323 ] 324 } 325 } 326 `, plaintext1, plaintext2)) 327 m := NewMockHTTPClient(). 328 With( 329 "POST", 330 mustURLf("%s/v1/transit/encrypt/%s", client.Remote.String(), key), 331 &http.Response{ 332 StatusCode: http.StatusOK, 333 Body: io.NopCloser(bytes.NewBuffer(batchEncryptResultBytes)), 334 }, 335 ).With( 336 "POST", 337 mustURLf("%s/v1/transit/decrypt/%s", client.Remote.String(), key), 338 &http.Response{ 339 StatusCode: http.StatusOK, 340 Body: io.NopCloser(bytes.NewBuffer(batchDecryptResultBytes)), 341 }, 342 ) 343 client.Client = m 344 345 ciphertextResults, err := client.BatchEncrypt(todo, "key", batchInput) 346 assert.Nil(err) 347 assert.Equal(fmt.Sprintf("vault:%s", plaintext1), ciphertextResults[0]) 348 assert.Equal(fmt.Sprintf("vault:%s", plaintext2), ciphertextResults[1]) 349 350 plaintextResults, err := client.BatchDecrypt(todo, "key", batchInput) 351 assert.Nil(err) 352 assert.Equal(plaintext1, plaintextResults[0]) 353 assert.Equal(plaintext2, plaintextResults[1]) 354 } 355 356 func TestVaultBatchEncryptDecrypt_EmptyInput(t *testing.T) { 357 assert := assert.New(t) 358 todo := context.Background() 359 360 client, err := New() 361 assert.Nil(err) 362 363 key := "key" 364 365 batchInput := BatchTransitInput{ 366 BatchTransitInputItems: []BatchTransitInputItem{}, 367 } 368 369 errorResultBytes := []byte(` 370 { 371 "data": { 372 "error": "missing batch input to process" 373 } 374 } 375 `) 376 m := NewMockHTTPClient(). 377 With( 378 "POST", 379 mustURLf("%s/v1/transit/encrypt/%s", client.Remote.String(), key), 380 &http.Response{ 381 StatusCode: http.StatusBadRequest, 382 Body: io.NopCloser(bytes.NewBuffer(errorResultBytes)), 383 }, 384 ).With( 385 "POST", 386 mustURLf("%s/v1/transit/decrypt/%s", client.Remote.String(), key), 387 &http.Response{ 388 StatusCode: http.StatusBadRequest, 389 Body: io.NopCloser(bytes.NewBuffer(errorResultBytes)), 390 }, 391 ) 392 client.Client = m 393 394 ciphertextResults, err := client.BatchEncrypt(todo, "key", batchInput) 395 assert.Nil(err) 396 assert.Empty(ciphertextResults) 397 398 plaintextResults, err := client.BatchDecrypt(todo, "key", batchInput) 399 assert.Nil(err) 400 assert.Empty(plaintextResults) 401 } 402 403 func TestVaultBatchEncrypt_Error(t *testing.T) { 404 assert := assert.New(t) 405 todo := context.TODO() 406 407 client, err := New() 408 assert.Nil(err) 409 410 key := "key" 411 412 plaintext1 := []byte("this is plaintext") 413 plaintext2 := []byte("this is plaintext2") 414 batchInput := BatchTransitInput{ 415 BatchTransitInputItems: []BatchTransitInputItem{ 416 { 417 Context: nil, 418 Plaintext: plaintext1, 419 }, 420 { 421 Context: nil, 422 Plaintext: plaintext2, 423 }, 424 }, 425 } 426 427 batchEncryptResultBytes := []byte(fmt.Sprintf(` 428 { 429 "data": { 430 "batch_results": [ 431 { 432 "ciphertext": "vault:%s", 433 "key_version": 1 434 }, 435 { 436 "error": "encryption error", 437 "ciphertext": "vault:%s", 438 "key_version": 1 439 } 440 ] 441 } 442 } 443 `, plaintext1, plaintext2)) 444 m := NewMockHTTPClient(). 445 With( 446 "POST", 447 mustURLf("%s/v1/transit/encrypt/%s", client.Remote.String(), key), 448 &http.Response{ 449 StatusCode: http.StatusOK, 450 Body: io.NopCloser(bytes.NewBuffer(batchEncryptResultBytes)), 451 }, 452 ) 453 client.Client = m 454 455 ciphertextResults, err := client.BatchEncrypt(todo, "key", batchInput) 456 assert.NotNil(err) 457 assert.Equal(ErrBatchTransitEncryptError, err.Error()) 458 assert.Nil(ciphertextResults) 459 } 460 461 func TestVaultBatchDecrypt_Error(t *testing.T) { 462 assert := assert.New(t) 463 todo := context.TODO() 464 465 client, err := New() 466 assert.Nil(err) 467 468 key := "key" 469 470 plaintext1 := []byte("this is plaintext") 471 plaintext2 := []byte("this is plaintext2") 472 batchInput := BatchTransitInput{ 473 BatchTransitInputItems: []BatchTransitInputItem{ 474 { 475 Context: nil, 476 Plaintext: plaintext1, 477 }, 478 { 479 Context: nil, 480 Plaintext: plaintext2, 481 }, 482 }, 483 } 484 485 batchDecryptResultBytes := []byte(fmt.Sprintf(` 486 { 487 "data": { 488 "batch_results": [ 489 { 490 "error": "error", 491 "plaintext": "%s", 492 "key_version": 1 493 }, 494 { 495 "plaintext": "%s", 496 "key_version": 1 497 } 498 ] 499 } 500 } 501 `, base64.StdEncoding.EncodeToString(plaintext1), base64.StdEncoding.EncodeToString(plaintext2))) 502 503 m := NewMockHTTPClient(). 504 With( 505 "POST", 506 mustURLf("%s/v1/transit/decrypt/%s", client.Remote.String(), key), 507 &http.Response{ 508 StatusCode: http.StatusOK, 509 Body: io.NopCloser(bytes.NewBuffer(batchDecryptResultBytes)), 510 }, 511 ) 512 client.Client = m 513 514 plaintextResults, err := client.BatchDecrypt(todo, "key", batchInput) 515 assert.NotNil(err) 516 assert.Equal(ErrBatchTransitDecryptError, err.Error()) 517 assert.Nil(plaintextResults) 518 } 519 520 func TestVaultHmac(t *testing.T) { 521 assert := assert.New(t) 522 todo := context.TODO() 523 524 client, err := New() 525 assert.Nil(err) 526 527 key := "key" 528 input := []byte("hmac!") 529 result := fmt.Sprintf(`{"data": {"hmac": "%s"}}`, input) 530 531 m := NewMockHTTPClient(). 532 With( 533 "POST", 534 mustURLf("%s/v1/transit/hmac/%s/sha2-256", client.Remote.String(), key), 535 &http.Response{ 536 StatusCode: http.StatusOK, 537 Body: io.NopCloser(bytes.NewBuffer([]byte(result))), 538 }, 539 ) 540 client.Client = m 541 542 res, err := client.TransitHMAC(todo, "key", input) 543 assert.Nil(err) 544 assert.Equal(input, res) 545 } 546 547 func TestVaultHmacError(t *testing.T) { 548 assert := assert.New(t) 549 todo := context.TODO() 550 551 client, err := New() 552 assert.Nil(err) 553 554 key := "key" 555 input := []byte("hmac!") 556 result := `bad payload` 557 558 m := NewMockHTTPClient(). 559 With( 560 "POST", 561 mustURLf("%s/v1/transit/hmac/%s/sha2-256", client.Remote.String(), key), 562 &http.Response{ 563 StatusCode: http.StatusOK, 564 Body: io.NopCloser(bytes.NewBuffer([]byte(result))), 565 }, 566 ) 567 client.Client = m 568 569 res, err := client.TransitHMAC(todo, "key", input) 570 assert.NotNil(err) 571 assert.Nil(res) 572 }