github.com/google/osv-scalibr@v0.4.1/veles/secrets/awsaccesskey/validator_test.go (about) 1 // Copyright 2025 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package awsaccesskey_test 16 17 import ( 18 "bytes" 19 "io" 20 "net/http" 21 "net/http/httptest" 22 "net/url" 23 "strings" 24 "testing" 25 26 "github.com/google/osv-scalibr/veles" 27 "github.com/google/osv-scalibr/veles/secrets/awsaccesskey" 28 ) 29 30 type fakeSigner struct{} 31 32 func (n fakeSigner) Sign(r *http.Request, accessID string, secret string) error { 33 r.Header.Set("Authorization", "Signature="+accessID+":"+secret) 34 return nil 35 } 36 37 type mockRoundTripper struct { 38 url string 39 } 40 41 func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 42 if req.URL.Host == "sts.us-east-1.amazonaws.com" { 43 testURL, _ := url.Parse(m.url) 44 req.URL.Scheme = testURL.Scheme 45 req.URL.Host = testURL.Host 46 } 47 return http.DefaultTransport.RoundTrip(req) 48 } 49 50 // mockSTSServer returns an httptest.Server that simulates the AWS STS server 51 func mockSTSServer(signature string) func() *httptest.Server { 52 return func() *httptest.Server { 53 handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 54 // Only handle GetCallerIdentity (POST /) 55 if req.Method != http.MethodPost || req.URL.Path != "/" || req.Body == nil { 56 http.Error(w, "not found", http.StatusNotFound) 57 return 58 } 59 body, _ := io.ReadAll(req.Body) 60 if !bytes.HasPrefix(body, []byte("Action=GetCallerIdentity")) { 61 http.Error(w, "bad method", http.StatusNotFound) 62 } 63 64 if !strings.Contains(req.Header.Get("Authorization"), signature) { 65 w.WriteHeader(http.StatusForbidden) 66 _, _ = io.WriteString(w, `<ErrorResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/"> 67 <Error> 68 <Type>Sender</Type> 69 <Code>SignatureDoesNotMatch</Code> 70 <Message>The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.</Message> 71 </Error> 72 <RequestId>f7a4e6b1-9d2c-4f80-8a7e-3c5d9f1b0e2b</RequestId> 73 </ErrorResponse>`) 74 return 75 } 76 77 _, _ = io.WriteString(w, `<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/"> 78 <GetCallerIdentityResult> 79 <Arn>***</Arn> 80 <UserId>***</UserId> 81 <Account>***</Account> 82 </GetCallerIdentityResult> 83 <ResponseMetadata> 84 <RequestId>f7a4e6b1-9d2c-4f80-8a7e-3c5d9f1b0e2a</RequestId> 85 </ResponseMetadata> 86 </GetCallerIdentityResponse>`) 87 w.WriteHeader(http.StatusOK) 88 }) 89 90 return httptest.NewServer(handler) 91 } 92 } 93 94 const ( 95 exampleAccessID = "AIKAerkjf4f034" 96 correctSecret = "testsecret" 97 badSecret = "badSecret" 98 correctSignature = exampleAccessID + ":" + correctSecret 99 ) 100 101 func TestValidator(t *testing.T) { 102 cases := []struct { 103 name string 104 key awsaccesskey.Credentials 105 want veles.ValidationStatus 106 server func() *httptest.Server 107 }{ 108 { 109 name: "correct_secret", 110 key: awsaccesskey.Credentials{ 111 AccessID: exampleAccessID, 112 Secret: correctSecret, 113 }, 114 want: veles.ValidationValid, 115 server: mockSTSServer(correctSignature), 116 }, 117 { 118 name: "bad_secret", 119 key: awsaccesskey.Credentials{ 120 AccessID: exampleAccessID, 121 Secret: badSecret, 122 }, 123 want: veles.ValidationInvalid, 124 server: mockSTSServer(correctSignature), 125 }, 126 } 127 128 for _, tc := range cases { 129 t.Run(tc.name, func(t *testing.T) { 130 srv := tc.server() 131 client := &http.Client{ 132 Transport: &mockRoundTripper{url: srv.URL}, 133 } 134 135 validator := awsaccesskey.NewValidator() 136 validator.SetHTTPClient(client) 137 validator.SetSigner(fakeSigner{}) 138 139 got, err := validator.Validate(t.Context(), tc.key) 140 if err != nil { 141 t.Errorf("Validate() error: %v, want nil", err) 142 } 143 if got != tc.want { 144 t.Errorf("Validate() = %q, want %q", got, tc.want) 145 } 146 }) 147 } 148 }