github.com/google/osv-scalibr@v0.4.1/veles/secrets/stripeapikeys/validator_test.go (about)

     1  // Copyright 2025 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Copyright 2025 Google LLC
    16  //
    17  // Licensed under the Apache License, Version 2.0 (the "License");
    18  // you may not use this file except in compliance with the License.
    19  // You may obtain a copy of the License at
    20  //
    21  // http://www.apache.org/licenses/LICENSE-2.0
    22  //
    23  // Unless required by applicable law or agreed to in writing, software
    24  // distributed under the License is distributed on an "AS IS" BASIS,
    25  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    26  // See the License for the specific language governing permissions and
    27  // limitations under the License.
    28  
    29  package stripeapikeys_test
    30  
    31  import (
    32  	"context"
    33  	"net/http"
    34  	"net/http/httptest"
    35  	"net/url"
    36  	"testing"
    37  
    38  	"github.com/google/go-cmp/cmp"
    39  	"github.com/google/go-cmp/cmp/cmpopts"
    40  	"github.com/google/osv-scalibr/veles"
    41  	stripeapikeys "github.com/google/osv-scalibr/veles/secrets/stripeapikeys"
    42  )
    43  
    44  const (
    45  	validatorTestSK = "sk_live_51PvZzqABcD1234EfGhIjKlMnOpQrStUvWxYz0123456789abcdefghijklmnopQRSTuvWXYZabcd12345678"
    46  	validatorTestRK = "rk_live_51PvZzABcDEfGhIjKlMnOpQrStUvWxYz0123456789abcdefGHIJKLMNOPQRSTUVWXYZabcd12345678"
    47  )
    48  
    49  // mockTransport redirects requests to the test server for the configured hosts.
    50  type mockTransport struct {
    51  	testServer *httptest.Server
    52  }
    53  
    54  func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
    55  	// Replace the original URL with our test server URL for Stripe API hosts.
    56  	if req.URL.Host == "api.stripe.com" {
    57  		testURL, _ := url.Parse(m.testServer.URL)
    58  		req.URL.Scheme = testURL.Scheme
    59  		req.URL.Host = testURL.Host
    60  	}
    61  	return http.DefaultTransport.RoundTrip(req)
    62  }
    63  
    64  // mockStripeAPIServer creates a mock Stripe /v1/accounts endpoint for testing validators.
    65  func mockStripeAPIServer(t *testing.T, expectedKey string, statusCode int) *httptest.Server {
    66  	t.Helper()
    67  
    68  	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    69  		// Expect a GET to /v1/accounts
    70  		if r.Method != http.MethodGet || r.URL.Path != "/v1/accounts" {
    71  			t.Errorf("unexpected request: %s %s, expected: GET /v1/accounts", r.Method, r.URL.Path)
    72  			http.Error(w, "not found", http.StatusNotFound)
    73  			return
    74  		}
    75  
    76  		// Check Basic Auth header contains the expected key
    77  		username, _, ok := r.BasicAuth()
    78  		if !ok || username != expectedKey {
    79  			t.Errorf("expected Basic Auth username to be %s, got: %s", expectedKey, username)
    80  		}
    81  
    82  		w.WriteHeader(statusCode)
    83  	}))
    84  }
    85  
    86  func TestValidatorSecretKey(t *testing.T) {
    87  	cases := []struct {
    88  		name       string
    89  		statusCode int
    90  		want       veles.ValidationStatus
    91  		wantErr    error
    92  	}{
    93  		{
    94  			name:       "valid_key",
    95  			statusCode: http.StatusOK,
    96  			want:       veles.ValidationValid,
    97  		},
    98  		{
    99  			name:       "invalid_key_unauthorized",
   100  			statusCode: http.StatusUnauthorized,
   101  			want:       veles.ValidationInvalid,
   102  		},
   103  		{
   104  			name:       "server_error",
   105  			statusCode: http.StatusInternalServerError,
   106  			want:       veles.ValidationInvalid,
   107  		},
   108  		{
   109  			name:       "forbidden_error",
   110  			statusCode: http.StatusForbidden,
   111  			want:       veles.ValidationInvalid,
   112  		},
   113  	}
   114  
   115  	for _, tc := range cases {
   116  		t.Run(tc.name, func(t *testing.T) {
   117  			// Create mock server
   118  			server := mockStripeAPIServer(t, validatorTestSK, tc.statusCode)
   119  			defer server.Close()
   120  
   121  			// Create client with custom transport
   122  			client := &http.Client{
   123  				Transport: &mockTransport{testServer: server},
   124  			}
   125  
   126  			// Create validator with mock client
   127  			validator := stripeapikeys.NewSecretKeyValidator()
   128  			validator.HTTPC = client
   129  
   130  			// Create test key
   131  			key := stripeapikeys.StripeSecretKey{Key: validatorTestSK}
   132  
   133  			// Test validation
   134  			got, err := validator.Validate(t.Context(), key)
   135  
   136  			if diff := cmp.Diff(tc.wantErr, err, cmpopts.EquateErrors()); diff != "" {
   137  				t.Errorf("Validate() error mismatch (-want +got):\n%s", diff)
   138  			}
   139  
   140  			// Check validation status
   141  			if got != tc.want {
   142  				t.Errorf("Validate() = %v, want %v", got, tc.want)
   143  			}
   144  		})
   145  	}
   146  }
   147  
   148  func TestValidatorSecretKey_ContextCancellation(t *testing.T) {
   149  	server := httptest.NewServer(nil)
   150  	t.Cleanup(func() {
   151  		server.Close()
   152  	})
   153  
   154  	// Create client with custom transport
   155  	client := &http.Client{
   156  		Transport: &mockTransport{testServer: server},
   157  	}
   158  
   159  	validator := stripeapikeys.NewSecretKeyValidator()
   160  	validator.HTTPC = client
   161  
   162  	key := stripeapikeys.StripeSecretKey{Key: validatorTestSK}
   163  
   164  	// Create context that is immediately cancelled
   165  	ctx, cancel := context.WithCancel(t.Context())
   166  	cancel()
   167  
   168  	// Test validation with cancelled context
   169  	got, err := validator.Validate(ctx, key)
   170  
   171  	if diff := cmp.Diff(cmpopts.AnyError, err, cmpopts.EquateErrors()); diff != "" {
   172  		t.Errorf("Validate() error mismatch (-want +got):\n%s", diff)
   173  	}
   174  	if got != veles.ValidationFailed {
   175  		t.Errorf("Validate() = %v, want %v", got, veles.ValidationFailed)
   176  	}
   177  }
   178  
   179  func TestValidatorRestrictedKey(t *testing.T) {
   180  	cases := []struct {
   181  		name       string
   182  		statusCode int
   183  		want       veles.ValidationStatus
   184  		wantErr    error
   185  	}{
   186  		{
   187  			name:       "valid_key_ok",
   188  			statusCode: http.StatusOK,
   189  			want:       veles.ValidationValid,
   190  		},
   191  		{
   192  			name:       "valid_key_forbidden",
   193  			statusCode: http.StatusForbidden,
   194  			want:       veles.ValidationValid,
   195  		},
   196  		{
   197  			name:       "invalid_key_unauthorized",
   198  			statusCode: http.StatusUnauthorized,
   199  			want:       veles.ValidationInvalid,
   200  		},
   201  		{
   202  			name:       "server_error",
   203  			statusCode: http.StatusInternalServerError,
   204  			want:       veles.ValidationInvalid,
   205  		},
   206  	}
   207  
   208  	for _, tc := range cases {
   209  		t.Run(tc.name, func(t *testing.T) {
   210  			// Create mock server
   211  			server := mockStripeAPIServer(t, validatorTestRK, tc.statusCode)
   212  			defer server.Close()
   213  
   214  			// Create client with custom transport
   215  			client := &http.Client{
   216  				Transport: &mockTransport{testServer: server},
   217  			}
   218  
   219  			// Create validator with mock client
   220  			validator := stripeapikeys.NewRestrictedKeyValidator()
   221  			validator.HTTPC = client
   222  
   223  			// Create test key
   224  			key := stripeapikeys.StripeRestrictedKey{Key: validatorTestRK}
   225  
   226  			// Test validation
   227  			got, err := validator.Validate(t.Context(), key)
   228  
   229  			if diff := cmp.Diff(tc.wantErr, err, cmpopts.EquateErrors()); diff != "" {
   230  				t.Errorf("Validate() error mismatch (-want +got):\n%s", diff)
   231  			}
   232  
   233  			// Check validation status
   234  			if got != tc.want {
   235  				t.Errorf("Validate() = %v, want %v", got, tc.want)
   236  			}
   237  		})
   238  	}
   239  }
   240  
   241  func TestValidatorRestrictedKey_ContextCancellation(t *testing.T) {
   242  	server := httptest.NewServer(nil)
   243  	t.Cleanup(func() {
   244  		server.Close()
   245  	})
   246  
   247  	// Create client with custom transport
   248  	client := &http.Client{
   249  		Transport: &mockTransport{testServer: server},
   250  	}
   251  
   252  	validator := stripeapikeys.NewRestrictedKeyValidator()
   253  	validator.HTTPC = client
   254  
   255  	key := stripeapikeys.StripeRestrictedKey{Key: validatorTestRK}
   256  
   257  	// Create context that is immediately cancelled
   258  	ctx, cancel := context.WithCancel(t.Context())
   259  	cancel()
   260  
   261  	// Test validation with cancelled context
   262  	got, err := validator.Validate(ctx, key)
   263  
   264  	if diff := cmp.Diff(cmpopts.AnyError, err, cmpopts.EquateErrors()); diff != "" {
   265  		t.Errorf("Validate() error mismatch (-want +got):\n%s", diff)
   266  	}
   267  	if got != veles.ValidationFailed {
   268  		t.Errorf("Validate() = %v, want %v", got, veles.ValidationFailed)
   269  	}
   270  }