github.com/google/osv-scalibr@v0.4.1/veles/secrets/awsaccesskey/validator_test.go (about)

     1  // Copyright 2025 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package awsaccesskey_test
    16  
    17  import (
    18  	"bytes"
    19  	"io"
    20  	"net/http"
    21  	"net/http/httptest"
    22  	"net/url"
    23  	"strings"
    24  	"testing"
    25  
    26  	"github.com/google/osv-scalibr/veles"
    27  	"github.com/google/osv-scalibr/veles/secrets/awsaccesskey"
    28  )
    29  
    30  type fakeSigner struct{}
    31  
    32  func (n fakeSigner) Sign(r *http.Request, accessID string, secret string) error {
    33  	r.Header.Set("Authorization", "Signature="+accessID+":"+secret)
    34  	return nil
    35  }
    36  
    37  type mockRoundTripper struct {
    38  	url string
    39  }
    40  
    41  func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
    42  	if req.URL.Host == "sts.us-east-1.amazonaws.com" {
    43  		testURL, _ := url.Parse(m.url)
    44  		req.URL.Scheme = testURL.Scheme
    45  		req.URL.Host = testURL.Host
    46  	}
    47  	return http.DefaultTransport.RoundTrip(req)
    48  }
    49  
    50  // mockSTSServer returns an httptest.Server that simulates the AWS STS server
    51  func mockSTSServer(signature string) func() *httptest.Server {
    52  	return func() *httptest.Server {
    53  		handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
    54  			// Only handle GetCallerIdentity (POST /)
    55  			if req.Method != http.MethodPost || req.URL.Path != "/" || req.Body == nil {
    56  				http.Error(w, "not found", http.StatusNotFound)
    57  				return
    58  			}
    59  			body, _ := io.ReadAll(req.Body)
    60  			if !bytes.HasPrefix(body, []byte("Action=GetCallerIdentity")) {
    61  				http.Error(w, "bad method", http.StatusNotFound)
    62  			}
    63  
    64  			if !strings.Contains(req.Header.Get("Authorization"), signature) {
    65  				w.WriteHeader(http.StatusForbidden)
    66  				_, _ = io.WriteString(w, `<ErrorResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
    67  				  <Error>
    68  				    <Type>Sender</Type>
    69  				    <Code>SignatureDoesNotMatch</Code>
    70  				    <Message>The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.</Message>
    71  				  </Error>
    72  				  <RequestId>f7a4e6b1-9d2c-4f80-8a7e-3c5d9f1b0e2b</RequestId>
    73  				</ErrorResponse>`)
    74  				return
    75  			}
    76  
    77  			_, _ = io.WriteString(w, `<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
    78  			  <GetCallerIdentityResult>
    79  			    <Arn>***</Arn>
    80  			    <UserId>***</UserId>
    81  			    <Account>***</Account>
    82  			  </GetCallerIdentityResult>
    83  			  <ResponseMetadata>
    84  			    <RequestId>f7a4e6b1-9d2c-4f80-8a7e-3c5d9f1b0e2a</RequestId>
    85  			  </ResponseMetadata>
    86  			</GetCallerIdentityResponse>`)
    87  			w.WriteHeader(http.StatusOK)
    88  		})
    89  
    90  		return httptest.NewServer(handler)
    91  	}
    92  }
    93  
    94  const (
    95  	exampleAccessID  = "AIKAerkjf4f034"
    96  	correctSecret    = "testsecret"
    97  	badSecret        = "badSecret"
    98  	correctSignature = exampleAccessID + ":" + correctSecret
    99  )
   100  
   101  func TestValidator(t *testing.T) {
   102  	cases := []struct {
   103  		name   string
   104  		key    awsaccesskey.Credentials
   105  		want   veles.ValidationStatus
   106  		server func() *httptest.Server
   107  	}{
   108  		{
   109  			name: "correct_secret",
   110  			key: awsaccesskey.Credentials{
   111  				AccessID: exampleAccessID,
   112  				Secret:   correctSecret,
   113  			},
   114  			want:   veles.ValidationValid,
   115  			server: mockSTSServer(correctSignature),
   116  		},
   117  		{
   118  			name: "bad_secret",
   119  			key: awsaccesskey.Credentials{
   120  				AccessID: exampleAccessID,
   121  				Secret:   badSecret,
   122  			},
   123  			want:   veles.ValidationInvalid,
   124  			server: mockSTSServer(correctSignature),
   125  		},
   126  	}
   127  
   128  	for _, tc := range cases {
   129  		t.Run(tc.name, func(t *testing.T) {
   130  			srv := tc.server()
   131  			client := &http.Client{
   132  				Transport: &mockRoundTripper{url: srv.URL},
   133  			}
   134  
   135  			validator := awsaccesskey.NewValidator()
   136  			validator.SetHTTPClient(client)
   137  			validator.SetSigner(fakeSigner{})
   138  
   139  			got, err := validator.Validate(t.Context(), tc.key)
   140  			if err != nil {
   141  				t.Errorf("Validate() error: %v, want nil", err)
   142  			}
   143  			if got != tc.want {
   144  				t.Errorf("Validate() = %q, want %q", got, tc.want)
   145  			}
   146  		})
   147  	}
   148  }