github.com/hashicorp/vault/sdk@v0.13.0/logical/response_util_test.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package logical 5 6 import ( 7 "errors" 8 "strings" 9 "testing" 10 11 "github.com/hashicorp/vault/sdk/helper/consts" 12 ) 13 14 func TestResponseUtil_RespondErrorCommon_basic(t *testing.T) { 15 testCases := []struct { 16 title string 17 req *Request 18 resp *Response 19 respErr error 20 expectedStatus int 21 expectedErr error 22 }{ 23 { 24 title: "Throttled, no error", 25 respErr: ErrUpstreamRateLimited, 26 resp: &Response{}, 27 expectedStatus: 502, 28 }, 29 { 30 title: "Throttled, with error", 31 respErr: ErrUpstreamRateLimited, 32 resp: &Response{ 33 Data: map[string]interface{}{ 34 "error": "rate limited", 35 }, 36 }, 37 expectedStatus: 502, 38 }, 39 { 40 title: "Read not found", 41 req: &Request{ 42 Operation: ReadOperation, 43 }, 44 respErr: nil, 45 expectedStatus: 404, 46 }, 47 { 48 title: "Header not found", 49 req: &Request{ 50 Operation: HeaderOperation, 51 }, 52 respErr: nil, 53 expectedStatus: 404, 54 }, 55 { 56 title: "List with response and no keys", 57 req: &Request{ 58 Operation: ListOperation, 59 }, 60 resp: &Response{}, 61 respErr: nil, 62 expectedStatus: 404, 63 }, 64 { 65 title: "List with response and keys", 66 req: &Request{ 67 Operation: ListOperation, 68 }, 69 resp: &Response{ 70 Data: map[string]interface{}{ 71 "keys": []string{"some", "things", "here"}, 72 }, 73 }, 74 respErr: nil, 75 expectedStatus: 0, 76 }, 77 { 78 title: "Invalid Credentials error ", 79 respErr: ErrInvalidCredentials, 80 resp: &Response{ 81 Data: map[string]interface{}{ 82 "error": "error due to wrong credentials", 83 }, 84 }, 85 expectedErr: errors.New("error due to wrong credentials"), 86 expectedStatus: 400, 87 }, 88 { 89 title: "Overloaded error", 90 respErr: consts.ErrOverloaded, 91 resp: &Response{ 92 Data: map[string]interface{}{ 93 "error": "overloaded, try again later", 94 }, 95 }, 96 expectedErr: consts.ErrOverloaded, 97 expectedStatus: 503, 98 }, 99 } 100 101 for _, tc := range testCases { 102 t.Run(tc.title, func(t *testing.T) { 103 var status int 104 var err, respErr error 105 if tc.respErr != nil { 106 respErr = tc.respErr 107 } 108 status, err = RespondErrorCommon(tc.req, tc.resp, respErr) 109 if status != tc.expectedStatus { 110 t.Fatalf("Expected (%d) status code, got (%d)", tc.expectedStatus, status) 111 } 112 if tc.expectedErr != nil { 113 if !strings.Contains(tc.expectedErr.Error(), err.Error()) { 114 t.Fatalf("Expected error to contain:\n%s\n\ngot:\n%s\n", tc.expectedErr, err) 115 } 116 } 117 }) 118 } 119 }